Skip to content

Config Schema Reference

xaytune uses Pydantic models for configuration. The root model is TrainConfig, which nests several sub-models for model, data, training, evaluation, logging, and output settings.

All config models live in xaytune.config.schema.

TrainConfig

The top-level configuration object. Every training run is driven by a TrainConfig.

from xaytune.config.schema import TrainConfig

config = TrainConfig(
    recipe="finetune",
    method="lora",
    model=ModelConfig(name="meta-llama/Llama-3.1-8B"),
    data=DataConfig(path="data/train.jsonl", format="alpaca"),
)
Field Type Default Description
recipe "finetune" | "pretrain" | "align" required Which training recipe to use
method str "full" Training method (see below)
base str | None None Base config to inherit from
model ModelConfig required Model configuration
data DataConfig required Data configuration
lora LoraConfig LoraConfig() LoRA adapter settings
trainer TrainerConfig TrainerConfig() Training hyperparameters
eval EvalConfig EvalConfig() Evaluation settings
logging LoggingConfig LoggingConfig() Logging backend configuration
output OutputConfig OutputConfig() Output directory settings

Valid methods by recipe:

  • finetune: full, lora, qlora
  • pretrain: full
  • align: dpo, grpo, ppo, orpo, simpo

ModelConfig

from xaytune.config.schema import ModelConfig

model = ModelConfig(
    name="meta-llama/Llama-3.1-8B",
    quantization="4bit",
    dtype="auto",
    trust_remote_code=False,
)
Field Type Default Description
name str required Model name (HF Hub ID) or local path
quantization "4bit" | "8bit" | None None Quantization mode for bitsandbytes
dtype str "auto" Model dtype ("auto", "float16", "bfloat16", etc.)
trust_remote_code bool False Whether to trust remote code from HF Hub

DataConfig

from xaytune.config.schema import DataConfig

data = DataConfig(
    path="data/train.jsonl",
    format="alpaca",
    source="local",
    eval_split=0.05,
    packing=True,
    max_seq_length=2048,
)
Field Type Default Description
path str required Path to dataset file or HF Hub dataset name
format str required Data format key (must be in format_registry)
source "local" | "huggingface" "local" Where to load data from
eval_split float 0.0 Fraction of data to hold out for evaluation
eval_path str | None None Explicit path to evaluation dataset
packing bool True Pack multiple sequences into one training example
max_seq_length int 2048 Maximum sequence length
streaming bool False Stream data instead of loading into memory

LoraConfig

from xaytune.config.schema import LoraConfig

lora = LoraConfig(
    rank=16,
    alpha=32,
    dropout=0.05,
    target_modules="auto",
)
Field Type Default Description
rank int 16 LoRA rank (r). Higher = more parameters, more capacity
alpha int 32 LoRA alpha scaling factor. Common rule: alpha = 2 * rank
dropout float 0.05 Dropout probability for LoRA layers
target_modules str | list[str] "auto" Which modules to apply LoRA to. "auto" selects standard attention layers

TrainerConfig

from xaytune.config.schema import TrainerConfig

trainer = TrainerConfig(
    strategy="auto",
    mixed_precision="bf16",
    batch_size=4,
    gradient_accumulation=4,
    learning_rate=2e-4,
    num_epochs=3,
)
Field Type Default Description
strategy "auto" | "ddp" | "fsdp" | "deepspeed" "auto" Distributed training strategy
mixed_precision "fp16" | "bf16" | "fp32" "bf16" Mixed precision mode
batch_size int 4 Per-device batch size
gradient_accumulation int 1 Gradient accumulation steps
learning_rate float 2e-4 Optimizer learning rate
num_epochs int 3 Number of training epochs
max_steps int -1 Maximum training steps (-1 = unlimited)
warmup_steps int 0 Number of warmup steps
warmup_ratio float 0.0 Warmup as a fraction of total steps
weight_decay float 0.01 Weight decay for optimizer
max_grad_norm float 1.0 Maximum gradient norm for clipping
seed int 42 Random seed
checkpoint_every_n_steps int 500 Save a checkpoint every N steps
save_last bool True Always save the final checkpoint

