Pre-training¶
The pretrain recipe trains a language model from scratch (or continues pre-training an existing model) on a large text corpus using a standard causal language modeling objective.
Python API¶
import xaytune
state = xaytune.pretrain(
model="meta-llama/Llama-3.1-8B",
dataset="data/corpus.jsonl",
format="text",
num_epochs=1,
learning_rate=3e-4,
batch_size=4,
)
Function Signature¶
def pretrain(
*,
config: TrainConfig | None = None,
model: str | None = None,
dataset: str | None = None,
format: str = "text",
num_epochs: int = 1,
learning_rate: float = 3e-4,
batch_size: int = 4,
**kwargs,
) -> TrainState:
- config -- A full
TrainConfigobject. If provided, all other arguments are ignored. - model -- Model name or path.
- dataset -- Path to training data.
- format -- Data format (default:
"text"). - num_epochs -- Number of training epochs (default: 1).
- learning_rate -- Learning rate (default: 3e-4).
- batch_size -- Per-device batch size (default: 4).
- **kwargs -- Additional
TrainerConfigfields.
Note
Pre-training always uses method="full" -- LoRA/QLoRA are not applicable since there are no pre-trained weights to adapt.
YAML Config¶
recipe: pretrain
method: full
model:
name: meta-llama/Llama-3.1-8B
data:
path: data/corpus.jsonl
format: text
packing: true
max_seq_length: 2048
streaming: true
trainer:
batch_size: 4
gradient_accumulation: 8
learning_rate: 3e-4
num_epochs: 1
warmup_steps: 1000
weight_decay: 0.01
mixed_precision: bf16
strategy: fsdp
logging:
backends: [console, tensorboard, wandb]
project: pretrain-run
log_every_n_steps: 10
eval:
every_n_steps: 500
metrics: [loss, perplexity]
output:
dir: output/pretrain
Data Preparation¶
Pre-training typically uses the text format, where each sample has a text (or content) field containing raw text:
{"text": "The quick brown fox jumps over the lazy dog."}
{"text": "Machine learning is a subset of artificial intelligence..."}
Sequence Packing¶
By default, xaytune packs multiple short sequences into a single training example up to max_seq_length to improve GPU utilization. Disable this with packing: false if your data already contains full-length sequences.
Streaming¶
For large datasets that do not fit in memory, enable streaming: true in the data config. This loads data on the fly rather than materializing the full dataset.
Distributed Training¶
Pre-training typically requires multiple GPUs. Configure the distributed strategy in the trainer config:
| Strategy | Description |
|---|---|
auto |
Automatically selects based on available hardware |
ddp |
Distributed Data Parallel -- replicates model on each GPU |
fsdp |
Fully Sharded Data Parallel -- shards model across GPUs |
deepspeed |
DeepSpeed ZeRO optimization (requires pip install xaytune[deepspeed]) |