TRL documentation
GSPO-token
GSPO-token
In the paper Group Sequence Policy Optimization, the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the GRPOTrainer class in trl.experimental.gspo_token.
Usage
from trl.experimental.gspo_token import GRPOTrainer
from trl import GRPOConfig
training_args = GRPOConfig(
importance_sampling_level="sequence_token",
...
)To leverage GSPO-token, the user will need to provide the per-token advantage for each token in the sequence (i.e., make varies with —which isn’t the case here, ). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.
GRPOTrainer
class trl.GRPOTrainer
< source >( model: str | transformers.modeling_utils.PreTrainedModel reward_funcs: str | transformers.modeling_utils.PreTrainedModel | collections.abc.Callable[[list, list], list[float]] | list[str | transformers.modeling_utils.PreTrainedModel | collections.abc.Callable[[list, list], list[float]]] args: trl.trainer.grpo_config.GRPOConfig | None = None train_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | None = None eval_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | dict[str, datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset] | None = None processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.processing_utils.ProcessorMixin | None = None reward_processing_classes: transformers.tokenization_utils_base.PreTrainedTokenizerBase | list[transformers.tokenization_utils_base.PreTrainedTokenizerBase] | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple = (None, None) peft_config: PeftConfig | None = None rollout_func: collections.abc.Callable[[list[str], 'GRPOTrainer'], dict[str, typing.Any]] | None = None )
train
< source >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs: typing.Any )
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. - kwargs (
dict[str, Any], optional) — Additional keyword arguments used to hide deprecated arguments
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.