Skip to content

Evaluation

xaytune provides two evaluation paths: custom dataset evaluation with evaluate() and benchmark evaluation with benchmark_evaluate().

evaluate()

Evaluate a model on a custom dataset with configurable metrics.

from xaytune.eval import evaluate

results = evaluate(
    model="output/my-finetune",
    dataset=[{"input_ids": ..., "labels": ...}],
    metrics=["loss", "perplexity"],
)

print(results)
# {'loss': 1.234, 'perplexity': 3.435}

Function Signature

def evaluate(
    *,
    model: Any,
    dataset: list[dict[str, Any]],
    metrics: list[str] | None = None,
) -> dict[str, float]:
Parameter Type Default Description
model model object or str required A model instance or path to load from
dataset list[dict] required List of data batches to evaluate on
metrics list[str] | None ["loss", "perplexity"] Metric names to compute (must be in metric_registry)

Returns: dict[str, float] mapping metric names to their computed values.

Note

When model is a string path, xaytune automatically loads the model and tokenizer using xaytune.models.load_model().


benchmark_evaluate()

Run standard benchmarks using lm-eval.

from xaytune.eval.benchmarks import benchmark_evaluate

results = benchmark_evaluate(
    model="meta-llama/Llama-3.1-8B",
    benchmarks=["mmlu", "gsm8k", "hellaswag"],
    num_fewshot=5,
)

for task, metrics in results.items():
    print(f"{task}: {metrics}")

Function Signature

def benchmark_evaluate(
    *,
    model: str,
    benchmarks: list[str],
    num_fewshot: int | None = None,
) -> dict[str, dict[str, Any]]:
Parameter Type Default Description
model str required Model path or Hugging Face Hub name
benchmarks list[str] required List of benchmark task names
num_fewshot int | None None Number of few-shot examples (benchmark default if None)

Returns: Nested dict {task_name: {metric_name: value}}.

Requires lm-eval

Install the eval extra to use benchmarks:

pip install xaytune[eval]


Built-in Metrics

xaytune ships three metrics, registered in xaytune.eval.metrics.metric_registry:

Metric Function Description
loss compute_loss Average cross-entropy loss
perplexity compute_perplexity Exponentiated average loss: exp(mean_loss)
token_accuracy compute_token_accuracy Fraction of correctly predicted tokens

Custom Metrics

Register your own metrics with the @register_metric decorator:

from xaytune.eval.metrics import register_metric

@register_metric("bleu")
def compute_bleu(predictions, references, **kwargs):
    # Your BLEU implementation here
    ...
    return score

Once registered, custom metrics can be used anywhere metrics are accepted:

results = evaluate(model=model, dataset=data, metrics=["loss", "bleu"])

Or in YAML config:

eval:
  metrics: [loss, perplexity, bleu]

CLI Usage

Benchmark Evaluation

xaytune eval --model output/my-finetune --benchmarks mmlu,gsm8k --num-fewshot 5

Dataset Evaluation

xaytune eval --model output/my-finetune --dataset data/eval.jsonl --metrics loss,perplexity

Model Comparison

Compare two models side-by-side on the same benchmarks:

xaytune compare model-a model-b --benchmarks mmlu,gsm8k

This prints a table showing each model's score on every benchmark metric.


Full API Reference

evaluate(*, model, dataset, metrics=None)

Evaluate a model on a list of batches and compute metrics.

Parameters:

Name Type Description Default
model Any

A model instance or HuggingFace model name string.

required
dataset list[dict[str, Any]]

List of batch dicts (each passable to model(**batch)).

required
metrics list[str] | None

Metric names to compute (default: ["loss", "perplexity"]).

None

Returns:

Type Description
dict[str, float]

Dict mapping metric names to computed values.

