Skip to content

Trainer

The trainer module contains the training loop, checkpointing, scheduling, distributed strategies, and the LR finder.


Trainer

Trainer(config, callback_manager=None)

Core training loop with mixed precision, gradient accumulation, and callbacks.

Handles optimizer creation, learning rate scheduling, AMP autocast/scaler, gradient clipping, and checkpoint resume. Fires callback events at each lifecycle point so evaluation, logging, and checkpointing are pluggable.

Source code in xaytune/trainer/loop.py
def __init__(
    self,
    config: TrainerConfig,
    callback_manager: CallbackManager | None = None,
) -> None:
    self.config = config
    self.callback_manager = callback_manager or CallbackManager()

Checkpointing

save_checkpoint(*, output_dir, model, optimizer, state, scheduler=None, scaler=None)

Save model, optimizer, scheduler, and scaler state to output_dir.

Writes model.pt, optimizer.pt, optional scheduler.pt and scaler.pt, plus a metadata.json with step/epoch/metrics.

Source code in xaytune/trainer/checkpointing.py
def save_checkpoint(
    *,
    output_dir: str,
    model: Any,
    optimizer: Any,
    state: TrainState,
    scheduler: Any | None = None,
    scaler: Any | None = None,
) -> None:
    """Save model, optimizer, scheduler, and scaler state to *output_dir*.

    Writes ``model.pt``, ``optimizer.pt``, optional ``scheduler.pt`` and
    ``scaler.pt``, plus a ``metadata.json`` with step/epoch/metrics.
    """
    ckpt_path = Path(output_dir)
    ckpt_path.mkdir(parents=True, exist_ok=True)

    model_state = model.state_dict() if hasattr(model, "state_dict") else {}
    torch.save(model_state, ckpt_path / "model.pt")

    optimizer_state = optimizer.state_dict() if hasattr(optimizer, "state_dict") else {}
    torch.save(optimizer_state, ckpt_path / "optimizer.pt")

    if scheduler is not None and hasattr(scheduler, "state_dict"):
        torch.save(scheduler.state_dict(), ckpt_path / "scheduler.pt")

    if scaler is not None and hasattr(scaler, "state_dict"):
        torch.save(scaler.state_dict(), ckpt_path / "scaler.pt")

    metadata = {
        "global_step": state.global_step,
        "epoch": state.epoch,
        "step": state.step,
        "metrics": state.metrics,
    }
    (ckpt_path / "metadata.json").write_text(json.dumps(metadata, indent=2))

load_checkpoint(*, checkpoint_dir, model, optimizer, scheduler=None, scaler=None)

Restore model, optimizer, and training state from a checkpoint directory.

Returns a :class:~xaytune.trainer.callbacks.TrainState with the saved step, epoch, and metrics so training can resume.

Raises:

Type Description
FileNotFoundError

If checkpoint_dir doesn't exist.

Source code in xaytune/trainer/checkpointing.py
def load_checkpoint(
    *,
    checkpoint_dir: str,
    model: Any,
    optimizer: Any,
    scheduler: Any | None = None,
    scaler: Any | None = None,
) -> TrainState:
    """Restore model, optimizer, and training state from a checkpoint directory.

    Returns a :class:`~xaytune.trainer.callbacks.TrainState` with the
    saved step, epoch, and metrics so training can resume.

    Raises:
        FileNotFoundError: If *checkpoint_dir* doesn't exist.
    """
    ckpt_path = Path(checkpoint_dir)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_dir}")

    model_path = ckpt_path / "model.pt"
    if model_path.exists():
        model_state = torch.load(model_path, weights_only=True)
        if model_state and hasattr(model, "load_state_dict"):
            model.load_state_dict(model_state)

    optimizer_path = ckpt_path / "optimizer.pt"
    if optimizer_path.exists():
        opt_state = torch.load(optimizer_path, weights_only=True)
        if opt_state and hasattr(optimizer, "load_state_dict"):
            optimizer.load_state_dict(opt_state)

    scheduler_path = ckpt_path / "scheduler.pt"
    if scheduler is not None and scheduler_path.exists():
        sched_state = torch.load(scheduler_path, weights_only=True)
        if sched_state and hasattr(scheduler, "load_state_dict"):
            scheduler.load_state_dict(sched_state)

    scaler_path = ckpt_path / "scaler.pt"
    if scaler is not None and scaler_path.exists():
        scaler_state = torch.load(scaler_path, weights_only=True)
        if scaler_state and hasattr(scaler, "load_state_dict"):
            scaler.load_state_dict(scaler_state)

    metadata_path = ckpt_path / "metadata.json"
    metadata = json.loads(metadata_path.read_text()) if metadata_path.exists() else {}

    return TrainState(
        global_step=metadata.get("global_step", 0),
        epoch=metadata.get("epoch", 0),
        step=metadata.get("step", 0),
        metrics=metadata.get("metrics", {}),
    )

