Skip to content

Recipes

xaytune's three top-level recipe functions are the primary entry points for training. Each recipe calls setup_training() internally, builds a Trainer, and runs the loop.

import xaytune

state = xaytune.finetune("config.yaml")
state = xaytune.pretrain("config.yaml")
state = xaytune.align("config.yaml")

finetune

finetune(*, config=None, model=None, tokenizer=None, dataset=None, method='full', format='alpaca', num_epochs=3, learning_rate=0.0002, batch_size=4, resume_from=None, **kwargs)

Fine-tune a pretrained language model on a supervised dataset.

Accepts either a fully specified TrainConfig or individual arguments for quick one-liner usage. Extra **kwargs that match TrainerConfig fields (e.g. max_steps, mixed_precision) are forwarded automatically.

Parameters:

Name Type Description Default
config TrainConfig | None

Complete training configuration. When provided, all other arguments except resume_from are ignored.

None
model Any | None

HuggingFace model name, local path, or a pre-built nn.Module / ModelResult. When passing a raw module, tokenizer must also be provided.

None
tokenizer Any | None

Tokenizer instance — required when model is a raw nn.Module, ignored when model is a string or None.

None
dataset str | None

Path to a JSONL training file or HuggingFace dataset name.

None
method str

Fine-tuning method — "full", "lora", or "qlora".

'full'
format str

Data format — "alpaca", "sharegpt", "chat", or "text".

'alpaca'
num_epochs int

Number of training epochs.

3
learning_rate float

Peak learning rate.

0.0002
batch_size int

Per-device batch size.

4
resume_from str | None

Path to a checkpoint directory to resume from.

None
**kwargs Any

Additional TrainerConfig fields (max_steps, mixed_precision, scheduler, warmup_steps, etc.).

{}

Returns:

Type Description
TrainState

Final training state with loss, global step count, and other metrics.

Raises:

Type Description
ValueError

If neither config nor both model and dataset are provided.

Example::

state = xaytune.finetune(
    model="meta-llama/Llama-3-8B",
    dataset="data/train.jsonl",
    method="lora",
    num_epochs=3,
    max_steps=100,
)
print(f"Final loss: {state.metrics['loss']:.4f}")
Source code in xaytune/recipes/finetune.py
def finetune(
    *,
    config: TrainConfig | None = None,
    model: Any | None = None,
    tokenizer: Any | None = None,
    dataset: str | None = None,
    method: str = "full",
    format: str = "alpaca",
    num_epochs: int = 3,
    learning_rate: float = 2e-4,
    batch_size: int = 4,
    resume_from: str | None = None,
    **kwargs: Any,
) -> TrainState:
    """Fine-tune a pretrained language model on a supervised dataset.

    Accepts either a fully specified ``TrainConfig`` or individual arguments
    for quick one-liner usage.  Extra ``**kwargs`` that match
    ``TrainerConfig`` fields (e.g. ``max_steps``, ``mixed_precision``) are
    forwarded automatically.

    Args:
        config: Complete training configuration. When provided, all other
            arguments except ``resume_from`` are ignored.
        model: HuggingFace model name, local path, or a pre-built
            ``nn.Module`` / ``ModelResult``.  When passing a raw module,
            ``tokenizer`` must also be provided.
        tokenizer: Tokenizer instance — required when ``model`` is a raw
            ``nn.Module``, ignored when ``model`` is a string or ``None``.
        dataset: Path to a JSONL training file or HuggingFace dataset name.
        method: Fine-tuning method — ``"full"``, ``"lora"``, or ``"qlora"``.
        format: Data format — ``"alpaca"``, ``"sharegpt"``, ``"chat"``, or ``"text"``.
        num_epochs: Number of training epochs.
        learning_rate: Peak learning rate.
        batch_size: Per-device batch size.
        resume_from: Path to a checkpoint directory to resume from.
        **kwargs: Additional ``TrainerConfig`` fields (``max_steps``,
            ``mixed_precision``, ``scheduler``, ``warmup_steps``, etc.).

    Returns:
        Final training state with loss, global step count, and other metrics.

    Raises:
        ValueError: If neither ``config`` nor both ``model`` and ``dataset``
            are provided.

    Example::

        state = xaytune.finetune(
            model="meta-llama/Llama-3-8B",
            dataset="data/train.jsonl",
            method="lora",
            num_epochs=3,
            max_steps=100,
        )
        print(f"Final loss: {state.metrics['loss']:.4f}")
    """
    injected_model = None
    if config is None:
        if dataset is None:
            raise ValueError("Either 'config' or both 'model' and 'dataset' are required.")

        model_name = model if isinstance(model, str) else "custom"
        if not isinstance(model, str) and model is not None:
            injected_model = model
        elif model is None:
            raise ValueError("Either 'config' or both 'model' and 'dataset' are required.")

        trainer_fields = {}
        trainer_param_names = {f for f in TrainerConfig.model_fields}
        for k, v in list(kwargs.items()):
            if k in trainer_param_names:
                trainer_fields[k] = kwargs.pop(k)

        config = TrainConfig(
            recipe="finetune",
            method=method,
            model=ModelConfig(name=model_name),
            data=DataConfig(path=dataset, format=format),
            trainer=TrainerConfig(
                num_epochs=num_epochs,
                learning_rate=learning_rate,
                batch_size=batch_size,
                **trainer_fields,
            ),
        )

    components = _base.setup_training(
        config,
        resume_from=resume_from,
        model=injected_model,
        tokenizer=tokenizer,
    )

    state = components.trainer.train(
        model=components.model,
        train_dataloader=components.train_dataloader,
        resume_state=components.resume_state,
        resume_checkpoint_dir=resume_from,
    )

    return state

