TRL documentation
BEMA for Reference Model
BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
Usage
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
bema_callback = BEMACallback(update_ref_model=True)
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
train_dataset=pref_dataset,
processing_class=tokenizer,
callbacks=[bema_callback],
)
trainer.train()DPOTrainer
train
< source >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs: typing.Any )
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. - kwargs (
dict[str, Any], optional) — Additional keyword arguments used to hide deprecated arguments
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
BEMACallback
class trl.BEMACallback
< source >( update_freq: int = 400 ema_power: float = 0.5 bias_power: float = 0.2 lag: int = 10 update_after: int = 0 multiplier: float = 1.0 min_ema_multiplier: float = 0.0 device: str = 'cpu' update_ref_model: bool = False ref_model_update_freq: int = 400 ref_model_update_after: int = 0 )
Parameters
- update_freq (
int, optional, defaults to400) — Update the BEMA weights every X steps. Denoted this as {@html "ϕ"} in the paper. - ema_power (
float, optional, defaults to0.5) — Power for the EMA decay factor. Denoted {@html "κ"} in the paper. To disable EMA, set this to0.0. - bias_power (
float, optional, defaults to0.2) — Power for the BEMA scaling factor. Denoted {@html "η"} in the paper. To disable BEMA, set this to0.0. - lag (
int, optional, defaults to10) — Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual starting age for the updates. Denoted as {@html "ρ"} in the paper. - update_after (
int, optional, defaults to0) — Burn-in time before starting to update the BEMA weights. Denoted {@html "τ"} in the paper. - multiplier (
float, optional, defaults to1.0) — Initial value for the EMA decay factor. Denoted as {@html "γ"} in the paper. - min_ema_multiplier (
float, optional, defaults to0.0) — Minimum value for the EMA decay factor. - device (
str, optional, defaults to"cpu") — Device to use for the BEMA buffers, e.g."cpu"or"cuda". Note that in most cases, this device SHOULD BE DIFFERENT from the device used for training in order to avoid OOM. - update_ref_model (
bool, optional, defaults toFalse) — Whether to update the reference model with BEMA weights. This creates a lagged, smoothed version of the main model as the reference model. - ref_model_update_freq (
int, optional, defaults to400) — Update the reference model with BEMA weights every this many steps. - ref_model_update_after (
int, optional, defaults to0) — Number of steps to wait before starting to update the reference model.
A TrainerCallback that implements BEMA (Bias-Corrected Exponential Moving Average) by Adam Block and Cyril Zhang. Code from https://github.com/abblock/bema under MIT license.
BEMA computes model weights that scale like:
where is the current model weights, is a snapshot of the model weights at the
first update_after step, is the exponential moving average of the model weights, and is a scaling factor that decays with the number of steps as
The EMA is computed as:
where is a decay factor that decays with the number of steps as