find_latest_checkpoint(output_dir)

Find the checkpoint with the highest global_step in output_dir.

Source code in xaytune/trainer/checkpointing.py
def find_latest_checkpoint(output_dir: str) -> str | None:
    """Find the checkpoint with the highest global_step in *output_dir*."""
    base = Path(output_dir)
    if not base.exists():
        return None

    checkpoints = []
    for d in base.iterdir():
        if d.is_dir() and (d / "metadata.json").exists():
            meta = json.loads((d / "metadata.json").read_text())
            checkpoints.append((meta.get("global_step", 0), str(d)))

    if not checkpoints:
        return None

    checkpoints.sort(key=lambda x: x[0], reverse=True)
    return checkpoints[0][1]

AsyncCheckpointSaver()

Write checkpoints in a background thread to avoid blocking training.

Source code in xaytune/trainer/async_checkpoint.py
def __init__(self) -> None:
    self._thread: threading.Thread | None = None
    self._error: BaseException | None = None

Scheduling

create_scheduler(optimizer, scheduler_type, total_steps, warmup_steps)

Create an LR scheduler with optional linear warmup.

Parameters:

Name Type Description Default
optimizer Any

The optimizer to schedule.

required
scheduler_type str

"cosine", "linear", "constant", or "constant_with_warmup".

required
total_steps int

Total training steps (for decay calculation).

required
warmup_steps int

Number of linear warmup steps.

required

Raises:

Type Description
ValueError

If scheduler_type is not recognized.

Source code in xaytune/trainer/scheduler.py
def create_scheduler(
    optimizer: Any,
    scheduler_type: str,
    total_steps: int,
    warmup_steps: int,
) -> LambdaLR:
    """Create an LR scheduler with optional linear warmup.

    Args:
        optimizer: The optimizer to schedule.
        scheduler_type: ``"cosine"``, ``"linear"``, ``"constant"``,
            or ``"constant_with_warmup"``.
        total_steps: Total training steps (for decay calculation).
        warmup_steps: Number of linear warmup steps.

    Raises:
        ValueError: If *scheduler_type* is not recognized.
    """
    if scheduler_type not in _VALID_TYPES:
        raise ValueError(
            f"Unknown scheduler type '{scheduler_type}'. "
            f"Valid options: {', '.join(sorted(_VALID_TYPES))}"
        )

    if scheduler_type == "constant":

        def lr_lambda(current_step: int) -> float:
            return 1.0

    elif scheduler_type == "constant_with_warmup":

        def lr_lambda(current_step: int) -> float:
            if warmup_steps > 0 and current_step < warmup_steps:
                return current_step / warmup_steps
            return 1.0

    elif scheduler_type == "linear":

        def lr_lambda(current_step: int) -> float:
            if warmup_steps > 0 and current_step < warmup_steps:
                return current_step / warmup_steps
            decay_steps = max(total_steps - warmup_steps, 1)
            return max(0.0, 1.0 - (current_step - warmup_steps) / decay_steps)

    elif scheduler_type == "cosine":

        def lr_lambda(current_step: int) -> float:
            if warmup_steps > 0 and current_step < warmup_steps:
                return current_step / warmup_steps
            decay_steps = max(total_steps - warmup_steps, 1)
            progress = (current_step - warmup_steps) / decay_steps
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)