pretrain

pretrain(*, config=None, model=None, tokenizer=None, dataset=None, format='text', num_epochs=1, learning_rate=0.0003, batch_size=4, resume_from=None, **kwargs)

Pre-train a language model on raw text with a causal LM objective.

Accepts either a fully specified TrainConfig or individual arguments for quick one-liner usage. Extra **kwargs that match TrainerConfig fields are forwarded automatically.

Parameters:

Name Type Description Default
config TrainConfig | None

Complete training configuration. When provided, all other arguments except resume_from are ignored.

None
model Any | None

HuggingFace model name or local path.

None
dataset str | None

Path to a JSONL corpus file (each line: {"text": "..."}) or a HuggingFace dataset name.

None
format str

Data format — typically "text" for pre-training.

'text'
num_epochs int

Number of training epochs.

1
learning_rate float

Peak learning rate.

0.0003
batch_size int

Per-device batch size.

4
resume_from str | None

Path to a checkpoint directory to resume from.

None
**kwargs Any

Additional TrainerConfig fields (max_steps, mixed_precision, scheduler, etc.).

{}

Returns:

Type Description
TrainState

Final training state with loss, global step count, and other metrics.

Raises:

Type Description
ValueError

If neither config nor both model and dataset are provided.

Example::

state = xaytune.pretrain(
    model="gpt2",
    dataset="data/corpus.jsonl",
    num_epochs=1,
    max_steps=1000,
)
Source code in xaytune/recipes/pretrain.py
def pretrain(
    *,
    config: TrainConfig | None = None,
    model: Any | None = None,
    tokenizer: Any | None = None,
    dataset: str | None = None,
    format: str = "text",
    num_epochs: int = 1,
    learning_rate: float = 3e-4,
    batch_size: int = 4,
    resume_from: str | None = None,
    **kwargs: Any,
) -> TrainState:
    """Pre-train a language model on raw text with a causal LM objective.

    Accepts either a fully specified ``TrainConfig`` or individual arguments
    for quick one-liner usage.  Extra ``**kwargs`` that match
    ``TrainerConfig`` fields are forwarded automatically.

    Args:
        config: Complete training configuration. When provided, all other
            arguments except ``resume_from`` are ignored.
        model: HuggingFace model name or local path.
        dataset: Path to a JSONL corpus file (each line: ``{"text": "..."}``)
            or a HuggingFace dataset name.
        format: Data format — typically ``"text"`` for pre-training.
        num_epochs: Number of training epochs.
        learning_rate: Peak learning rate.
        batch_size: Per-device batch size.
        resume_from: Path to a checkpoint directory to resume from.
        **kwargs: Additional ``TrainerConfig`` fields (``max_steps``,
            ``mixed_precision``, ``scheduler``, etc.).

    Returns:
        Final training state with loss, global step count, and other metrics.

    Raises:
        ValueError: If neither ``config`` nor both ``model`` and ``dataset``
            are provided.

    Example::

        state = xaytune.pretrain(
            model="gpt2",
            dataset="data/corpus.jsonl",
            num_epochs=1,
            max_steps=1000,
        )
    """
    injected_model = None
    if config is None:
        if dataset is None:
            raise ValueError("Either 'config' or both 'model' and 'dataset' are required.")

        model_name = model if isinstance(model, str) else "custom"
        if not isinstance(model, str) and model is not None:
            injected_model = model
        elif model is None:
            raise ValueError("Either 'config' or both 'model' and 'dataset' are required.")

        trainer_fields = {}
        trainer_param_names = {f for f in TrainerConfig.model_fields}
        for k, v in list(kwargs.items()):
            if k in trainer_param_names:
                trainer_fields[k] = kwargs.pop(k)

        config = TrainConfig(
            recipe="pretrain",
            method="full",
            model=ModelConfig(name=model_name),
            data=DataConfig(path=dataset, format=format),
            trainer=TrainerConfig(
                num_epochs=num_epochs,
                learning_rate=learning_rate,
                batch_size=batch_size,
                **trainer_fields,
            ),
        )

    components = _base.setup_training(
        config,
        resume_from=resume_from,
        model=injected_model,
        tokenizer=tokenizer,
    )

    state = components.trainer.train(
        model=components.model,
        train_dataloader=components.train_dataloader,
        resume_state=components.resume_state,
        resume_checkpoint_dir=resume_from,
    )

    return state

