Alignment Losses¶
xaytune supports six alignment methods. Each has a dedicated loss function and can be selected via create_alignment_loss_fn().
| Method | Function | Paper |
|---|---|---|
| DPO | dpo_loss |
Rafailov et al., 2023 |
| SimPO | simpo_loss |
Meng et al., 2024 |
| ORPO | orpo_loss |
Hong et al., 2024 |
| GRPO | grpo_loss |
Shao et al., 2024 |
| PPO | ppo_clip_loss |
Schulman et al., 2017 |
| REINFORCE | reinforce_loss |
Williams, 1992 |
Loss Dispatch¶
create_alignment_loss_fn(*, method, ref_model=None, beta=0.1, kl_coeff=0.04, lambda_weight=1.0, gamma=0.5, clip_eps=0.2)
¶
Create a loss function for the given alignment method.
Returns a callable (model, batch, outputs) -> loss that handles
forward passes on chosen/rejected pairs and reference model inference.
Source code in xaytune/recipes/align/loss_dispatch.py
is_alignment_method(method)
¶
DPO¶
dpo_loss(*, policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=0.1)
¶
Compute Direct Preference Optimization loss (Rafailov et al., 2023).
Source code in xaytune/recipes/align/dpo.py
SimPO¶
simpo_loss(*, policy_chosen_logps, policy_rejected_logps, chosen_lengths, rejected_lengths, beta=2.0, gamma=0.5)
¶
Compute Simple Preference Optimization loss (Meng et al., 2024).
Source code in xaytune/recipes/align/simpo.py
ORPO¶
orpo_loss(*, sft_loss, policy_chosen_logps, policy_rejected_logps, lambda_weight=1.0)
¶
Compute Odds Ratio Preference Optimization loss (Hong et al., 2024).
Source code in xaytune/recipes/align/orpo.py
GRPO¶
grpo_loss(*, logprobs, ref_logprobs, advantages, kl_coeff=0.04)
¶
Compute Group Relative Policy Optimization loss (Shao et al., 2024).
Source code in xaytune/recipes/align/grpo.py
compute_group_advantages(rewards)
¶
Normalize rewards to zero-mean unit-variance advantages.
Source code in xaytune/recipes/align/grpo.py
PPO / REINFORCE¶
ppo_clip_loss(*, logprobs, old_logprobs, advantages, clip_eps=0.2)
¶
Compute PPO clipped surrogate objective (Schulman et al., 2017).
Source code in xaytune/recipes/align/ppo.py
ppo_value_loss(*, values, returns)
¶
reinforce_loss(*, logprobs, advantages)
¶
Log-Probabilities¶
get_per_token_logps(logits, labels)
¶
Compute per-token log probabilities from logits and label ids.
Source code in xaytune/recipes/align/logprobs.py
get_sequence_logps(logits, labels, mask=None)
¶
Sum per-token log probabilities into a sequence-level log probability.
Source code in xaytune/recipes/align/logprobs.py
get_model_logps(model, input_ids, attention_mask=None, labels=None)
¶
Run a forward pass and return sequence log probabilities (no grad).
Source code in xaytune/recipes/align/logprobs.py
Rewards¶
default_reward(prompt, response)
¶
length_penalty_reward(prompt, response, *, target_length=200, penalty_scale=0.001)
¶
Penalize responses that deviate from target_length characters.
Source code in xaytune/recipes/align/rewards.py
format_check_reward(prompt, response, *, required_markers=None)
¶
Reward based on the fraction of required_markers present in the response.
Source code in xaytune/recipes/align/rewards.py
composite_reward(prompt, response, *, reward_names=None, weights=None)
¶
Weighted combination of multiple registered reward functions.