resolve_warmup_steps(warmup_steps, warmup_ratio, total_steps)

Return the effective warmup step count from either an absolute count or ratio.

Source code in xaytune/trainer/scheduler.py
def resolve_warmup_steps(
    warmup_steps: int,
    warmup_ratio: float,
    total_steps: int,
) -> int:
    """Return the effective warmup step count from either an absolute count or ratio."""
    if warmup_steps > 0:
        return warmup_steps
    if warmup_ratio > 0.0:
        return int(warmup_ratio * total_steps)
    return 0

LR Finder

lr_find(model, train_dataloader, *, start_lr=1e-07, end_lr=1.0, num_iterations=100, smoothing_factor=0.05, divergence_threshold=4.0, loss_fn=None)

Run an LR range test to find the optimal learning rate.

Trains with exponentially increasing LR from start_lr to end_lr, tracking loss. Stops early if loss diverges. Model weights are restored to their original state afterward.

Parameters:

Name Type Description Default
model Any

The model to test.

required
train_dataloader Any

Training data loader.

required
start_lr float

Starting learning rate.

1e-07
end_lr float

Maximum learning rate to test.

1.0
num_iterations int

Number of training iterations.

100
smoothing_factor float

Exponential smoothing for loss curve.

0.05
divergence_threshold float

Stop when smoothed loss exceeds this multiple of the best smoothed loss.

4.0
loss_fn Any | None

Optional custom loss function (model, batch, outputs) -> loss.

None

Returns:

Type Description
LRFinderResult

class:LRFinderResult with tested LRs, losses, and a suggestion.

Source code in xaytune/trainer/lr_finder.py
def lr_find(
    model: Any,
    train_dataloader: Any,
    *,
    start_lr: float = 1e-7,
    end_lr: float = 1.0,
    num_iterations: int = 100,
    smoothing_factor: float = 0.05,
    divergence_threshold: float = 4.0,
    loss_fn: Any | None = None,
) -> LRFinderResult:
    """Run an LR range test to find the optimal learning rate.

    Trains with exponentially increasing LR from *start_lr* to *end_lr*,
    tracking loss.  Stops early if loss diverges.  Model weights are
    restored to their original state afterward.

    Args:
        model: The model to test.
        train_dataloader: Training data loader.
        start_lr: Starting learning rate.
        end_lr: Maximum learning rate to test.
        num_iterations: Number of training iterations.
        smoothing_factor: Exponential smoothing for loss curve.
        divergence_threshold: Stop when smoothed loss exceeds this
            multiple of the best smoothed loss.
        loss_fn: Optional custom loss function ``(model, batch, outputs) -> loss``.

    Returns:
        :class:`LRFinderResult` with tested LRs, losses, and a suggestion.
    """
    if not train_dataloader:
        raise ValueError("train_dataloader must not be empty")

    saved_state = copy.deepcopy(model.state_dict())

    optimizer = torch.optim.SGD(model.parameters(), lr=start_lr)
    lr_mult = (end_lr / start_lr) ** (1.0 / num_iterations)

    lrs: list[float] = []
    losses: list[float] = []
    smooth_losses: list[float] = []
    best_smooth = float("inf")

    batches = islice(_cycle(train_dataloader), num_iterations)

    for batch in batches:
        current_lr = optimizer.param_groups[0]["lr"]

        if isinstance(batch, dict):
            outputs = model(**batch)
        else:
            outputs = model(batch)

        if loss_fn is not None:
            loss = loss_fn(model, batch, outputs)
        else:
            loss = outputs.loss if hasattr(outputs, "loss") else outputs

        raw_loss = loss.item()
        if smooth_losses:
            smooth = smoothing_factor * raw_loss + (1 - smoothing_factor) * smooth_losses[-1]
        else:
            smooth = raw_loss

        lrs.append(current_lr)
        losses.append(raw_loss)
        smooth_losses.append(smooth)

        if smooth < best_smooth:
            best_smooth = smooth

        if best_smooth > 0 and smooth > divergence_threshold * best_smooth:
            break

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        for pg in optimizer.param_groups:
            pg["lr"] *= lr_mult

    model.load_state_dict(saved_state)

    suggested = _suggest_lr(lrs, smooth_losses)
    return LRFinderResult(lrs=lrs, losses=losses, suggested_lr=suggested)