EvalConfig

Field Type Default Description
every_n_steps int 500 Run evaluation every N steps
metrics list[str] ["loss", "perplexity"] Metrics to compute during evaluation
benchmarks list[str] [] lm-eval benchmarks to run

LoggingConfig

Field Type Default Description
backends list[str] ["console"] Logging backends to enable
project str | None None Project name for wandb/mlflow
run_name str | None None Run name for wandb/mlflow
log_every_n_steps int 10 Log metrics every N steps

Available backends: console, tensorboard, wandb, mlflow


OutputConfig

Field Type Default Description
dir str "output" Output directory for checkpoints and artifacts
merge_on_complete bool False Automatically merge LoRA adapters after training

Loading Configs from YAML

from xaytune.config import load_config, validate_config

# Load from YAML file
config = load_config("configs/examples/lora_finetune.yaml")

# Load with overrides
config = load_config(
    "configs/examples/lora_finetune.yaml",
    overrides=["model.name=mistralai/Mistral-7B-v0.3", "trainer.num_epochs=5"],
)

# Validate
validate_config(config)

Full API Reference

Schema Classes

TrainConfig

Bases: BaseModel

Top-level training configuration combining all sub-configs.

This is the single object that drives setup_training() and the recipe one-liners (finetune, pretrain, align).

Attributes:

Name Type Description
recipe str

Training recipe — "finetune", "pretrain", or "align".

method str

Training method (e.g. "full", "lora", "dpo").

base str | None

Optional path to a base YAML config for inheritance.

model ModelConfig

Model loading settings.

data DataConfig

Dataset settings.

lora LoraConfig

LoRA adapter settings (used when method is "lora"/"qlora").

trainer TrainerConfig

Training loop settings.

eval EvalConfig

Evaluation and early stopping settings.

logging LoggingConfig

Logging backend settings.

output OutputConfig

Output directory and artifact settings.

method_params dict[str, Any]

Extra hyperparameters passed to the alignment loss function (e.g. {"beta": 0.1} for DPO).

fsdp FSDPConfig

FSDP settings.

deepspeed_config DeepSpeedConfig

DeepSpeed settings.

ModelConfig

Bases: BaseModel

Model loading configuration.

Attributes:

Name Type Description
name str

HuggingFace model name or local path.

quantization Literal['4bit', '8bit'] | None

Optional quantization level ("4bit" or "8bit").

dtype str

Model dtype — "auto", "float16", "bfloat16", etc.

trust_remote_code bool

Allow execution of custom model code from the Hub.

DataConfig

Bases: BaseModel

Dataset configuration.

Attributes:

Name Type Description
path str

Path to a local JSONL file or HuggingFace dataset name.

format str

Data format — "alpaca", "sharegpt", "chat", "text", or "preference".

source Literal['local', 'huggingface']

"local" for files on disk, "huggingface" for Hub datasets.

eval_split float

Fraction of training data to hold out for evaluation.

eval_path str | None

Optional separate evaluation dataset path.

packing bool

Pack short sequences together to reduce padding waste.

max_seq_length int

Maximum sequence length after tokenization.

streaming bool

Stream data instead of loading into memory.

LoraConfig

Bases: BaseModel

LoRA adapter configuration.

Attributes:

Name Type Description
rank int

Rank of the low-rank matrices.

alpha int

LoRA scaling factor (effective scale = alpha / rank).

dropout float

Dropout probability applied to LoRA layers.

target_modules str | list[str]

Modules to apply LoRA to — "auto" for framework defaults, or a list of module name patterns.

TrainerConfig

Bases: BaseModel

Training loop configuration.

Attributes:

Name Type Description
strategy Literal['auto', 'ddp', 'fsdp', 'deepspeed']

Distributed strategy — "auto", "ddp", "fsdp", or "deepspeed".

mixed_precision Literal['fp16', 'bf16', 'fp32']

