Skip to content

llmcompressor.datasets.utils

Dataset utility functions for LLM compression workflows.

Provides helper functions for loading, processing, and formatting datasets used in model compression pipelines. Handles dataset splitting, tokenization, calibration data preparation, and dataloader creation for both training and one-shot calibration workflows.

Classes:

  • LengthAwareSampler

    Sample data in order of descending sequence length. Relies on input_ids or

Functions:

LengthAwareSampler

LengthAwareSampler(
    data_source: Dataset,
    num_samples: Optional[int] = None,
    batch_size: int = 1,
)

Bases: Sampler[int]

Sample data in order of descending sequence length. Relies on input_ids or decoder_input_ids column existing in dataset

Parameters:

  • data_source

    (Dataset) –

    dataset containing a input_ids or decoder_input_ids column

  • num_samples

    (Optional[int], default: None ) –

    Maximum number of samples to sample. Shorted sequence lengths are truncated first

Source code in llmcompressor/datasets/utils.py
def __init__(
    self,
    data_source: Dataset,
    num_samples: Optional[int] = None,
    batch_size: int = 1,
) -> None:
    self.data_source = data_source
    self._num_samples = num_samples or len(data_source)
    self.batch_size = batch_size

    if "input_ids" in data_source.column_names:
        feature_name = "input_ids"
    elif "decoder_input_ids" in data_source.column_names:
        feature_name = "decoder_input_ids"
    else:
        logger.warning(f"Could not find input ids in {data_source.column_names}")
        self.order = range(len(data_source))
        return

    lengths = [len(sample) for sample in data_source[feature_name]]
    self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist()
    self._calculate_and_log_batch_stats(lengths)

get_calibration_dataloader

get_calibration_dataloader(
    dataset_args: DatasetArguments, processor: Processor
) -> torch.utils.data.DataLoader

Get the dataloader used for oneshot calibration.

Parameters:

  • dataset_args

    (DatasetArguments) –

    DatasetArguments that contains the dataset parameters.

  • processor

    (Processor) –

    Processor or the tokenizer of the model.

Returns:

  • DataLoader

    PyTorch dataloader object that contains the calibration dataset.

Source code in llmcompressor/datasets/utils.py
def get_calibration_dataloader(
    dataset_args: DatasetArguments,
    processor: Processor,
) -> torch.utils.data.DataLoader:
    """
    Get the dataloader used for oneshot calibration.
    :param dataset_args: DatasetArguments that contains the dataset parameters.
    :param processor: Processor or the tokenizer of the model.
    :return: PyTorch dataloader object that contains the calibration dataset.
    """
    if dataset_args.dataset is None:
        # weight-only quantization or dynamic quantization
        return

    datasets = get_processed_dataset(
        dataset_args=dataset_args,
        processor=processor,
        do_oneshot=True,
        do_train=False,
    )
    calibration_dataset = datasets.get("calibration")

    return format_calibration_data(dataset_args, calibration_dataset, processor)

get_processed_dataset

get_processed_dataset(
    dataset_args: DatasetArguments,
    processor: Processor | None = None,
    do_oneshot: bool = False,
    do_train: bool = True,
) -> dict[str, Dataset] | None

Loads datasets for each flow based on dataset_args, stores a Dataset for each enabled flow in datasets

Parameters:

  • dataset_args

    (DatasetArguments) –

    DatasetArguments that contain dataset loading and processing params

  • processor

    (Processor | None, default: None ) –

    processor or tokenizer to use for dataset tokenization

  • do_oneshot

    (bool, default: False ) –

    True for oneshot pathway

  • do_train

    (bool, default: True ) –

    True for train pathway

Returns:

  • dict[str, Dataset] | None

    A dataset corresponding to either train or calibration (oneshot)