LRFinderResult(lrs, losses, suggested_lr) dataclass

Result of an LR range test.

Attributes:

Name Type Description
lrs list[float]

Learning rates tested.

losses list[float]

Raw loss values at each LR.

suggested_lr float | None

Recommended LR (steepest descent point).

Distributed Training

DistributedContext(rank=0, world_size=1, local_rank=0) dataclass

Process-level distributed training state (rank, world size, device).

get_strategy(strategy, world_size=1)

Resolve "auto" strategy to "fsdp" (multi-GPU) or "none" (single).

Source code in xaytune/trainer/distributed.py
def get_strategy(strategy: str, world_size: int = 1) -> str:
    """Resolve ``"auto"`` strategy to ``"fsdp"`` (multi-GPU) or ``"none"`` (single)."""
    if strategy == "auto":
        return "fsdp" if world_size > 1 else "none"
    return strategy

wrap_model_distributed(model, *, strategy, ctx, fsdp_config=None, deepspeed_config=None, mixed_precision='bf16', **kwargs)

Wrap a model with the chosen distributed strategy (DDP, FSDP, or DeepSpeed).

Source code in xaytune/trainer/distributed.py
def wrap_model_distributed(
    model: Any,
    *,
    strategy: str,
    ctx: DistributedContext,
    fsdp_config: Any | None = None,
    deepspeed_config: Any | None = None,
    mixed_precision: str = "bf16",
    **kwargs: Any,
) -> Any:
    """Wrap a model with the chosen distributed strategy (DDP, FSDP, or DeepSpeed)."""
    if strategy == "none":
        return model

    if strategy == "ddp":
        from torch.nn.parallel import DistributedDataParallel

        return DistributedDataParallel(
            model,
            device_ids=[ctx.local_rank] if ctx.local_rank >= 0 else None,
            find_unused_parameters=False,
        )

    if strategy == "fsdp":
        from torch.distributed.fsdp import FullyShardedDataParallel

        fsdp_kwargs: dict[str, Any] = {}

        if fsdp_config is not None:
            from torch.distributed.fsdp import CPUOffload, ShardingStrategy

            strategy_map = {
                "full_shard": ShardingStrategy.FULL_SHARD,
                "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
                "no_shard": ShardingStrategy.NO_SHARD,
            }
            fsdp_kwargs["sharding_strategy"] = strategy_map[fsdp_config.sharding_strategy]

            if fsdp_config.cpu_offload:
                fsdp_kwargs["cpu_offload"] = CPUOffload(offload_params=True)

            if fsdp_config.backward_prefetch is not None:
                from torch.distributed.fsdp import BackwardPrefetch

                prefetch_map = {
                    "backward_pre": BackwardPrefetch.BACKWARD_PRE,
                    "backward_post": BackwardPrefetch.BACKWARD_POST,
                }
                fsdp_kwargs["backward_prefetch"] = prefetch_map[fsdp_config.backward_prefetch]

            if fsdp_config.mixed_precision:
                from torch.distributed.fsdp import MixedPrecision as FSDPMixedPrecision

                dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16}
                mp_dtype = dtype_map.get(mixed_precision)
                if mp_dtype is not None:
                    fsdp_kwargs["mixed_precision"] = FSDPMixedPrecision(
                        param_dtype=mp_dtype,
                        reduce_dtype=mp_dtype,
                        buffer_dtype=mp_dtype,
                    )

            if getattr(fsdp_config, "auto_wrap_min_params", 0) > 0:
                from functools import partial

                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

                fsdp_kwargs["auto_wrap_policy"] = partial(
                    size_based_auto_wrap_policy,
                    min_num_params=fsdp_config.auto_wrap_min_params,
                )

            fsdp_kwargs["forward_prefetch"] = getattr(
                fsdp_config,
                "forward_prefetch",
                False,
            )
            fsdp_kwargs["sync_module_states"] = getattr(
                fsdp_config,
                "sync_module_states",
                True,
            )
            fsdp_kwargs["limit_all_gathers"] = getattr(
                fsdp_config,
                "limit_all_gathers",
                True,
            )

        fsdp_kwargs.update(kwargs)
        return FullyShardedDataParallel(model, **fsdp_kwargs)

    if strategy == "deepspeed":
        if deepspeed_config is not None:
            import deepspeed as ds

            config_dict: dict[str, Any] = {}
            if deepspeed_config.config_file is not None:
                import json

                with open(deepspeed_config.config_file) as f:
                    config_dict = json.load(f)
            else:
                zero_opt: dict[str, Any] = {
                    "stage": deepspeed_config.zero_stage,
                    "overlap_comm": getattr(deepspeed_config, "overlap_comm", True),
                    "contiguous_gradients": getattr(
                        deepspeed_config,
                        "contiguous_gradients",
                        True,
                    ),
                    "reduce_bucket_size": getattr(
                        deepspeed_config,
                        "reduce_bucket_size",
                        500_000_000,
                    ),
                }

                if getattr(deepspeed_config, "offload_optimizer", False):
                    zero_opt["offload_optimizer"] = {"device": "cpu", "pin_memory": True}

                if getattr(deepspeed_config, "offload_param", False):
                    zero_opt["offload_param"] = {"device": "cpu", "pin_memory": True}

                if deepspeed_config.zero_stage == 3:
                    zero_opt["stage3_prefetch_bucket_size"] = getattr(
                        deepspeed_config,
                        "stage3_prefetch_bucket_size",
                        50_000_000,
                    )
                    zero_opt["stage3_param_persistence_threshold"] = getattr(
                        deepspeed_config,
                        "stage3_param_persistence_threshold",
                        100_000,
                    )

                config_dict = {
                    "zero_optimization": zero_opt,
                    "train_batch_size": "auto",
                    "train_micro_batch_size_per_gpu": "auto",
                }

            engine, _, _, _ = ds.initialize(model=model, config=config_dict)
            return engine
        return model

    raise ValueError(f"Unknown strategy: '{strategy}'. Valid options: none, ddp, fsdp, deepspeed.")

