TRL documentation

BEMA for Reference Model

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.25.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

BEMA for Reference Model

This feature implements the BEMA algorithm to update the reference model during DPO training.

Usage

from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

bema_callback = BEMACallback(update_ref_model=True)

model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    train_dataset=pref_dataset,
    processing_class=tokenizer,
    callbacks=[bema_callback],
)

trainer.train()

DPOTrainer

class trl.DPOTrainer

< >

( *args **kwargs )

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs: typing.Any )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments used to hide deprecated arguments

Main training entry point.

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

BEMACallback

class trl.BEMACallback

< >

( update_freq: int = 400 ema_power: float = 0.5 bias_power: float = 0.2 lag: int = 10 update_after: int = 0 multiplier: float = 1.0 min_ema_multiplier: float = 0.0 device: str = 'cpu' update_ref_model: bool = False ref_model_update_freq: int = 400 ref_model_update_after: int = 0 )

Parameters

  • update_freq (int, optional, defaults to 400) — Update the BEMA weights every X steps. Denoted this as {@html "ϕ \\phi ϕ"} in the paper.
  • ema_power (float, optional, defaults to 0.5) — Power for the EMA decay factor. Denoted {@html "κ \\kappa κ"} in the paper. To disable EMA, set this to 0.0.
  • bias_power (float, optional, defaults to 0.2) — Power for the BEMA scaling factor. Denoted {@html "η \\eta η"} in the paper. To disable BEMA, set this to 0.0.
  • lag (int, optional, defaults to 10) — Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual starting age for the updates. Denoted as {@html "ρ \\rho ρ"} in the paper.
  • update_after (int, optional, defaults to 0) — Burn-in time before starting to update the BEMA weights. Denoted {@html "τ \\tau τ"} in the paper.
  • multiplier (float, optional, defaults to 1.0) — Initial value for the EMA decay factor. Denoted as {@html "γ \\gamma γ"} in the paper.
  • min_ema_multiplier (float, optional, defaults to 0.0) — Minimum value for the EMA decay factor.
  • device (str, optional, defaults to "cpu") — Device to use for the BEMA buffers, e.g. "cpu" or "cuda". Note that in most cases, this device SHOULD BE DIFFERENT from the device used for training in order to avoid OOM.
  • update_ref_model (bool, optional, defaults to False) — Whether to update the reference model with BEMA weights. This creates a lagged, smoothed version of the main model as the reference model.
  • ref_model_update_freq (int, optional, defaults to 400) — Update the reference model with BEMA weights every this many steps.
  • ref_model_update_after (int, optional, defaults to 0) — Number of steps to wait before starting to update the reference model.

A TrainerCallback that implements BEMA (Bias-Corrected Exponential Moving Average) by Adam Block and Cyril Zhang. Code from https://github.com/abblock/bema under MIT license.

BEMA computes model weights that scale like: θt=αt(θtθ0)+EMAt \theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t

where θt \theta_t is the current model weights, θ0 \theta_0 is a snapshot of the model weights at the first update_after step, EMAt \text{EMA}_t is the exponential moving average of the model weights, andαt \alpha_t is a scaling factor that decays with the number of steps t t as αt=(ρ+γt)η. \alpha_t = (\rho + \gamma \cdot t)^{-\eta}.

The EMA is computed as: EMAt=(1βt)EMAt1+βtθt \text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t

where βt \beta_t is a decay factor that decays with the number of steps t t as βt=(ρ+γt)κ. \beta_t = (\rho + \gamma \cdot t)^{-\kappa}.

Example:

from trl import BEMACallback

trainer = Trainer(..., callbacks=[BEMACallback()])
Update on GitHub