TRL documentation

GFPO

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.27.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GFPO

This feature implements the GFPO algorithm to enforce concise reasoning in the model’s output generation, as proposed in the paper Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning.

Usage

To activate GFPO in GFPOTrainer:

  • set num_remains_in_group in GFPOConfig
  • define a group filter function and set it to group_filter_func in GFPOTrainer. group_filter_func will score the num_generations completions and The GFPOTrainer filters groups according to their scores to get top num_remains_in_group completions as a new group. Model will be trained on the filtered group.
# train_gfpo.py
from trl.experimental.gfpo import GFPOConfig, GFPOTrainer

# dummy group filter to scores the completions based on its indice in group
class GroupFilter:
    def __call__(self, group_completions, group_rewards, **kwargs):
        group_scores = []
        for completions, rewards in zip(group_completions, group_rewards):
            scores = [float(i) for i in range(len(completions))]
            group_scores.append(scores)
        return group_scores

training_args = GFPOConfig(
    output_dir="Qwen3-0.6B-GFPO",
    per_device_train_batch_size=4,
    num_remains_in_group=2,
    bf16=True,
)
trainer = GFPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=...,
    train_dataset=...,
    args=training_args,
    group_filter_func=GroupFilter(),
)
trainer.train()

GFPOTrainer

class trl.experimental.gfpo.GFPOTrainer

< >

( model reward_funcs args = None train_dataset = None eval_dataset = None processing_class = None reward_processing_classes = None group_filter_func = None callbacks = None optimizers = (None, None) peft_config = None )

train

< >

( resume_from_checkpoint: str | bool | None = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: list[str] | None = None )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

Main training entry point.

save_model

< >

( output_dir: str | None = None _internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

GFPOConfig

class trl.experimental.gfpo.GFPOConfig

< >

( output_dir: str | None = None do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 gradient_accumulation_steps: int = 1 eval_accumulation_steps: int | None = None eval_delay: float = 0 torch_empty_cache_steps: int | None = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_ratio: float | None = None warmup_steps: float = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: str | None = None logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_total_limit: int | None = None enable_jit_checkpoint: bool = False save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False use_cpu: bool = False seed: int = 42 data_seed: int | None = None bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None local_rank: int = -1 ddp_backend: str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' dataloader_drop_last: bool = False eval_steps: float | None = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: int | None = None run_name: str | None = None disable_tqdm: bool | None = None remove_unused_columns: bool | None = False label_names: list[str] | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None deepspeed: dict | str | None = None label_smoothing_factor: float = 0.0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None group_by_length: bool = False length_column_name: str = 'length' report_to: None | str | list[str] = 'none' project: str = 'huggingface' trackio_space_id: str | None = 'trackio' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True push_to_hub: bool = False resume_from_checkpoint: str | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_token: str | None = None hub_private_repo: bool | None = None hub_always_push: bool = False hub_revision: str | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None include_for_metrics: list = <factory> eval_do_concat_batches: bool = True auto_find_batch_size: bool = False full_determinism: bool = False ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None include_num_input_tokens_seen: str | bool = 'no' neftune_noise_alpha: float | None = None optim_target_modules: None | str | list[str] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None eval_use_gather_object: bool = False average_tokens_across_devices: bool = True use_cache: bool = False model_init_kwargs: dict | str | None = None disable_dropout: bool = False cast_lm_head_to_fp32: bool = False num_generations: int | None = 8 num_generations_eval: int | None = None max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool | None = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int = 0 min_p: float | None = None generation_kwargs: dict | None = None chat_template_kwargs: dict | None = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: str | None = None use_vllm: bool = False vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_enable_sleep_mode: bool = False vllm_structured_outputs_regex: str | None = None vllm_server_base_url: str | None = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_group_port: int = 51216 vllm_gpu_memory_utilization: float = 0.3 vllm_max_model_length: int | None = None vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: float | None = None epsilon_high: float | None = None sapo_temperature_neg: float = 1.05 sapo_temperature_pos: float = 1.0 importance_sampling_level: str = 'token' reward_weights: list[float] | None = None multi_objective_aggregation: str = 'sum_then_normalize' scale_rewards: str = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 max_tool_calling_iterations: int | None = None vllm_importance_sampling_correction: bool = True vllm_importance_sampling_mode: str = 'sequence_mask' vllm_importance_sampling_cap: float = 3.0 off_policy_mask_threshold: float | None = None use_bias_correction_kl: bool = False log_completions: bool = False num_completions_to_print: int | None = None log_unique_prompts: bool = False log_completions_hub_repo: str | None = None num_remains_in_group: int | None = None )

Update on GitHub