TRL documentation

Post-Training Toolkit Integration

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.27.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Post-Training Toolkit Integration

Post-Training Toolkit is a diagnostic and observability layer for RLHF training runs. Add one callback to any TRL trainer and get auto-metrics, crash postmortems, and literature-backed heuristics—without writing glue code.

It was built to operationalize the debugging patterns we found most useful when running post-training at scale.

Usage

  1. First, install Post-Training Toolkit:
pip install post-training-toolkit
  1. Add one callback to your trainer. That’s it!
DPO
PPO
SFT
ORPO
KTO
CPO
GRPO
from post_training_toolkit import DiagnosticsCallback
from trl import DPOTrainer

trainer = DPOTrainer(
    model=model,
    args=training_args,
    callbacks=[DiagnosticsCallback()],  # ← Just add this
    ...
)
trainer.train()

What You Get

Example output:

[HIGH] DPO loss stuck at ~0.693 (random chance). Model may not be learning preferences.
       Ref: Rafailov et al. (2023) 'DPO', Section 4.2

[RECOMMENDED] Increase learning rate 2-5x, check data quality, or reduce beta.

Example Demo

See a full working example with auto-stop in action:

📂 demo/live_demo.ipynb

📂 demo/scripts/custom_heuristic.py

1. Auto-Metrics

The callback automatically captures algorithm-specific metrics, backed by the latest research and industry push:

Trainer Key Metrics Captured
DPO loss, win_rate, reward_margin, logps_chosen/rejected
PPO policy_loss, value_loss, entropy, clip_fraction, KL
GRPO group rewards, advantages, policy loss, KL
SFT loss, perplexity, accuracy
ORPO sft_loss, odds_ratio_loss, log_odds_ratio
KTO kl, logps for desirable/undesirable

2. Crash Postmortems

If training crashes or gets interrupted, you get a postmortem.json with full context:

{
  "exit_reason": "exception",
  "last_step": 847,
  "timestamp": "2025-12-17T19:26:04Z",
  "final_metrics": {"dpo_loss": 0.693, "win_rate": 0.52}
}

No more “what step did it die on?”

3. Auto-Stop on Critical Issues

Enable automatic training termination when critical issues are detected:

callback = DiagnosticsCallback(stop_on_critical=True)

Distributed Training

Works automatically with multi-GPU setups. Zero configuration needed:

accelerate launch --num_processes 8 train.py

Automatically detects stragglers, aggregates metrics across ranks, and tracks memory balance.

Update on GitHub