Source code in llmcompressor/datasets/utils.py
def get_processed_dataset(
    dataset_args: DatasetArguments,
    processor: Processor | None = None,
    do_oneshot: bool = False,
    do_train: bool = True,
) -> dict[str, Dataset] | None:
    """
    Loads datasets for each flow based on dataset_args, stores a Dataset for each
    enabled flow in datasets
    :param dataset_args: DatasetArguments that contain dataset loading and
        processing params
    :param processor: processor or tokenizer to use for dataset tokenization
    :param do_oneshot: True for oneshot pathway
    :param do_train: True for train pathway
    :return: A dataset corresponding to either train or calibration (oneshot)
    """
    if dataset_args.dataset is None:
        logger.warning(
            "Running oneshot without calibration data. This is expected for "
            "weight-only and dynamic quantization"
        )
        return

    splits = dataset_args.splits
    tokenized_datasets = {}

    def _get_split_name(inp_str):
        # strip out split name, for ex train[60%:] -> train
        split_name_match = re.match(r"(\w*)\[.*\]", inp_str)
        if split_name_match is not None:
            return split_name_match.group(1)
        return inp_str

    match splits:
        case None:
            splits = {"all": None}
        case str():
            splits = {_get_split_name(splits): splits}
        case list():
            splits = {_get_split_name(s): s for s in splits}
        case dict():
            pass
        case _:
            raise ValueError(f"Invalid splits type: {type(splits)}")

    # default to custom dataset if dataset provided isn't a string
    registry_id = (
        dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
    )
    for split_name, split_str in splits.items():
        dataset = dataset_args.dataset
        if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
            # dataset is already tokenized
            tokenized_datasets[split_name] = dataset
        else:
            # dataset needs to be tokenized
            dataset_manager = TextGenerationDataset.load_from_registry(
                registry_id,
                dataset_args=dataset_args,
                split=split_str,
                processor=processor,
            )
            tokenized_datasets[split_name] = dataset_manager(add_labels=do_train)

    return make_dataset_splits(
        tokenized_datasets,
        do_oneshot=do_oneshot,
        do_train=do_train,
    )

get_rank_partition

get_rank_partition(split: str, num_samples: int) -> str

Utility for splitting data in a distributed setting

Parameters:

  • split

    (str) –

    the split string to partition, e.g. "train"

  • num_samples

    (int) –

    the total number of samples in the dataset to partition

Returns:

  • str

    a partitioned split string

    Usage example:

    DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" NUM_CALIBRATION_SAMPLES = 256

    split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) ds = load_dataset( DATASET_ID, split=split, )

    for S samples and D devices, when S is not perfectly divisible by D, we give each device at least S//D samples and distribute the remaining samples as evenly as possible across all devices

Source code in llmcompressor/datasets/utils.py
def get_rank_partition(split: str, num_samples: int) -> str:
    """
    Utility for splitting data in a distributed setting

    :param split: the split string to partition, e.g. "train"
    :param num_samples: the total number of samples in the dataset to partition
    :return: a partitioned split string

    Usage example:

    DATASET_ID = "HuggingFaceH4/ultrachat_200k"
    DATASET_SPLIT = "train_sft"
    NUM_CALIBRATION_SAMPLES = 256

    split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
    ds = load_dataset(
        DATASET_ID,
        split=split,
    )

    for S samples and D devices, when S is not perfectly divisible by D,
    we give each device at least S//D samples and distribute
    the remaining samples as evenly as possible across all devices
    """
    assert (
        "[" not in split
    ), "Split string should not already contain partitioning brackets"

    start, end = _get_partition_start_end(
        num_samples, dist.get_rank(), dist.get_world_size()
    )
    return f"{split}[{start}:{end}]"

make_dataset_splits

make_dataset_splits(
    tokenized_datasets: dict[str, Any],
    do_oneshot: bool = True,
    do_train: bool = False,
) -> dict[str, Dataset]

Restructures the datasets dictionary based on what tasks will be run train

Parameters:

  • tokenized_datasets

    (dict[str, Any]) –

    dictionary of processed datasets

  • do_oneshot

    (bool, default: True ) –

    Whether to store the calibration dataset

Returns:

  • dict[str, Dataset]

    A dataset corresponding to either train or calibration (oneshot)

Source code in llmcompressor/datasets/utils.py
def make_dataset_splits(
    tokenized_datasets: dict[str, Any],
    do_oneshot: bool = True,
    do_train: bool = False,
) -> dict[str, Dataset]:
    """
    Restructures the datasets dictionary based on what tasks will be run
    train
    :param tokenized_datasets: dictionary of processed datasets
    :param do_oneshot: Whether to store the calibration dataset
    :return: A dataset corresponding to either train or calibration (oneshot)
    """

    # handles case where all splits are contained in a single dataset
    if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
        tokenized_datasets = tokenized_datasets.get("all")
        if isinstance(tokenized_datasets, Dataset):
            tokenized_datasets = {"train": tokenized_datasets}

    train_split = calib_split = None

    if do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_split = tokenized_datasets["train"]
    if do_oneshot:
        calib_split = tokenized_datasets.get("calibration")
        if calib_split is None:
            if "train" not in tokenized_datasets:
                raise ValueError("--do_oneshot requires a calibration dataset")
            calib_split = tokenized_datasets["train"]

    split_datasets = {
        "train": train_split,
        "calibration": calib_split,
    }
    return split_datasets