Skip to content

Callbacks

xaytune provides an event-driven callback system for hooking into the training loop. You can log custom metrics, implement early stopping, send notifications, or run any arbitrary code at specific training events.

TrainState

TrainState is a dataclass that tracks the current state of training. It is passed to every callback and returned by recipe functions.

from xaytune.trainer.callbacks import TrainState
Field Type Default Description
step int 0 Current step within the epoch
epoch int 0 Current epoch number
global_step int 0 Total steps across all epochs
num_epochs int 0 Total number of epochs
max_steps int -1 Maximum steps limit (-1 = unlimited)
metrics dict[str, Any] {} Current metrics (loss, learning rate, etc.)
should_stop bool False Set to True to stop training early

Stopping Training Early

Call state.stop_training() from any callback to stop after the current step:

@callbacks.on("step_end")
def early_stop(state):
    if state.metrics.get("loss", float("inf")) < 0.01:
        print("Loss target reached, stopping early.")
        state.stop_training()

CallbackManager

CallbackManager manages callback registration and event firing. Use its on() decorator to register callbacks.

from xaytune.trainer.callbacks import CallbackManager

callbacks = CallbackManager()

@callbacks.on("step_end")
def log_loss(state):
    if state.global_step % 100 == 0:
        print(f"Step {state.global_step}: loss={state.metrics.get('loss', 'N/A')}")

@callbacks.on("train_end")
def on_complete(state):
    print(f"Training complete after {state.global_step} steps.")

Available Events

Event Fired When
train_start Training begins
train_end Training finishes
epoch_start An epoch begins
epoch_end An epoch finishes
step_start A training step begins
step_end A training step finishes
eval_start Evaluation begins
eval_end Evaluation finishes
checkpoint_saved A checkpoint is saved to disk
error An error occurs during training

Firing Events

The trainer fires events automatically. If you are building custom training loops, fire events manually:

callbacks.fire("step_end", state)

Examples

Gradient Logging

@callbacks.on("step_end")
def log_gradients(state):
    if state.global_step % 50 == 0:
        grad_norm = state.metrics.get("grad_norm")
        if grad_norm is not None:
            print(f"Step {state.global_step}: grad_norm={grad_norm:.4f}")

Checkpoint Notification

@callbacks.on("checkpoint_saved")
def notify_checkpoint(state):
    print(f"Checkpoint saved at step {state.global_step}")

Error Handling

@callbacks.on("error")
def handle_error(state):
    print(f"Error at step {state.global_step}: {state.metrics.get('error')}")

Full API Reference

TrainState(step=0, epoch=0, global_step=0, num_epochs=0, max_steps=-1, metrics=dict(), should_stop=False) dataclass

Mutable training state passed to every callback.

Attributes:

Name Type Description
step int

Current step within the epoch.

epoch int

Current epoch index.

global_step int

Total optimizer steps taken across all epochs.

num_epochs int

Total number of epochs configured.

max_steps int

Step limit (-1 = unlimited).

metrics dict[str, Any]

Dict of latest metrics (loss, eval_loss, learning_rate, etc.).

should_stop bool

Set to True to stop training after the current step.

CallbackManager()

Registry for event-driven training callbacks.

Register handlers with the :meth:on decorator and trigger them with :meth:fire. Supported events: train_start, train_end, epoch_start, epoch_end, step_start, step_end, eval_start, eval_end, checkpoint_saved, error.

Source code in xaytune/trainer/callbacks.py
def __init__(self) -> None:
    self._callbacks: dict[str, list[Callable]] = defaultdict(list)