AMP dtype — "fp16", "bf16", or "fp32".

batch_size int

Per-device batch size.

gradient_accumulation int

Accumulate gradients over N micro-batches.

learning_rate float

Peak learning rate.

num_epochs int

Number of training epochs.

max_steps int

Stop after this many optimizer steps (-1 = unlimited).

warmup_steps int

Linear warmup steps (mutually exclusive with warmup_ratio).

warmup_ratio float

Warmup as a fraction of total steps.

scheduler Literal['cosine', 'linear', 'constant', 'constant_with_warmup']

LR schedule — "cosine", "linear", "constant", or "constant_with_warmup".

weight_decay float

AdamW weight decay coefficient.

max_grad_norm float

Gradient clipping norm (0 = disabled).

seed int

Random seed for reproducibility.

checkpoint_every_n_steps int

Save a checkpoint every N steps.

save_last bool

Save a final checkpoint at training end.

activation_checkpointing bool

Trade compute for memory by recomputing activations during backward.

async_checkpoint bool

Write checkpoints in a background thread.

EvalConfig

Bases: BaseModel

Evaluation and early stopping configuration.

Attributes:

Name Type Description
every_n_steps int

Run evaluation every N training steps.

metrics list[str]

Metrics to compute — "loss", "perplexity", "token_accuracy".

benchmarks list[str]

Optional benchmark names for lm-eval-harness.

early_stopping_patience int

Stop if no improvement for this many evaluations (0 = disabled).

early_stopping_metric str

Metric to monitor for early stopping.

early_stopping_min_delta float

Minimum improvement to count as progress.

LoggingConfig

Bases: BaseModel

Logging backend configuration.

Attributes:

Name Type Description
backends list[str]

Active backends — "console", "tensorboard", "wandb".

project str | None

W&B / TensorBoard project name.

run_name str | None

Optional run name for experiment tracking.

log_every_n_steps int

Log metrics every N steps.

OutputConfig

Bases: BaseModel

Output and artifact configuration.

Attributes:

Name Type Description
dir str

Directory for checkpoints, logs, and exported models.

merge_on_complete bool

Auto-merge LoRA adapters at training end.

FSDPConfig

Bases: BaseModel

Fully Sharded Data Parallel (FSDP) configuration.

Attributes:

Name Type Description
sharding_strategy Literal['full_shard', 'shard_grad_op', 'no_shard']

How to shard parameters across ranks. "full_shard" shards params, grads, and optimizer states. "shard_grad_op" only shards grads and optimizer states. "no_shard" disables sharding (equivalent to DDP).

cpu_offload bool

Offload parameters and gradients to CPU RAM. Reduces GPU memory at the cost of slower training.

backward_prefetch Literal['backward_pre', 'backward_post'] | None

Prefetch next layer's params during backward. "backward_pre" is faster, "backward_post" uses less memory.

mixed_precision bool

Use FSDP-native mixed precision (dtype from TrainerConfig.mixed_precision).

auto_wrap_min_params int

Minimum parameter count for automatic FSDP wrapping. Layers with fewer parameters than this are grouped together. Set to 0 to disable auto-wrapping.

forward_prefetch bool

Prefetch next layer's params during forward pass.

sync_module_states bool

Broadcast module states from rank 0 on init. Useful when only rank 0 loads the checkpoint.

limit_all_gathers bool

Rate-limit all-gathers to reduce memory spikes.

activation_checkpointing bool

Apply activation checkpointing to auto-wrapped layers (trades compute for memory).

DeepSpeedConfig

Bases: BaseModel

DeepSpeed integration configuration.

Attributes:

Name Type Description
config_file str | None

Path to a DeepSpeed JSON config file. When provided, all other fields are ignored and the JSON file is used directly.

zero_stage Literal[0, 1, 2, 3]

ZeRO optimization stage. 0 = disabled, 1 = optimizer state partitioning, 2 = gradient + optimizer partitioning, 3 = full parameter partitioning.