Source code in xaytune/eval/evaluate.py
def evaluate(
    *,
    model: Any,
    dataset: list[dict[str, Any]],
    metrics: list[str] | None = None,
) -> dict[str, float]:
    """Evaluate a model on a list of batches and compute metrics.

    Args:
        model: A model instance or HuggingFace model name string.
        dataset: List of batch dicts (each passable to ``model(**batch)``).
        metrics: Metric names to compute (default: ``["loss", "perplexity"]``).

    Returns:
        Dict mapping metric names to computed values.
    """
    if metrics is None:
        metrics = ["loss", "perplexity"]

    if isinstance(model, str):
        from xaytune.models import load_model

        model_result = load_model(model)
        model = model_result.model

    losses: list[float] = []

    model.eval() if hasattr(model, "eval") else None

    with torch.no_grad():
        for batch in dataset:
            if isinstance(batch, dict):
                outputs = model(**batch)
            else:
                outputs = model(batch)

            if hasattr(outputs, "loss") and outputs.loss is not None:
                losses.append(outputs.loss.item())

    results: dict[str, float] = {}
    for metric_name in metrics:
        compute_fn = metric_registry.get(metric_name)
        if metric_name in ("loss", "perplexity"):
            results[metric_name] = compute_fn(losses)
        else:
            results[metric_name] = compute_fn([], [])

    return results

benchmark_evaluate(*, model, benchmarks, num_fewshot=None)

Run lm-eval-harness benchmarks against a HuggingFace model.

Parameters:

Name Type Description Default
model str

HuggingFace model name or local path.

required
benchmarks list[str]

Benchmark task names (e.g. ["mmlu", "gsm8k"]).

required
num_fewshot int | None

Number of few-shot examples. None uses each benchmark's default.

None

Returns:

Type Description
dict[str, dict[str, Any]]

Dict mapping benchmark names to their result dicts.

Raises:

Type Description
ImportError

If lm-eval is not installed.

Source code in xaytune/eval/benchmarks.py
def benchmark_evaluate(
    *,
    model: str,
    benchmarks: list[str],
    num_fewshot: int | None = None,
) -> dict[str, dict[str, Any]]:
    """Run lm-eval-harness benchmarks against a HuggingFace model.

    Args:
        model: HuggingFace model name or local path.
        benchmarks: Benchmark task names (e.g. ``["mmlu", "gsm8k"]``).
        num_fewshot: Number of few-shot examples. ``None`` uses each
            benchmark's default.

    Returns:
        Dict mapping benchmark names to their result dicts.

    Raises:
        ImportError: If ``lm-eval`` is not installed.
    """
    if lm_eval is None:
        raise ImportError(
            "lm-eval is required for benchmark evaluation. "
            "Install it with: pip install xaytune[eval]"
        )

    kwargs: dict[str, Any] = {
        "model": "hf",
        "model_args": f"pretrained={model}",
        "tasks": benchmarks,
    }
    if num_fewshot is not None:
        kwargs["num_fewshot"] = num_fewshot

    raw = lm_eval.simple_evaluate(**kwargs)

    return raw.get("results", {})  # type: ignore[no-any-return]

compute_loss(losses, *args, **kwargs)

Compute mean loss across batches.

Source code in xaytune/eval/metrics.py
@register_metric("loss")
def compute_loss(losses: list[float], *args: Any, **kwargs: Any) -> float:
    """Compute mean loss across batches."""
    if not losses:
        return 0.0
    return sum(losses) / len(losses)

compute_perplexity(losses, *args, **kwargs)

Compute perplexity as exp(mean_loss).

Source code in xaytune/eval/metrics.py
@register_metric("perplexity")
def compute_perplexity(losses: list[float], *args: Any, **kwargs: Any) -> float:
    """Compute perplexity as ``exp(mean_loss)``."""
    if not losses:
        return 0.0
    mean_loss = sum(losses) / len(losses)
    return math.exp(mean_loss)

compute_token_accuracy(predictions, references, *args, **kwargs)

Compute fraction of tokens where prediction matches reference.

Source code in xaytune/eval/metrics.py
@register_metric("token_accuracy")
def compute_token_accuracy(
    predictions: list[int],
    references: list[int],
    *args: Any,
    **kwargs: Any,
) -> float:
    """Compute fraction of tokens where prediction matches reference."""
    if not predictions:
        return 0.0
    correct = sum(p == r for p, r in zip(predictions, references))
    return correct / len(predictions)