Skip to content

Data Pipeline

xaytune's data pipeline handles loading, formatting, tokenizing, packing, and validating training data. The typical flow is:

load_dataset → format (automatic) → tokenize_dataset → pack_sequences (optional) → DataLoader(collate_fn)

For preference/alignment data, use the preference-specific functions instead:

load_preference_dataset → tokenize_preference_dataset → DataLoader(collate_preference)

Loading

load_dataset(path, *, format, source='local', streaming=False, eval_split=0.0, tokenizer=None, **kwargs)

Load and format a dataset from a local JSONL file or HuggingFace Hub.

Each sample is run through the registered format function ("alpaca", "sharegpt", "chat", "text", "preference"), converting raw fields into a {"text": "..."} dict ready for tokenization.

Parameters:

Name Type Description Default
path str

Local file path or HuggingFace dataset name.

required
format str

Format name registered in the format registry.

required
source str

"local" or "huggingface".

'local'
streaming bool

Stream from HuggingFace instead of downloading.

False
eval_split float

Fraction to hold out for evaluation (0 = no split).

0.0
tokenizer Any | None

Optional tokenizer for chat template application.

None

Returns:

Type Description
list[dict] | tuple[list[dict], list[dict]]

A list of formatted samples, or a (train, eval) tuple when

list[dict] | tuple[list[dict], list[dict]]

eval_split > 0.

Raises:

Type Description
FileNotFoundError

If source is "local" and path doesn't exist.

Source code in xaytune/data/loader.py
def load_dataset(
    path: str,
    *,
    format: str,
    source: str = "local",
    streaming: bool = False,
    eval_split: float = 0.0,
    tokenizer: Any | None = None,
    **kwargs: Any,
) -> list[dict] | tuple[list[dict], list[dict]]:
    """Load and format a dataset from a local JSONL file or HuggingFace Hub.

    Each sample is run through the registered format function (``"alpaca"``,
    ``"sharegpt"``, ``"chat"``, ``"text"``, ``"preference"``), converting
    raw fields into a ``{"text": "..."}`` dict ready for tokenization.

    Args:
        path: Local file path or HuggingFace dataset name.
        format: Format name registered in the format registry.
        source: ``"local"`` or ``"huggingface"``.
        streaming: Stream from HuggingFace instead of downloading.
        eval_split: Fraction to hold out for evaluation (0 = no split).
        tokenizer: Optional tokenizer for chat template application.

    Returns:
        A list of formatted samples, or a ``(train, eval)`` tuple when
        ``eval_split > 0``.

    Raises:
        FileNotFoundError: If *source* is ``"local"`` and *path* doesn't exist.
    """
    if source == "huggingface":
        return _load_huggingface(  # type: ignore[no-any-return]
            path,
            format=format,
            streaming=streaming,
            eval_split=eval_split,
            tokenizer=tokenizer,
        )

    file_path = Path(path)
    if not file_path.exists():
        raise FileNotFoundError(f"Dataset not found: {path}")
    format_fn = _make_format_fn(format, tokenizer)
    raw_data = _load_jsonl(path)
    processed = [format_fn(sample) for sample in raw_data]
    if eval_split > 0:
        return _split_dataset(processed, eval_split)
    return processed

Formats

Built-in format functions registered in format_registry:

format_alpaca(sample)

Format an Alpaca-style sample (instruction/input/output) into {"text": ...}.

Source code in xaytune/data/formats.py
@format_registry.register("alpaca")
def format_alpaca(sample: dict[str, Any]) -> dict[str, str]:
    """Format an Alpaca-style sample (instruction/input/output) into ``{"text": ...}``."""
    instruction = sample.get("instruction", "")
    input_text = sample.get("input", "")
    output = sample.get("output", "")
    if input_text:
        text = (
            f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}"
            f"\n\n### Response:\n{output}"
        )
    else:
        text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
    return {"text": text}

format_sharegpt(sample)

Format a ShareGPT-style multi-turn conversation into {"text": ...}.

Source code in xaytune/data/formats.py
@format_registry.register("sharegpt")
def format_sharegpt(sample: dict[str, Any]) -> dict[str, str]:
    """Format a ShareGPT-style multi-turn conversation into ``{"text": ...}``."""
    conversations = sample.get("conversations", [])
    parts = []
    for turn in conversations:
        role = turn.get("from", turn.get("role", ""))
        value = turn.get("value", turn.get("content", ""))
        if role in ("human", "user"):
            parts.append(f"### User:\n{value}")
        elif role in ("gpt", "assistant"):
            parts.append(f"### Assistant:\n{value}")
        elif role == "system":
            parts.append(f"### System:\n{value}")
    return {"text": "\n\n".join(parts)}