init_distributed()

Initialize distributed training from environment variables (RANK, WORLD_SIZE).

Source code in xaytune/trainer/distributed.py
def init_distributed() -> DistributedContext:
    """Initialize distributed training from environment variables (``RANK``, ``WORLD_SIZE``)."""
    rank = int(os.environ.get("RANK", "0"))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))

    if world_size <= 1:
        return DistributedContext()

    import torch.distributed as dist

    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)

    return DistributedContext(rank=rank, world_size=world_size, local_rank=local_rank)

cleanup_distributed(ctx)

Destroy the process group if distributed training is active.

Source code in xaytune/trainer/distributed.py
def cleanup_distributed(ctx: DistributedContext) -> None:
    """Destroy the process group if distributed training is active."""
    if not ctx.is_distributed:
        return
    import torch.distributed as dist

    if dist.is_initialized():
        dist.destroy_process_group()

Device Utilities

get_device(local_rank=0, *, device_type=None)

Return a :class:torch.device for the given rank and device type.

Source code in xaytune/trainer/device.py
def get_device(local_rank: int = 0, *, device_type: str | None = None) -> torch.device:
    """Return a :class:`torch.device` for the given rank and device type."""
    dt = device_type or get_device_type()
    if dt == "cuda":
        return torch.device(f"cuda:{local_rank}")
    return torch.device(dt)

get_device_type()

Detect the best available device type ("cuda", "mps", or "cpu").

Source code in xaytune/trainer/device.py
def get_device_type() -> str:
    """Detect the best available device type (``"cuda"``, ``"mps"``, or ``"cpu"``)."""
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"

seed_all(seed)

Seed Python, PyTorch CPU, CUDA, and MPS random generators.

Source code in xaytune/trainer/device.py
def seed_all(seed: int) -> None:
    """Seed Python, PyTorch CPU, CUDA, and MPS random generators."""
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)