align

align

align(*, config=None, model=None, tokenizer=None, dataset=None, method='dpo', format='preference', num_epochs=1, learning_rate=5e-06, batch_size=4, resume_from=None, **kwargs)

Align a language model using preference-based or RL methods.

Supports DPO, SimPO, ORPO, PPO, GRPO, and REINFORCE. A frozen reference model is created automatically for methods that need one. Method-specific hyperparameters (beta, kl_coeff, etc.) are extracted from **kwargs and forwarded to the loss function.

Parameters:

Name Type Description Default
config TrainConfig | None

Complete training configuration. When provided, all other arguments except resume_from are ignored.

None
model Any | None

HuggingFace model name or local path.

None
dataset str | None

Path to a preference JSONL file (each line: {"prompt": "...", "chosen": "...", "rejected": "..."}).

None
method str

Alignment method — "dpo", "simpo", "orpo", "ppo", "grpo", or "reinforce".

'dpo'
format str

Data format — "preference" for paired data.

'preference'
num_epochs int

Number of training epochs.

1
learning_rate float

Peak learning rate.

5e-06
batch_size int

Per-device batch size.

4
resume_from str | None

Path to a checkpoint directory to resume from.

None
**kwargs Any

Method hyperparameters (beta, kl_coeff, lambda_weight, gamma, clip_eps) and any extra TrainerConfig fields.

{}

Returns:

Type Description
TrainState

Final training state with loss, global step count, and other metrics.

Raises:

Type Description
ValueError

If neither config nor both model and dataset are provided.

Example::