format_chat(sample)

Format an OpenAI-style chat messages list into {"text": ...}.

Source code in xaytune/data/formats.py
@format_registry.register("chat")
def format_chat(sample: dict[str, Any]) -> dict[str, str]:
    """Format an OpenAI-style chat messages list into ``{"text": ...}``."""
    messages = sample.get("messages", [])
    parts = []
    for msg in messages:
        role = msg.get("role", "")
        content = msg.get("content", "")
        parts.append(f"### {role.capitalize()}:\n{content}")
    return {"text": "\n\n".join(parts)}

format_text(sample)

Pass through a raw text sample as {"text": ...}.

Source code in xaytune/data/formats.py
@format_registry.register("text")
def format_text(sample: dict[str, Any]) -> dict[str, str]:
    """Pass through a raw text sample as ``{"text": ...}``."""
    text = sample.get("text", sample.get("content", ""))
    return {"text": text}

apply_chat_template(sample, tokenizer, *, format='chat')

Apply the tokenizer's chat template to a conversation sample.

Source code in xaytune/data/formats.py
def apply_chat_template(
    sample: dict[str, Any],
    tokenizer: Any,
    *,
    format: str = "chat",
) -> dict[str, str]:
    """Apply the tokenizer's chat template to a conversation sample."""
    if format == "sharegpt":
        role_map = {"human": "user", "gpt": "assistant"}
        conversations = sample.get("conversations", [])
        messages = []
        for turn in conversations:
            role = turn.get("from", turn.get("role", ""))
            content = turn.get("value", turn.get("content", ""))
            messages.append({"role": role_map.get(role, role), "content": content})
    else:
        messages = sample.get("messages", [])

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return {"text": text}

Tokenization

tokenize_dataset(data, tokenizer, max_seq_length=0)

Tokenize formatted samples into input_ids/labels/attention_mask dicts.

If samples already contain "input_ids", returns them unchanged. Empty texts and empty encodings are filtered out.

Parameters:

Name Type Description Default
data list[dict[str, Any]]

Formatted samples, each with a "text" key.

required
tokenizer Any

A HuggingFace tokenizer.

required
max_seq_length int

Maximum sequence length (0 = use tokenizer default).

0

Returns:

Type Description
list[dict[str, list[int]]]

List of dicts with input_ids, labels, and attention_mask

list[dict[str, list[int]]]

(all list[int]).

Source code in xaytune/data/tokenizer.py
def tokenize_dataset(
    data: list[dict[str, Any]],
    tokenizer: Any,
    max_seq_length: int = 0,
) -> list[dict[str, list[int]]]:
    """Tokenize formatted samples into input_ids/labels/attention_mask dicts.

    If samples already contain ``"input_ids"``, returns them unchanged.
    Empty texts and empty encodings are filtered out.

    Args:
        data: Formatted samples, each with a ``"text"`` key.
        tokenizer: A HuggingFace tokenizer.
        max_seq_length: Maximum sequence length (0 = use tokenizer default).

    Returns:
        List of dicts with ``input_ids``, ``labels``, and ``attention_mask``
        (all ``list[int]``).
    """
    if not data:
        return []

    if "input_ids" in data[0]:
        return data

    max_length = (
        max_seq_length if max_seq_length > 0 else getattr(tokenizer, "model_max_length", 1024)
    )

    tokenized = []
    for sample in data:
        text = sample.get("text", "")
        if not text:
            continue

        encoded = tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding=False,
            return_attention_mask=True,
        )

        input_ids = encoded["input_ids"]
        if not input_ids:
            continue

        tokenized.append(
            {
                "input_ids": input_ids,
                "labels": list(input_ids),
                "attention_mask": encoded["attention_mask"],
            }
        )

    return tokenized

collate_tokenized(batch, pad_token_id=0)

Collate tokenized samples into padded tensors for model input.

Pads all sequences to the longest in the batch. Labels are padded with -100 (cross-entropy ignore index).

Parameters:

Name Type Description Default
batch list[dict[str, Any]]

List of tokenized dicts with input_ids keys.

required
pad_token_id int

Token id for input padding (masks use 0).

0

Returns:

Type Description
dict[str, Tensor]

