Skip to content

Alignment Losses

xaytune supports six alignment methods. Each has a dedicated loss function and can be selected via create_alignment_loss_fn().

Method Function Paper
DPO dpo_loss Rafailov et al., 2023
SimPO simpo_loss Meng et al., 2024
ORPO orpo_loss Hong et al., 2024
GRPO grpo_loss Shao et al., 2024
PPO ppo_clip_loss Schulman et al., 2017
REINFORCE reinforce_loss Williams, 1992

Loss Dispatch

create_alignment_loss_fn(*, method, ref_model=None, beta=0.1, kl_coeff=0.04, lambda_weight=1.0, gamma=0.5, clip_eps=0.2)

Create a loss function for the given alignment method.

Returns a callable (model, batch, outputs) -> loss that handles forward passes on chosen/rejected pairs and reference model inference.

Source code in xaytune/recipes/align/loss_dispatch.py
def create_alignment_loss_fn(
    *,
    method: str,
    ref_model: Any | None = None,
    beta: float = 0.1,
    kl_coeff: float = 0.04,
    lambda_weight: float = 1.0,
    gamma: float = 0.5,
    clip_eps: float = 0.2,
) -> Callable[..., torch.Tensor]:
    """Create a loss function for the given alignment method.

    Returns a callable ``(model, batch, outputs) -> loss`` that handles
    forward passes on chosen/rejected pairs and reference model inference.
    """

    def loss_fn(
        model: Any,
        batch: dict[str, Any],
        outputs: Any,
    ) -> torch.Tensor:
        if not _has_alignment_fields(method, batch):
            loss: torch.Tensor = outputs.loss if hasattr(outputs, "loss") else outputs
            return loss

        if method == "dpo":
            return _dpo_step(model, batch, ref_model, beta=beta)
        elif method == "grpo":
            return _grpo_step(model, batch, ref_model, kl_coeff=kl_coeff)
        elif method == "orpo":
            return _orpo_step(model, batch, outputs, lambda_weight=lambda_weight)
        elif method == "simpo":
            return _simpo_step(model, batch, beta=beta, gamma=gamma)
        elif method == "ppo":
            return _ppo_step(model, batch, clip_eps=clip_eps)
        elif method == "reinforce":
            return _reinforce_step(model, batch)
        else:
            raise ValueError(f"Unknown alignment method: {method}")

    return loss_fn

is_alignment_method(method)

Return whether method is a known alignment method.

Source code in xaytune/recipes/align/loss_dispatch.py
def is_alignment_method(method: str) -> bool:
    """Return whether *method* is a known alignment method."""
    return method in ALIGNMENT_METHODS

DPO

dpo_loss(*, policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=0.1)

Compute Direct Preference Optimization loss (Rafailov et al., 2023).

Source code in xaytune/recipes/align/dpo.py
def dpo_loss(
    *,
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    beta: float = 0.1,
) -> torch.Tensor:
    """Compute Direct Preference Optimization loss (Rafailov et al., 2023)."""
    chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
    rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)

    logits = chosen_rewards - rejected_rewards

    return -F.logsigmoid(logits).mean()

SimPO

simpo_loss(*, policy_chosen_logps, policy_rejected_logps, chosen_lengths, rejected_lengths, beta=2.0, gamma=0.5)

Compute Simple Preference Optimization loss (Meng et al., 2024).

Source code in xaytune/recipes/align/simpo.py
def simpo_loss(
    *,
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    chosen_lengths: torch.Tensor,
    rejected_lengths: torch.Tensor,
    beta: float = 2.0,
    gamma: float = 0.5,
) -> torch.Tensor:
    """Compute Simple Preference Optimization loss (Meng et al., 2024)."""
    chosen_avg = policy_chosen_logps / chosen_lengths.float()
    rejected_avg = policy_rejected_logps / rejected_lengths.float()

    logits = beta * (chosen_avg - rejected_avg) - gamma

    return -F.logsigmoid(logits).mean()

ORPO

orpo_loss(*, sft_loss, policy_chosen_logps, policy_rejected_logps, lambda_weight=1.0)

Compute Odds Ratio Preference Optimization loss (Hong et al., 2024).

Source code in xaytune/recipes/align/orpo.py
def orpo_loss(
    *,
    sft_loss: torch.Tensor,
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    lambda_weight: float = 1.0,
) -> torch.Tensor:
    """Compute Odds Ratio Preference Optimization loss (Hong et al., 2024)."""
    chosen_odds = policy_chosen_logps.exp() / (1 - policy_chosen_logps.exp())
    rejected_odds = policy_rejected_logps.exp() / (1 - policy_rejected_logps.exp())

    log_odds_ratio = torch.log(chosen_odds / rejected_odds)

    or_loss = -F.logsigmoid(log_odds_ratio).mean()

    return sft_loss + lambda_weight * or_loss

GRPO

grpo_loss(*, logprobs, ref_logprobs, advantages, kl_coeff=0.04)

Compute Group Relative Policy Optimization loss (Shao et al., 2024).

Source code in xaytune/recipes/align/grpo.py
def grpo_loss(
    *,
    logprobs: torch.Tensor,
    ref_logprobs: torch.Tensor,
    advantages: torch.Tensor,
    kl_coeff: float = 0.04,
) -> torch.Tensor:
    """Compute Group Relative Policy Optimization loss (Shao et al., 2024)."""
    policy_loss = -(logprobs * advantages).mean()

    kl = (logprobs - ref_logprobs).mean()

    return policy_loss + kl_coeff * kl

