Trainer¶
The trainer module contains the training loop, checkpointing, scheduling, distributed strategies, and the LR finder.
Trainer¶
Trainer(config, callback_manager=None)
¶
Core training loop with mixed precision, gradient accumulation, and callbacks.
Handles optimizer creation, learning rate scheduling, AMP autocast/scaler, gradient clipping, and checkpoint resume. Fires callback events at each lifecycle point so evaluation, logging, and checkpointing are pluggable.
Source code in xaytune/trainer/loop.py
Checkpointing¶
save_checkpoint(*, output_dir, model, optimizer, state, scheduler=None, scaler=None)
¶
Save model, optimizer, scheduler, and scaler state to output_dir.
Writes model.pt, optimizer.pt, optional scheduler.pt and
scaler.pt, plus a metadata.json with step/epoch/metrics.
Source code in xaytune/trainer/checkpointing.py
load_checkpoint(*, checkpoint_dir, model, optimizer, scheduler=None, scaler=None)
¶
Restore model, optimizer, and training state from a checkpoint directory.
Returns a :class:~xaytune.trainer.callbacks.TrainState with the
saved step, epoch, and metrics so training can resume.
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If checkpoint_dir doesn't exist. |
Source code in xaytune/trainer/checkpointing.py
find_latest_checkpoint(output_dir)
¶
Find the checkpoint with the highest global_step in output_dir.
Source code in xaytune/trainer/checkpointing.py
AsyncCheckpointSaver()
¶
Scheduling¶
create_scheduler(optimizer, scheduler_type, total_steps, warmup_steps)
¶
Create an LR scheduler with optional linear warmup.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
optimizer
|
Any
|
The optimizer to schedule. |
required |
scheduler_type
|
str
|
|
required |
total_steps
|
int
|
Total training steps (for decay calculation). |
required |
warmup_steps
|
int
|
Number of linear warmup steps. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If scheduler_type is not recognized. |
Source code in xaytune/trainer/scheduler.py
resolve_warmup_steps(warmup_steps, warmup_ratio, total_steps)
¶
Return the effective warmup step count from either an absolute count or ratio.
Source code in xaytune/trainer/scheduler.py
LR Finder¶
lr_find(model, train_dataloader, *, start_lr=1e-07, end_lr=1.0, num_iterations=100, smoothing_factor=0.05, divergence_threshold=4.0, loss_fn=None)
¶
Run an LR range test to find the optimal learning rate.
Trains with exponentially increasing LR from start_lr to end_lr, tracking loss. Stops early if loss diverges. Model weights are restored to their original state afterward.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
The model to test. |
required |
train_dataloader
|
Any
|
Training data loader. |
required |
start_lr
|
float
|
Starting learning rate. |
1e-07
|
end_lr
|
float
|
Maximum learning rate to test. |
1.0
|
num_iterations
|
int
|
Number of training iterations. |
100
|
smoothing_factor
|
float
|
Exponential smoothing for loss curve. |
0.05
|
divergence_threshold
|
float
|
Stop when smoothed loss exceeds this multiple of the best smoothed loss. |
4.0
|
loss_fn
|
Any | None
|
Optional custom loss function |
None
|
Returns:
| Type | Description |
|---|---|
LRFinderResult
|
class: |
Source code in xaytune/trainer/lr_finder.py
LRFinderResult(lrs, losses, suggested_lr)
dataclass
¶
Result of an LR range test.
Attributes:
| Name | Type | Description |
|---|---|---|
lrs |
list[float]
|
Learning rates tested. |
losses |
list[float]
|
Raw loss values at each LR. |
suggested_lr |
float | None
|
Recommended LR (steepest descent point). |
Distributed Training¶
DistributedContext(rank=0, world_size=1, local_rank=0)
dataclass
¶
Process-level distributed training state (rank, world size, device).
get_strategy(strategy, world_size=1)
¶
Resolve "auto" strategy to "fsdp" (multi-GPU) or "none" (single).
wrap_model_distributed(model, *, strategy, ctx, fsdp_config=None, deepspeed_config=None, mixed_precision='bf16', **kwargs)
¶
Wrap a model with the chosen distributed strategy (DDP, FSDP, or DeepSpeed).
Source code in xaytune/trainer/distributed.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | |
init_distributed()
¶
Initialize distributed training from environment variables (RANK, WORLD_SIZE).
Source code in xaytune/trainer/distributed.py
cleanup_distributed(ctx)
¶
Destroy the process group if distributed training is active.
Source code in xaytune/trainer/distributed.py
Device Utilities¶
get_device(local_rank=0, *, device_type=None)
¶
Return a :class:torch.device for the given rank and device type.
Source code in xaytune/trainer/device.py
get_device_type()
¶
Detect the best available device type ("cuda", "mps", or "cpu").
Source code in xaytune/trainer/device.py
seed_all(seed)
¶
Seed Python, PyTorch CPU, CUDA, and MPS random generators.