Dict with input_ids, labels, and attention_mask tensors.

Source code in xaytune/data/tokenizer.py
def collate_tokenized(
    batch: list[dict[str, Any]],
    pad_token_id: int = 0,
) -> dict[str, torch.Tensor]:
    """Collate tokenized samples into padded tensors for model input.

    Pads all sequences to the longest in the batch.  Labels are padded
    with ``-100`` (cross-entropy ignore index).

    Args:
        batch: List of tokenized dicts with ``input_ids`` keys.
        pad_token_id: Token id for input padding (masks use 0).

    Returns:
        Dict with ``input_ids``, ``labels``, and ``attention_mask`` tensors.
    """
    max_len = max(len(sample["input_ids"]) for sample in batch)

    input_ids = []
    labels = []
    attention_mask = []

    for sample in batch:
        ids = _to_list(sample["input_ids"])
        seq_len = len(ids)
        pad_len = max_len - seq_len

        input_ids.append(ids + [pad_token_id] * pad_len)
        lab = _to_list(sample.get("labels", sample["input_ids"]))
        labels.append(lab + [IGNORE_INDEX] * pad_len)
        mask = _to_list(sample.get("attention_mask", [1] * seq_len))
        attention_mask.append(mask + [0] * pad_len)

    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
    }

Preference Data

load_preference_dataset(path, *, eval_split=0.0)

Load a preference JSONL file with prompt/chosen/rejected fields.

Parameters:

Name Type Description Default
path str

Path to a JSONL file where each line has prompt, chosen, and rejected fields.

required
eval_split float

Fraction to hold out for evaluation.

0.0

Returns:

Type Description
list[dict] | tuple[list[dict], list[dict]]

Formatted samples, or a (train, eval) tuple.

Raises:

Type Description
FileNotFoundError

If path doesn't exist.

ValueError

If any row is missing required fields.

Source code in xaytune/data/preferences.py
def load_preference_dataset(
    path: str,
    *,
    eval_split: float = 0.0,
) -> list[dict] | tuple[list[dict], list[dict]]:
    """Load a preference JSONL file with prompt/chosen/rejected fields.

    Args:
        path: Path to a JSONL file where each line has ``prompt``,
            ``chosen``, and ``rejected`` fields.
        eval_split: Fraction to hold out for evaluation.

    Returns:
        Formatted samples, or a ``(train, eval)`` tuple.

    Raises:
        FileNotFoundError: If *path* doesn't exist.
        ValueError: If any row is missing required fields.
    """
    file_path = Path(path)
    if not file_path.exists():
        raise FileNotFoundError(f"Preference dataset not found: {path}")

    items = []
    with open(file_path) as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            sample = json.loads(line)
            missing = _REQUIRED_FIELDS - set(sample.keys())
            if missing:
                raise ValueError(
                    f"Row {i}: missing required fields: {', '.join(sorted(missing))}. "
                    f"Preference data must have: prompt, chosen, rejected."
                )
            items.append(format_preference(sample))

    if eval_split > 0:
        split_idx = len(items) - int(len(items) * eval_split)
        return items[:split_idx], items[split_idx:]

    return items

tokenize_preference_dataset(data, tokenizer, max_seq_length=0)

Tokenize preference pairs into chosen/rejected input_ids and masks.

Concatenates prompt + chosen and prompt + rejected before tokenizing. If samples already contain "chosen_input_ids", returns them unchanged. Pairs with empty chosen or rejected text are skipped.

Parameters:

Name Type Description Default
data list[dict[str, Any]]

Preference samples with prompt, chosen, rejected.

required
tokenizer Any

A HuggingFace tokenizer.

required
max_seq_length int

Maximum sequence length (0 = use tokenizer default).

0

Returns:

Type Description
list[dict[str, list[int]]]

List of dicts with chosen_input_ids, chosen_attention_mask,

list[dict[str, list[int]]]

rejected_input_ids, and rejected_attention_mask.