compute_group_advantages(rewards)

Normalize rewards to zero-mean unit-variance advantages.

Source code in xaytune/recipes/align/grpo.py
def compute_group_advantages(rewards: torch.Tensor) -> torch.Tensor:
    """Normalize rewards to zero-mean unit-variance advantages."""
    if rewards.numel() <= 1:
        return torch.zeros_like(rewards)

    mean = rewards.mean()
    std = rewards.std()

    if std < 1e-8:
        return torch.zeros_like(rewards)

    return (rewards - mean) / (std + 1e-8)

PPO / REINFORCE

ppo_clip_loss(*, logprobs, old_logprobs, advantages, clip_eps=0.2)

Compute PPO clipped surrogate objective (Schulman et al., 2017).

Source code in xaytune/recipes/align/ppo.py
def ppo_clip_loss(
    *,
    logprobs: torch.Tensor,
    old_logprobs: torch.Tensor,
    advantages: torch.Tensor,
    clip_eps: float = 0.2,
) -> torch.Tensor:
    """Compute PPO clipped surrogate objective (Schulman et al., 2017)."""
    ratio = torch.exp(logprobs - old_logprobs)

    unclipped = ratio * advantages
    clipped = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages

    return -torch.min(unclipped, clipped).mean()

ppo_value_loss(*, values, returns)

Compute PPO value function MSE loss.

Source code in xaytune/recipes/align/ppo.py
def ppo_value_loss(
    *,
    values: torch.Tensor,
    returns: torch.Tensor,
) -> torch.Tensor:
    """Compute PPO value function MSE loss."""
    return (values - returns).pow(2).mean()

reinforce_loss(*, logprobs, advantages)

Compute REINFORCE policy gradient loss.

Source code in xaytune/recipes/align/ppo.py
def reinforce_loss(
    *,
    logprobs: torch.Tensor,
    advantages: torch.Tensor,
) -> torch.Tensor:
    """Compute REINFORCE policy gradient loss."""
    return -(logprobs * advantages).mean()

Log-Probabilities

get_per_token_logps(logits, labels)

Compute per-token log probabilities from logits and label ids.

Source code in xaytune/recipes/align/logprobs.py
def get_per_token_logps(
    logits: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    """Compute per-token log probabilities from logits and label ids."""
    log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
    target = labels[:, 1:]
    return torch.gather(log_probs, dim=2, index=target.unsqueeze(2)).squeeze(2)

get_sequence_logps(logits, labels, mask=None)

Sum per-token log probabilities into a sequence-level log probability.

Source code in xaytune/recipes/align/logprobs.py
def get_sequence_logps(
    logits: torch.Tensor,
    labels: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Sum per-token log probabilities into a sequence-level log probability."""
    per_token = get_per_token_logps(logits, labels)
    if mask is not None:
        per_token = per_token * mask[:, 1:]
    return per_token.sum(dim=-1)

get_model_logps(model, input_ids, attention_mask=None, labels=None)

Run a forward pass and return sequence log probabilities (no grad).

Source code in xaytune/recipes/align/logprobs.py
def get_model_logps(
    model: torch.nn.Module,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    labels: torch.Tensor | None = None,
) -> torch.Tensor:
    """Run a forward pass and return sequence log probabilities (no grad)."""
    if labels is None:
        labels = input_ids
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    return get_sequence_logps(outputs.logits, labels, attention_mask)

Rewards

default_reward(prompt, response)

Baseline reward that always returns 0.

Source code in xaytune/recipes/align/rewards.py
@register_reward("default")
def default_reward(prompt: str, response: str) -> float:
    """Baseline reward that always returns 0."""
    return 0.0

length_penalty_reward(prompt, response, *, target_length=200, penalty_scale=0.001)

Penalize responses that deviate from target_length characters.

Source code in xaytune/recipes/align/rewards.py
@register_reward("length_penalty")
def length_penalty_reward(
    prompt: str,
    response: str,
    *,
    target_length: int = 200,
    penalty_scale: float = 0.001,
) -> float:
    """Penalize responses that deviate from *target_length* characters."""
    diff = abs(len(response) - target_length)
    return -penalty_scale * diff

format_check_reward(prompt, response, *, required_markers=None)

Reward based on the fraction of required_markers present in the response.

Source code in xaytune/recipes/align/rewards.py
@register_reward("format_check")
def format_check_reward(
    prompt: str,
    response: str,
    *,
    required_markers: list[str] | None = None,
) -> float:
    """Reward based on the fraction of *required_markers* present in the response."""
    if required_markers is None:
        required_markers = []
    if not required_markers:
        return 0.0
    matched = sum(1 for m in required_markers if m in response)
    return matched / len(required_markers)

composite_reward(prompt, response, *, reward_names=None, weights=None)

Weighted combination of multiple registered reward functions.

Source code in xaytune/recipes/align/rewards.py
@register_reward("composite")
def composite_reward(
    prompt: str,
    response: str,
    *,
    reward_names: list[str] | None = None,
    weights: list[float] | None = None,
) -> float:
    """Weighted combination of multiple registered reward functions."""
    if not reward_names:
        return 0.0
    if weights is None:
        weights = [1.0] * len(reward_names)
    total = 0.0
    for name, weight in zip(reward_names, weights):
        fn = reward_registry.get(name)
        total += weight * fn(prompt, response)
    return total