state = xaytune.align(
    model="meta-llama/Llama-3-8B",
    dataset="data/prefs.jsonl",
    method="dpo",
    beta=0.1,
    max_steps=200,
)
Source code in xaytune/recipes/align/align.py
def align(
    *,
    config: TrainConfig | None = None,
    model: Any | None = None,
    tokenizer: Any | None = None,
    dataset: str | None = None,
    method: str = "dpo",
    format: str = "preference",
    num_epochs: int = 1,
    learning_rate: float = 5e-6,
    batch_size: int = 4,
    resume_from: str | None = None,
    **kwargs: Any,
) -> TrainState:
    """Align a language model using preference-based or RL methods.

    Supports DPO, SimPO, ORPO, PPO, GRPO, and REINFORCE.  A frozen
    reference model is created automatically for methods that need one.
    Method-specific hyperparameters (``beta``, ``kl_coeff``, etc.) are
    extracted from ``**kwargs`` and forwarded to the loss function.

    Args:
        config: Complete training configuration. When provided, all other
            arguments except ``resume_from`` are ignored.
        model: HuggingFace model name or local path.
        dataset: Path to a preference JSONL file (each line:
            ``{"prompt": "...", "chosen": "...", "rejected": "..."}``).
        method: Alignment method — ``"dpo"``, ``"simpo"``, ``"orpo"``,
            ``"ppo"``, ``"grpo"``, or ``"reinforce"``.
        format: Data format — ``"preference"`` for paired data.
        num_epochs: Number of training epochs.
        learning_rate: Peak learning rate.
        batch_size: Per-device batch size.
        resume_from: Path to a checkpoint directory to resume from.
        **kwargs: Method hyperparameters (``beta``, ``kl_coeff``,
            ``lambda_weight``, ``gamma``, ``clip_eps``) and any extra
            ``TrainerConfig`` fields.

    Returns:
        Final training state with loss, global step count, and other metrics.

    Raises:
        ValueError: If neither ``config`` nor both ``model`` and ``dataset``
            are provided.

    Example::

        state = xaytune.align(
            model="meta-llama/Llama-3-8B",
            dataset="data/prefs.jsonl",
            method="dpo",
            beta=0.1,
            max_steps=200,
        )
    """
    injected_model = None
    if config is None:
        if dataset is None:
            raise ValueError("Either 'config' or both 'model' and 'dataset' are required.")

        model_name = model if isinstance(model, str) else "custom"
        if not isinstance(model, str) and model is not None:
            injected_model = model
        elif model is None:
            raise ValueError("Either 'config' or both 'model' and 'dataset' are required.")

        trainer_fields = {}
        method_params = {}
        trainer_param_names = {f for f in TrainerConfig.model_fields}
        method_param_names = {
            "beta",
            "kl_coeff",
            "lambda_weight",
            "gamma",
            "clip_eps",
        }
        for k in list(kwargs.keys()):
            if k in trainer_param_names:
                trainer_fields[k] = kwargs.pop(k)
            elif k in method_param_names:
                method_params[k] = kwargs.pop(k)

        config = TrainConfig(
            recipe="align",
            method=method,
            model=ModelConfig(name=model_name),
            data=DataConfig(path=dataset, format=format),
            trainer=TrainerConfig(
                num_epochs=num_epochs,
                learning_rate=learning_rate,
                batch_size=batch_size,
                **trainer_fields,
            ),
            method_params=method_params,
        )

    components = _base.setup_training(
        config,
        resume_from=resume_from,
        model=injected_model,
        tokenizer=tokenizer,
    )

    loss_fn = None
    if is_alignment_method(config.method):
        ref_model = copy.deepcopy(components.model)
        ref_model.eval()
        for param in ref_model.parameters():
            param.requires_grad = False

        loss_fn = create_alignment_loss_fn(
            method=config.method,
            ref_model=ref_model,
            **config.method_params,
        )

    state = components.trainer.train(
        model=components.model,
        train_dataloader=components.train_dataloader,
        loss_fn=loss_fn,
        resume_state=components.resume_state,
        resume_checkpoint_dir=resume_from,
    )

    return state

setup_training

setup_training(config, callback_manager=None, resume_from=None, model=None, tokenizer=None)

Build the full training pipeline from a configuration object.

Handles model loading, LoRA/QLoRA application, tokenization, data packing, distributed setup, and callback registration (eval, checkpoints, early stopping, progress bar, logging). Returns a :class:TrainingComponents tuple ready for components.trainer.train().

Parameters:

Name Type Description Default
config TrainConfig