Source code in xaytune/data/tokenizer.py
def tokenize_preference_dataset(
    data: list[dict[str, Any]],
    tokenizer: Any,
    max_seq_length: int = 0,
) -> list[dict[str, list[int]]]:
    """Tokenize preference pairs into chosen/rejected input_ids and masks.

    Concatenates ``prompt + chosen`` and ``prompt + rejected`` before
    tokenizing.  If samples already contain ``"chosen_input_ids"``, returns
    them unchanged.  Pairs with empty chosen or rejected text are skipped.

    Args:
        data: Preference samples with ``prompt``, ``chosen``, ``rejected``.
        tokenizer: A HuggingFace tokenizer.
        max_seq_length: Maximum sequence length (0 = use tokenizer default).

    Returns:
        List of dicts with ``chosen_input_ids``, ``chosen_attention_mask``,
        ``rejected_input_ids``, and ``rejected_attention_mask``.
    """
    if not data:
        return []

    if "chosen_input_ids" in data[0]:
        return data

    max_length = (
        max_seq_length if max_seq_length > 0 else getattr(tokenizer, "model_max_length", 1024)
    )

    tokenized = []
    for sample in data:
        prompt = sample.get("prompt", "")
        chosen = sample.get("chosen", "")
        rejected = sample.get("rejected", "")
        if not chosen or not rejected:
            continue

        chosen_text = f"{prompt}{chosen}" if prompt else chosen
        rejected_text = f"{prompt}{rejected}" if prompt else rejected

        chosen_enc = tokenizer(
            chosen_text,
            truncation=True,
            max_length=max_length,
            padding=False,
            return_attention_mask=True,
        )
        rejected_enc = tokenizer(
            rejected_text,
            truncation=True,
            max_length=max_length,
            padding=False,
            return_attention_mask=True,
        )

        if not chosen_enc["input_ids"] or not rejected_enc["input_ids"]:
            continue

        tokenized.append(
            {
                "chosen_input_ids": chosen_enc["input_ids"],
                "chosen_attention_mask": chosen_enc["attention_mask"],
                "rejected_input_ids": rejected_enc["input_ids"],
                "rejected_attention_mask": rejected_enc["attention_mask"],
            }
        )

    return tokenized

collate_preference(batch, pad_token_id=0)

Collate tokenized preference pairs into padded tensors.

Pads chosen and rejected sequences independently to their respective max lengths within the batch.

Parameters:

Name Type Description Default
batch list[dict[str, Any]]

List of tokenized preference dicts.

required
pad_token_id int

Token id for input padding (masks use 0).

0

Returns:

Type Description
dict[str, Tensor]

Dict with chosen_input_ids, chosen_attention_mask,

dict[str, Tensor]

rejected_input_ids, and rejected_attention_mask tensors.

Source code in xaytune/data/tokenizer.py
def collate_preference(
    batch: list[dict[str, Any]],
    pad_token_id: int = 0,
) -> dict[str, torch.Tensor]:
    """Collate tokenized preference pairs into padded tensors.

    Pads chosen and rejected sequences independently to their respective
    max lengths within the batch.

    Args:
        batch: List of tokenized preference dicts.
        pad_token_id: Token id for input padding (masks use 0).

    Returns:
        Dict with ``chosen_input_ids``, ``chosen_attention_mask``,
        ``rejected_input_ids``, and ``rejected_attention_mask`` tensors.
    """
    result: dict[str, torch.Tensor] = {}

    for prefix in ("chosen", "rejected"):
        ids_key = f"{prefix}_input_ids"
        mask_key = f"{prefix}_attention_mask"

        max_len = max(len(_to_list(sample[ids_key])) for sample in batch)

        all_ids = []
        all_mask = []
        for sample in batch:
            ids = _to_list(sample[ids_key])
            pad_len = max_len - len(ids)
            all_ids.append(ids + [pad_token_id] * pad_len)
            mask = _to_list(sample.get(mask_key, [1] * len(ids)))
            all_mask.append(mask + [0] * pad_len)

        result[ids_key] = torch.tensor(all_ids, dtype=torch.long)
        result[mask_key] = torch.tensor(all_mask, dtype=torch.long)

    return result

Packing

pack_sequences(sequences, *, max_seq_length, pad_token_id)

Pack multiple short sequences into fixed-length blocks to reduce padding.

Concatenates tokenized samples end-to-end and splits at max_seq_length boundaries. Remaining space is padded with pad_token_id (labels use -100).

Parameters:

Name Type Description Default
sequences list[dict[str, list[int]]]

Tokenized samples, each with "input_ids" keys.

required
max_seq_length int

Target sequence length for packed blocks.

required
pad_token_id int

Token id used for input padding.

required

Returns:

Type Description
list[dict[str, list[int]]]

Packed samples with input_ids, attention_mask, and labels.

