Extensibility¶
xaytune is designed to be extended. You can bring your own models, training logic, recipes, and data formats without modifying the library itself.
Custom Training Step¶
Subclass Trainer and override training_step() to implement custom forward/backward logic while keeping the epoch loop, callbacks, checkpointing, and scheduling:
import torch
from xaytune.trainer import Trainer
from xaytune.config.schema import TrainerConfig
class DistillTrainer(Trainer):
def __init__(self, config, teacher_model, **kwargs):
super().__init__(config, **kwargs)
self.teacher = teacher_model
def training_step(self, model, batch, optimizer, state):
# Custom forward pass
student_out = model(**batch)
with torch.no_grad():
teacher_out = self.teacher(**batch)
# Custom loss
loss = torch.nn.functional.mse_loss(
student_out.logits, teacher_out.logits
)
# Backward + optimizer step
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss.item()
# Use it
config = TrainerConfig(num_epochs=3, learning_rate=1e-4)
trainer = DistillTrainer(config=config, teacher_model=teacher)
state = trainer.train(model=student, train_dataloader=dataloader)
The base train() method handles epoch iteration, callback firing (train_start, step_end, epoch_end, etc.), early stopping, and max_steps — your override only needs to handle the single-step logic.
You can also override move_batch_to_device() if your batches need custom device handling.
Custom Models¶
Passing a pre-built model¶
All recipe functions (finetune, pretrain, align) accept a model instance directly:
import xaytune
state = xaytune.finetune(
model=my_custom_model, # any nn.Module
tokenizer=my_tokenizer, # required with raw models
dataset="data/train.jsonl",
format="alpaca",
max_steps=100,
)
You can also pass a ModelResult for more control:
from xaytune.models import ModelResult
model_result = ModelResult(
model=my_model,
tokenizer=my_tokenizer,
name="my-transformer",
metadata={"custom_key": "value"},
)
state = xaytune.finetune(
model=model_result,
dataset="data/train.jsonl",
)
Registering a model loader¶
Register a custom model loader so it works with YAML config files:
from xaytune.models import register_model, ModelResult
@register_model("my-transformer")
def load_my_transformer(name_or_path, *, dtype="auto", **kwargs):
model = MyTransformerArchitecture(hidden_size=768, num_layers=12)
tokenizer = MyTokenizer.from_pretrained("my-tokenizer")
return ModelResult(model=model, tokenizer=tokenizer, name=name_or_path)
Then reference it in your config:
When load_model("my-transformer") is called, it checks the model registry first and falls back to HuggingFace AutoModelForCausalLM if no match is found.
Custom Recipes¶
Register a recipe function with the recipe registry:
from xaytune.recipes import recipe_registry
from xaytune.recipes.base import setup_training
@recipe_registry.register("distill")
def distill(*, config, resume_from=None):
components = setup_training(config, resume_from=resume_from)
teacher = load_teacher_model()
trainer = DistillTrainer(
config=config.trainer,
teacher_model=teacher,
callback_manager=components.trainer.callback_manager,
)
return trainer.train(
model=components.model,
train_dataloader=components.train_dataloader,
resume_state=components.resume_state,
)
Use it with the CLI:
# distill.yaml
recipe: distill
model:
name: meta-llama/Llama-3-8B
data:
path: data/train.jsonl
format: alpaca
trainer:
num_epochs: 3
Custom recipe and method names are accepted by the config system — you are not limited to the built-in finetune, pretrain, and align.
Custom Data Formats¶
Register a data format function to handle custom dataset structures:
from xaytune.data import format_registry
@format_registry.register("my_format")
def format_my_data(sample):
return {"text": f"Question: {sample['q']}\nAnswer: {sample['a']}"}
Then use it in your config:
Plugin Discovery¶
Third-party packages can register recipes, models, formats, and metrics automatically using Python entry points. xaytune discovers them at import time.
In your package's pyproject.toml:
[project.entry-points."xaytune.recipes"]
distill = "my_package.recipes:distill"
[project.entry-points."xaytune.models"]
my-transformer = "my_package.models:load_my_transformer"
[project.entry-points."xaytune.formats"]
my-format = "my_package.formats:format_my_data"
[project.entry-points."xaytune.metrics"]
my-metric = "my_package.metrics:my_metric_fn"
Once the package is installed, its extensions appear in xaytune automatically — no imports or registration code needed. The xaytune list command shows all registered items including plugins.
Available Registries¶
| Registry | Entry Point Group | List Command |
|---|---|---|
| Recipes | xaytune.recipes |
xaytune list recipes |
| Models | xaytune.models |
xaytune list models |
| Data Formats | xaytune.formats |
xaytune list formats |
| Metrics | xaytune.metrics |
xaytune list metrics |
| Rewards | — | xaytune list rewards |