offload_optimizer bool

Offload optimizer states to CPU (ZeRO stage 2/3).

offload_param bool

Offload parameters to CPU (ZeRO stage 3 only).

overlap_comm bool

Overlap gradient communication with backward pass.

contiguous_gradients bool

Use contiguous memory for gradients.

reduce_bucket_size int

Size of gradient reduction buckets in bytes.

stage3_prefetch_bucket_size int

Prefetch buffer size for ZeRO-3.

stage3_param_persistence_threshold int

Params smaller than this stay on GPU even in ZeRO-3 (reduces communication overhead).

Parser

load_config(path, overrides=None)

Load a YAML config file, resolve inheritance, and apply CLI overrides.

Parameters:

Name Type Description Default
path str

Path to a YAML configuration file.

required
overrides list[str] | None

Optional list of "key.subkey=value" strings.

None

Returns:

Type Description
TrainConfig

A validated :class:~xaytune.config.schema.TrainConfig.

Raises:

Type Description
FileNotFoundError

If the config file (or a base it inherits from) does not exist.

Source code in xaytune/config/parser.py
def load_config(
    path: str,
    overrides: list[str] | None = None,
) -> TrainConfig:
    """Load a YAML config file, resolve inheritance, and apply CLI overrides.

    Args:
        path: Path to a YAML configuration file.
        overrides: Optional list of ``"key.subkey=value"`` strings.

    Returns:
        A validated :class:`~xaytune.config.schema.TrainConfig`.

    Raises:
        FileNotFoundError: If the config file (or a base it inherits from)
            does not exist.
    """
    config_path = Path(path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found: {path}")

    with open(config_path) as f:
        data = yaml.safe_load(f)

    data = _resolve_inheritance(data, config_path.parent)

    if overrides:
        data = apply_overrides(data, overrides)

    return TrainConfig(**data)

merge_dicts(base, override)

Deep-merge override into base, returning a new dict.

Source code in xaytune/config/parser.py
def merge_dicts(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
    """Deep-merge *override* into *base*, returning a new dict."""
    result = copy.deepcopy(base)
    for key, value in override.items():
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):
            result[key] = merge_dicts(result[key], value)
        else:
            result[key] = copy.deepcopy(value)
    return result

apply_overrides(data, overrides)

Apply dot-notation CLI overrides (e.g. "trainer.lr=1e-4") to a config dict.

Source code in xaytune/config/parser.py
def apply_overrides(data: dict[str, Any], overrides: list[str]) -> dict[str, Any]:
    """Apply dot-notation CLI overrides (e.g. ``"trainer.lr=1e-4"``) to a config dict."""
    result = copy.deepcopy(data)
    for override in overrides:
        key, _, value = override.partition("=")
        parts = key.split(".")
        target = result
        for part in parts[:-1]:
            if part not in target:
                target[part] = {}
            target = target[part]
        target[parts[-1]] = _parse_value(value)
    return result

Validation

validate_config(config)

Validate cross-field constraints on a training configuration.

Checks recipe/method compatibility, mutual exclusivity of warmup settings, quantization requirements, and method_params validity.

Raises:

Type Description
ConfigValidationError

With a list of all detected issues.

Source code in xaytune/config/validation.py
def validate_config(config: TrainConfig) -> None:
    """Validate cross-field constraints on a training configuration.

    Checks recipe/method compatibility, mutual exclusivity of warmup
    settings, quantization requirements, and method_params validity.

    Raises:
        ConfigValidationError: With a list of all detected issues.
    """
    errors: list[str] = []

    if config.method == "qlora" and config.model.quantization != "4bit":
        errors.append(
            "QLoRA requires 4bit quantization, but model.quantization="
            f"'{config.model.quantization}'. Suggestion: set model.quantization='4bit'."
        )

    if not 0.0 <= config.data.eval_split <= 1.0:
        errors.append(
            f"data.eval_split must be between 0.0 and 1.0, got {config.data.eval_split}. "
            "Suggestion: set eval_split to a value like 0.05 for a 5% eval split."
        )

    if config.trainer.batch_size < 1:
        errors.append(
            f"trainer.batch_size must be >= 1, got {config.trainer.batch_size}. "
            "Suggestion: set batch_size to at least 1."
        )

    if config.trainer.learning_rate <= 0:
        errors.append(
            f"trainer.learning_rate must be positive, got {config.trainer.learning_rate}. "
            "Suggestion: typical values are 1e-5 to 5e-4."
        )

    if config.trainer.warmup_steps > 0 and config.trainer.warmup_ratio > 0.0:
        errors.append(
            "trainer.warmup_steps and trainer.warmup_ratio are mutually exclusive — "
            "set one to 0. Suggestion: use warmup_steps for an exact count, "
            "or warmup_ratio for a fraction of total steps."
        )

    if config.recipe == "align" and config.method not in _ALIGN_METHODS:
        errors.append(
            f"Recipe 'align' requires an alignment method "
            f"({', '.join(sorted(_ALIGN_METHODS))}), got '{config.method}'. "
            "Suggestion: set method='dpo' or method='grpo'."
        )

    if config.recipe == "finetune" and config.method not in _FINETUNE_METHODS:
        errors.append(
            f"Recipe 'finetune' requires a fine-tuning method "
            f"({', '.join(sorted(_FINETUNE_METHODS))}), got '{config.method}'. "
            "Suggestion: set method='lora' or method='full'."
        )

    if config.method_params:
        known = _KNOWN_METHOD_PARAMS.get(config.method, set())
        if not known and config.method not in _ALIGN_METHODS:
            errors.append(
                f"method_params is only supported for alignment methods "
                f"({', '.join(sorted(_ALIGN_METHODS))}), "
                f"but recipe/method is '{config.recipe}/{config.method}'."
            )
        else:
            unknown = set(config.method_params) - known
            if unknown:
                errors.append(
                    f"Unknown method_params for '{config.method}': "
                    f"{', '.join(sorted(unknown))}. "
                    f"Valid params: {', '.join(sorted(known)) if known else 'none'}."
                )

    if errors:
        raise ConfigValidationError(
            f"Config validation failed with {len(errors)} error(s):\n"
            + "\n".join(f"  - {e}" for e in errors)
        )

preflight_check(config)

Run environment-aware checks before training starts.

Verifies GPU availability for quantization and mixed precision, checks that data paths exist, and validates output directory write permissions.

Returns:

Type Description
list[str]

List of warning/issue strings (empty if everything looks good).

Source code in xaytune/config/validation.py
def preflight_check(config: TrainConfig) -> list[str]:
    """Run environment-aware checks before training starts.

    Verifies GPU availability for quantization and mixed precision,
    checks that data paths exist, and validates output directory
    write permissions.

    Returns:
        List of warning/issue strings (empty if everything looks good).
    """
    issues: list[str] = []

    try:
        import torch

        has_cuda = torch.cuda.is_available()
        has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
    except ImportError:
        has_cuda = False
        has_mps = False

    if config.model.quantization and not has_cuda:
        issues.append(
            f"Quantization ({config.model.quantization}) requires CUDA, "
            "but no CUDA GPU was detected."
        )

    if config.trainer.mixed_precision != "fp32" and not has_cuda and not has_mps:
        warnings.warn(
            f"mixed_precision='{config.trainer.mixed_precision}' selected "
            "but no GPU detected. Training will fall back to CPU (fp32).",
            stacklevel=2,
        )

    if config.data.source == "local":
        data_path = Path(config.data.path)
        if not data_path.exists():
            issues.append(f"Data path not found: {config.data.path}")

    output_parent = Path(config.output.dir).parent
    if output_parent.exists() and not os.access(str(output_parent), os.W_OK):
        issues.append(f"Output directory parent is not writable: {output_parent}")

    return issues

ConfigValidationError

Bases: Exception

Raised when a :class:~xaytune.config.schema.TrainConfig has invalid field combinations.