Source code in xaytune/data/packing.py
def pack_sequences(
    sequences: list[dict[str, list[int]]],
    *,
    max_seq_length: int,
    pad_token_id: int,
) -> list[dict[str, list[int]]]:
    """Pack multiple short sequences into fixed-length blocks to reduce padding.

    Concatenates tokenized samples end-to-end and splits at
    *max_seq_length* boundaries.  Remaining space is padded with
    *pad_token_id* (labels use ``-100``).

    Args:
        sequences: Tokenized samples, each with ``"input_ids"`` keys.
        max_seq_length: Target sequence length for packed blocks.
        pad_token_id: Token id used for input padding.

    Returns:
        Packed samples with ``input_ids``, ``attention_mask``, and ``labels``.
    """
    if not sequences:
        return []

    packed: list[dict[str, list[int]]] = []
    current_ids: list[int] = []
    current_labels: list[int] = []

    for seq in sequences:
        ids = seq["input_ids"][:max_seq_length]

        if len(current_ids) + len(ids) > max_seq_length:
            if current_ids:
                packed.append(
                    _pad_and_finalize(current_ids, current_labels, max_seq_length, pad_token_id)
                )
            current_ids = ids[:]
            current_labels = ids[:]
        else:
            current_ids.extend(ids)
            current_labels.extend(ids)

    if current_ids:
        packed.append(_pad_and_finalize(current_ids, current_labels, max_seq_length, pad_token_id))

    return packed

Validation

validate_dataset_sample(dataloader, *, max_seq_length=0)

Draw one batch from a dataloader and validate it.

Raises:

Type Description
DataValidationError

If the dataset is empty or the batch has issues.

Source code in xaytune/data/validation.py
def validate_dataset_sample(
    dataloader: Any,
    *,
    max_seq_length: int = 0,
) -> None:
    """Draw one batch from a dataloader and validate it.

    Raises:
        DataValidationError: If the dataset is empty or the batch has issues.
    """
    try:
        batch = next(iter(dataloader))
    except StopIteration:
        raise DataValidationError("Dataset is empty")

    issues = validate_batch(batch, max_seq_length=max_seq_length)
    if issues:
        raise DataValidationError(
            "Data validation failed:\n" + "\n".join(f"  - {i}" for i in issues)
        )

validate_batch(batch, *, max_seq_length=0)

Check a single batch dict for common data issues.

Returns a list of human-readable issue strings (empty = valid).

Source code in xaytune/data/validation.py
def validate_batch(
    batch: dict[str, Any],
    *,
    max_seq_length: int = 0,
) -> list[str]:
    """Check a single batch dict for common data issues.

    Returns a list of human-readable issue strings (empty = valid).
    """
    issues: list[str] = []

    if not isinstance(batch, dict):
        issues.append(f"Batch must be a dict, got {type(batch).__name__}")
        return issues

    is_preference = "chosen_input_ids" in batch
    if "input_ids" not in batch and not is_preference:
        issues.append("Batch missing required field: 'input_ids'")

    if "input_ids" in batch:
        ids = batch["input_ids"]
        if not isinstance(ids, torch.Tensor):
            issues.append(f"'input_ids' should be a Tensor, got {type(ids).__name__}")
        elif ids.dtype not in (torch.long, torch.int, torch.int32):
            issues.append(f"'input_ids' dtype should be integer, got {ids.dtype}")

        if max_seq_length > 0 and isinstance(ids, torch.Tensor) and ids.ndim >= 1:
            seq_len = ids.shape[-1]
            if seq_len > max_seq_length:
                issues.append(f"Sequence length {seq_len} exceeds max_seq_length {max_seq_length}")

    if "labels" in batch and "input_ids" in batch:
        ids = batch["input_ids"]
        labels = batch["labels"]
        if isinstance(ids, torch.Tensor) and isinstance(labels, torch.Tensor):
            if ids.shape != labels.shape:
                issues.append(
                    f"'labels' shape {labels.shape} doesn't match 'input_ids' shape {ids.shape}"
                )

    if "attention_mask" in batch and "input_ids" in batch:
        ids = batch["input_ids"]
        mask = batch["attention_mask"]
        if isinstance(ids, torch.Tensor) and isinstance(mask, torch.Tensor):
            if ids.shape != mask.shape:
                issues.append(
                    f"'attention_mask' shape {mask.shape} doesn't match "
                    f"'input_ids' shape {ids.shape}"
                )

    return issues

DataValidationError

Bases: ValueError

Raised when a data batch fails validation checks.