Complete training configuration.

required
callback_manager CallbackManager | None

Optional pre-configured callback manager. A new one is created if not provided.

None
resume_from str | None

Path to a checkpoint directory to resume from.

None

Returns:

Name Type Description
A TrainingComponents

class:TrainingComponents with the model, data loaders, trainer,

TrainingComponents

and optional resume state.

Source code in xaytune/recipes/base.py
def setup_training(
    config: TrainConfig,
    callback_manager: CallbackManager | None = None,
    resume_from: str | None = None,
    model: Any | None = None,
    tokenizer: Any | None = None,
) -> TrainingComponents:
    """Build the full training pipeline from a configuration object.

    Handles model loading, LoRA/QLoRA application, tokenization, data
    packing, distributed setup, and callback registration (eval, checkpoints,
    early stopping, progress bar, logging).  Returns a
    :class:`TrainingComponents` tuple ready for ``components.trainer.train()``.

    Args:
        config: Complete training configuration.
        callback_manager: Optional pre-configured callback manager.
            A new one is created if not provided.
        resume_from: Path to a checkpoint directory to resume from.

    Returns:
        A :class:`TrainingComponents` with the model, data loaders, trainer,
        and optional resume state.
    """
    # Set random seeds for reproducibility
    seed_all(config.trainer.seed)

    # Initialize distributed context
    ctx = init_distributed()
    strategy = get_strategy(config.trainer.strategy, ctx.world_size)

    if model is not None:
        from xaytune.models.loader import ModelResult

        if isinstance(model, ModelResult):
            model_result = model
        else:
            if tokenizer is None:
                raise ValueError(
                    "tokenizer is required when passing a raw model to setup_training()"
                )
            model_result = ModelResult(
                model=model,
                tokenizer=tokenizer,
                name="custom",
            )
    else:
        quantization = None
        if config.method == "qlora":
            quantization = "4bit"
        elif config.model.quantization:
            quantization = config.model.quantization

        model_result = load_model(
            config.model.name,
            quantization=quantization,
            dtype=config.model.dtype,
            trust_remote_code=config.model.trust_remote_code,
        )

    if config.method in ("lora", "qlora"):
        model_result = apply_lora(
            model_result,
            rank=config.lora.rank,
            alpha=config.lora.alpha,
            dropout=config.lora.dropout,
            target_modules=config.lora.target_modules,
        )

    # Move model to correct device
    model = model_result.model

    if config.trainer.activation_checkpointing:
        if hasattr(model, "gradient_checkpointing_enable"):
            model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": False}
            )

    model.to(ctx.device)

    # Wrap model for distributed training
    if strategy != "none":
        model = wrap_model_distributed(
            model,
            strategy=strategy,
            ctx=ctx,
            fsdp_config=config.fsdp,
            deepspeed_config=config.deepspeed_config,
            mixed_precision=config.trainer.mixed_precision,
        )

    dataset = load_dataset(
        config.data.path,
        format=config.data.format,
        source=config.data.source,
        streaming=config.data.streaming,
        eval_split=config.data.eval_split,
        tokenizer=model_result.tokenizer,
    )

    if config.data.eval_split > 0:
        train_data, eval_data = dataset  # type: ignore[misc]
    else:
        train_data = dataset  # type: ignore[assignment]
        eval_data = None

    max_seq = config.data.max_seq_length

    # Detect streaming (HuggingFace IterableDataset or torch IterableDataset)
    is_streaming = not isinstance(train_data, list)

    if is_streaming:
        train_data = StreamingTokenizedDataset(  # type: ignore[assignment]
            train_data,
            model_result.tokenizer,
            max_seq,
        )
    else:
        samples: list[dict[str, Any]] = train_data  # type: ignore[assignment]
        is_preference = samples and "prompt" in samples[0] and "chosen" in samples[0]

        if is_preference:
            train_data = tokenize_preference_dataset(
                samples,
                model_result.tokenizer,
                max_seq,
            )
            if eval_data is not None and isinstance(eval_data, list) and eval_data:
                eval_data = tokenize_preference_dataset(
                    eval_data,
                    model_result.tokenizer,
                    max_seq,
                )
        else:
            if samples and "text" in samples[0]:
                train_data = tokenize_dataset(
                    samples,
                    model_result.tokenizer,
                    max_seq,
                )
            if (
                eval_data is not None
                and isinstance(eval_data, list)
                and eval_data
                and "text" in eval_data[0]
            ):
                eval_data = tokenize_dataset(
                    eval_data,
                    model_result.tokenizer,
                    max_seq,
                )

        if (
            not is_preference
            and config.data.packing
            and config.data.max_seq_length > 0
            and isinstance(train_data, list)
            and train_data
            and isinstance(train_data[0], dict)
            and "input_ids" in train_data[0]
            and isinstance(train_data[0]["input_ids"], list)
        ):
            pad_id = getattr(model_result.tokenizer, "pad_token_id", 0) or 0
            train_data = pack_sequences(  # type: ignore[arg-type]
                train_data,
                max_seq_length=config.data.max_seq_length,
                pad_token_id=pad_id,
            )
            if eval_data is not None:
                eval_data = pack_sequences(
                    eval_data,  # type: ignore[arg-type]
                    max_seq_length=config.data.max_seq_length,
                    pad_token_id=pad_id,
                )

    # Create collate function for tokenized data
    pad_id = getattr(model_result.tokenizer, "pad_token_id", 0) or 0

    is_preference = (
        not is_streaming
        and isinstance(train_data, list)
        and train_data
        and "chosen_input_ids" in train_data[0]
    )

    if is_preference:

        def collate_fn(batch: list, pid: int = pad_id) -> dict:
            return collate_preference(batch, pad_token_id=pid)
    else:

        def collate_fn(batch: list, pid: int = pad_id) -> dict:
            return collate_tokenized(batch, pad_token_id=pid)

    # Create DataLoaders — streaming datasets don't support shuffle/sampler
    sampler: Any = None
    shuffle: bool | None = True if not is_streaming else None

    if ctx.is_distributed and not is_streaming:
        from torch.utils.data.distributed import DistributedSampler

        sampler = DistributedSampler(
            train_data,  # type: ignore[arg-type]
            num_replicas=ctx.world_size,
            rank=ctx.rank,
            shuffle=True,
        )
        shuffle = None

    dl_kwargs: dict[str, Any] = {
        "batch_size": config.trainer.batch_size,
        "collate_fn": collate_fn,
    }
    if shuffle is not None:
        dl_kwargs["shuffle"] = shuffle
    if sampler is not None:
        dl_kwargs["sampler"] = sampler

    train_dataloader: Any = DataLoader(
        train_data,  # type: ignore[arg-type]
        **dl_kwargs,
    )

    eval_sampler: Any = None
    if ctx.is_distributed and eval_data is not None:
        from torch.utils.data.distributed import DistributedSampler

        eval_sampler = DistributedSampler(
            eval_data,  # type: ignore[arg-type]
            num_replicas=ctx.world_size,
            rank=ctx.rank,
            shuffle=False,
        )

    eval_dataloader: Any = None
    if eval_data is not None:
        eval_dataloader = DataLoader(
            eval_data,  # type: ignore[arg-type]
            batch_size=config.trainer.batch_size,
            shuffle=False,
            sampler=eval_sampler,
            collate_fn=collate_fn,
        )

    # Validate a sample batch before training (skip for streaming)
    if not is_streaming:
        validate_dataset_sample(
            train_dataloader,
            max_seq_length=config.data.max_seq_length,
        )

    cb_manager = callback_manager or CallbackManager()

    # Register distributed cleanup callback
    if ctx.is_distributed:

        @cb_manager.on("train_end")
        def _cleanup_distributed(state: Any) -> None:
            cleanup_distributed(ctx)

    trainer = Trainer(
        config=config.trainer,
        callback_manager=cb_manager,
    )

    # Set up logging
    from xaytune.logging import setup_logging

    logging_manager = setup_logging(
        config.logging,
        cb_manager,
        output_dir=config.output.dir,
        rank=ctx.rank,
    )

    @cb_manager.on("train_start")
    def _log_config(state: Any) -> None:
        logging_manager.log_config(config.model_dump())

    # Set up async checkpoint saver if requested
    async_saver = None
    if config.trainer.async_checkpoint:
        from xaytune.trainer.async_checkpoint import AsyncCheckpointSaver

        async_saver = AsyncCheckpointSaver()

        @cb_manager.on("train_end")
        def _wait_async_saver(state: Any) -> None:
            async_saver.wait()

    # Register checkpoint callbacks
    register_checkpoint_callbacks(
        callback_manager=cb_manager,
        trainer=trainer,
        model=model,
        output_dir=config.output.dir,
        checkpoint_every_n_steps=config.trainer.checkpoint_every_n_steps,
        save_last=config.trainer.save_last,
        is_main_process=ctx.is_main_process,
        async_saver=async_saver,
    )

    # Register eval callbacks if eval data is available
    if eval_dataloader is not None and config.eval.every_n_steps > 0:
        register_eval_callbacks(
            callback_manager=cb_manager,
            model=model,
            eval_dataloader=eval_dataloader,
            every_n_steps=config.eval.every_n_steps,
            metrics=config.eval.metrics,
            is_main_process=ctx.is_main_process,
        )

    # Register early stopping if configured
    if config.eval.early_stopping_patience > 0 and eval_dataloader is not None:
        register_early_stopping_callbacks(
            callback_manager=cb_manager,
            patience=config.eval.early_stopping_patience,
            metric=config.eval.early_stopping_metric,
            min_delta=config.eval.early_stopping_min_delta,
        )

    # Auto-merge LoRA adapters on training completion
    if config.output.merge_on_complete and config.method in ("lora", "qlora"):

        @cb_manager.on("train_end")
        def _merge_on_complete(state: Any) -> None:
            if not ctx.is_main_process:
                return
            if hasattr(model, "merge_and_unload"):
                merged = model.merge_and_unload()
                save_dir = f"{config.output.dir}/merged"
                from pathlib import Path

                Path(save_dir).mkdir(parents=True, exist_ok=True)
                merged.save_pretrained(save_dir)
                model_result.tokenizer.save_pretrained(save_dir)

    # Register progress bar
    try:
        total_steps = len(train_dataloader)
        if config.trainer.gradient_accumulation > 1:
            total_steps = total_steps // config.trainer.gradient_accumulation
        total_steps *= config.trainer.num_epochs
        if config.trainer.max_steps > 0:
            total_steps = min(total_steps, config.trainer.max_steps)
    except TypeError:
        total_steps = config.trainer.max_steps if config.trainer.max_steps > 0 else 0

    register_progress_callbacks(
        callback_manager=cb_manager,
        total_steps=total_steps,
        is_main_process=ctx.is_main_process,
    )

    # Resume from checkpoint if requested
    resume_state = None
    if resume_from is not None:

        class _NoopOptimizer:
            def load_state_dict(self, state: Any) -> None:
                pass

        resume_state = load_checkpoint(
            checkpoint_dir=resume_from,
            model=model,
            optimizer=_NoopOptimizer(),
        )

    return TrainingComponents(
        model=model,
        tokenizer=model_result.tokenizer,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        trainer=trainer,
        distributed_ctx=ctx,
        resume_state=resume_state,
    )

TrainingComponents

TrainingComponents

Bases: NamedTuple

Container for all objects created by :func:setup_training.

Attributes:

Name Type Description
model Any

The model, potentially wrapped for distributed training.

tokenizer Any

The associated tokenizer.

train_dataloader DataLoader

DataLoader for training data.

eval_dataloader DataLoader | None

DataLoader for evaluation data, or None.

trainer Trainer

Configured :class:~xaytune.trainer.Trainer instance.

distributed_ctx Any

Distributed context (rank, world size, device).

resume_state Any

Restored :class:~xaytune.trainer.callbacks.TrainState when resuming from a checkpoint, otherwise None.