TRL documentation

GSPO-token

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.25.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GSPO-token

In the paper Group Sequence Policy Optimization, the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the GRPOTrainer class in trl.experimental.gspo_token.

Usage

from trl.experimental.gspo_token import GRPOTrainer
from trl import GRPOConfig

training_args = GRPOConfig(
    importance_sampling_level="sequence_token",
    ...
)

To leverage GSPO-token, the user will need to provide the per-token advantage Ai,t^ \hat{A_{i,t}} for each token t t in the sequence i i (i.e., make Ai,t^ \hat{A_{i,t}} varies with t t —which isn’t the case here, Ai,t^=Ai^ \hat{A_{i,t}}=\hat{A_{i}} ). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.

GRPOTrainer

class trl.GRPOTrainer

< >

( model: str | transformers.modeling_utils.PreTrainedModel reward_funcs: str | transformers.modeling_utils.PreTrainedModel | collections.abc.Callable[[list, list], list[float]] | list[str | transformers.modeling_utils.PreTrainedModel | collections.abc.Callable[[list, list], list[float]]] args: trl.trainer.grpo_config.GRPOConfig | None = None train_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | None = None eval_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | dict[str, datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset] | None = None processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.processing_utils.ProcessorMixin | None = None reward_processing_classes: transformers.tokenization_utils_base.PreTrainedTokenizerBase | list[transformers.tokenization_utils_base.PreTrainedTokenizerBase] | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple = (None, None) peft_config: PeftConfig | None = None rollout_func: collections.abc.Callable[[list[str], 'GRPOTrainer'], dict[str, typing.Any]] | None = None )

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs: typing.Any )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments used to hide deprecated arguments

Main training entry point.

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

Update on GitHub