Skip to content

llmcompressor.entrypoints.oneshot

Oneshot compression entrypoint for post-training model optimization.

Provides the main oneshot compression entry point for applying quantization, pruning, and other compression techniques to pre-trained models without additional training. Supports calibration-based compression with various pipeline configurations for efficient model optimization.

Classes:

  • Oneshot

    Class responsible for carrying out one-shot calibration on a pretrained model.

Functions:

  • oneshot

    Performs oneshot calibration on a model.

Oneshot

Oneshot(log_dir: str | None = None, **kwargs)

Class responsible for carrying out one-shot calibration on a pretrained model.

This class handles the entire lifecycle of one-shot calibration, including preprocessing (model and tokenizer/processor initialization), model optimization (quantization or sparsification), and postprocessing (saving outputs). The instructions for model optimization can be specified by using a recipe.

  • Input Keyword Arguments: kwargs are parsed into:

    • model_args: Arguments for loading and configuring a pretrained model (e.g., AutoModelForCausalLM).
    • dataset_args: Arguments for dataset-related configurations, such as calibration dataloaders.
    • recipe_args: Arguments for defining and configuring recipes that specify optimization actions.

    Parsers are defined in src/llmcompressor/args/.

  • Lifecycle Overview: The oneshot calibration lifecycle consists of three steps:

    1. Preprocessing:
      • Instantiates a pretrained model and tokenizer/processor.
      • Ensures input and output embedding layers are untied if they share tensors.
      • Patches the model to include additional functionality for saving with quantization configurations.
    2. Oneshot Calibration:
      • Optimizes the model using a global CompressionSession and applies recipe-defined modifiers (e.g., GPTQModifier, SparseGPTModifier)
    3. Postprocessing:
      • Saves the model, tokenizer/processor, and configuration to the specified output_dir.
  • Usage:

    oneshot = Oneshot(model=model, recipe=recipe, dataset=dataset)
    oneshot()
    
    # Access the processed components
    model = oneshot.model
    processor = oneshot.processor
    recipe = oneshot.recipe
    

Methods: init(**kwargs): Initializes the Oneshot object by parsing input arguments, performing preprocessing, and setting instance attributes.

__call__(**kwargs):
    Performs the one-shot calibration process by preparing a calibration
    dataloader, applying recipe modifiers to the model, and executing
    postprocessing steps.

save():
    Saves the calibrated model and tokenizer/processor to the specified
    `output_dir`. Supports saving in compressed formats based on model
    arguments.

apply_recipe_modifiers(calibration_dataloader, **kwargs):
    Applies lifecycle actions (e.g., `initialize`, `finalize`) using modifiers
    defined in the recipe. Each action is executed via the global
    `CompressionSession`.

Initializes the Oneshot class with provided arguments.

Parses the input keyword arguments into model_args, dataset_args, and recipe_args. Performs preprocessing to initialize the model and tokenizer/processor.

Parameters:

  • model_args

    ModelArguments parameters, responsible for controlling model loading and saving logic

  • dataset_args

    DatasetArguments parameters, responsible for controlling dataset loading, preprocessing and dataloader loading

  • recipe_args

    RecipeArguments parameters, responsible for containing recipe-related parameters

  • output_dir

    Path to save the output model after carrying out oneshot

  • log_dir

    (str | None, default: None ) –

    Path to save logs during oneshot run. Nothing is logged to file if None.

Methods:

Source code in llmcompressor/entrypoints/oneshot.py
def __init__(
    self,
    log_dir: str | None = None,
    **kwargs,
):
    """
    Initializes the `Oneshot` class with provided arguments.

    Parses the input keyword arguments into `model_args`, `dataset_args`, and
    `recipe_args`. Performs preprocessing to initialize the model and
    tokenizer/processor.

    :param model_args: ModelArguments parameters, responsible for controlling
        model loading and saving logic
    :param dataset_args: DatasetArguments parameters, responsible for controlling
        dataset loading, preprocessing and dataloader loading
    :param recipe_args: RecipeArguments parameters, responsible for containing
        recipe-related parameters
    :param output_dir: Path to save the output model after carrying out oneshot
    :param log_dir: Path to save logs during oneshot run.
        Nothing is logged to file if None.
    """
    # Disable tokenizer parallelism to prevent warning when using
    # multiprocessing for dataset preprocessing. The warning occurs because
    # FastTokenizer's internal threading conflicts with dataset.map's num_proc.
    # See: https://github.com/vllm-project/llm-compressor/issues/2007
    if TOKENIZERS_PARALLELISM_ENV not in os.environ:
        os.environ[TOKENIZERS_PARALLELISM_ENV] = "false"
        logger.warning(
            "Disabling tokenizer parallelism due to threading conflict between "
            "FastTokenizer and Datasets. Set "
            f"{TOKENIZERS_PARALLELISM_ENV}=false to "
            "suppress this warning."
        )

    # Set up file logging (no default files):
    # 1) If LLM_COMPRESSOR_LOG_FILE is set, log to that file.
    # 2) Else, if an explicit log_dir is provided, create a timestamped file there.
    log_file = os.environ.get("LLM_COMPRESSOR_LOG_FILE", "").strip()
    if log_file:
        p = Path(log_file).expanduser()
        p.parent.mkdir(parents=True, exist_ok=True)
        logger.add(
            str(p),
            level="DEBUG",
        )
    elif log_dir:
        os.makedirs(log_dir, exist_ok=True)
        date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        logger.add(
            f"{log_dir}/oneshot_{date_str}.log",
            level="DEBUG",
        )

    model_args, dataset_args, recipe_args, output_dir = parse_args(**kwargs)

    self.model_args = model_args
    self.dataset_args = dataset_args
    self.recipe_args = recipe_args
    self.output_dir = output_dir

    # initialize the model and processor
    pre_process(model_args, dataset_args, output_dir)

    # Set instance attributes
    self.model = self.model_args.model
    self.processor = self.model_args.processor
    self.recipe = self.recipe_args.recipe

apply_recipe_modifiers

apply_recipe_modifiers(
    calibration_dataloader: DataLoader | None,
    recipe_stage: str | None = None,
)

Applies recipe modifiers to the model during the lifecycle.

The modifiers are defined in the recipe and executed via lifecycle actions (initialize, finalize) through the global CompressionSession.

Source code in llmcompressor/entrypoints/oneshot.py
def apply_recipe_modifiers(
    self,
    calibration_dataloader: DataLoader | None,
    recipe_stage: str | None = None,
):
    """
    Applies recipe modifiers to the model during the lifecycle.

    The modifiers are defined in the recipe and executed via lifecycle actions
    (`initialize`, `finalize`) through the global `CompressionSession`.


    :param: calibration_dataloader: Dataloader for calibration data.

    Raises:
        RuntimeError: If any modifier fails during execution.
    """

    session = active_session()
    session.reset()

    # (Helen INFERENG-661): validate recipe modifiers before initialization
    # Apply MoE calibration context for the entire calibration process
    with moe_calibration_context(
        self.model,
        calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
    ):
        session.initialize(
            model=self.model,
            start=-1,
            recipe=self.recipe,
            recipe_stage=recipe_stage,
            recipe_args=self.recipe_args.recipe_args,
            calib_data=calibration_dataloader,
            sequential_targets=self.dataset_args.sequential_targets,
        )
        user_pipeline = self.dataset_args.pipeline
        pipeline = CalibrationPipeline.from_modifiers(
            session.lifecycle.recipe.modifiers, user=user_pipeline
        )

        pipeline(
            self.model,
            calibration_dataloader,
            self.dataset_args,
        )

    session.finalize()

oneshot

oneshot(
    model: str | PreTrainedModel,
    config_name: str | None = None,
    tokenizer: str | PreTrainedTokenizerBase | None = None,
    processor: str | ProcessorMixin | None = None,
    use_auth_token: bool = False,
    precision: str = "auto",
    tie_word_embeddings: bool = True,
    trust_remote_code_model: bool = False,
    save_compressed: bool = True,
    model_revision: str = "main",
    recipe: str | list[str] | None = None,
    recipe_args: list[str] | None = None,
    clear_sparse_session: bool = False,
    stage: str | None = None,
    dataset: str | Dataset | DatasetDict | None = None,
    dataset_config_name: str | None = None,
    dataset_path: str | None = None,
    splits: str | list[str] | dict[str, str] | None = None,
    batch_size: int = 1,
    data_collator: str | Callable = "truncation",
    num_calibration_samples: int = 512,
    shuffle_calibration_samples: bool = True,
    max_seq_length: int = 384,
    pad_to_max_length: bool = True,
    text_column: str = "text",
    concatenate_data: bool = False,
    streaming: bool = False,
    overwrite_cache: bool = False,
    preprocessing_num_workers: int | None = None,
    dataloader_num_workers: int = 0,
    min_tokens_per_module: float | None = None,
    moe_calibrate_all_experts: bool = True,
    pipeline: str | None = "independent",
    tracing_ignore: list[str] = [
        "_update_causal_mask",
        "create_causal_mask",
        "_update_mamba_mask",
        "make_causal_mask",
        "get_causal_mask",
        "mask_interface",
        "mask_function",
        "_prepare_4d_causal_attention_mask",
        "_prepare_fsmt_decoder_inputs",
        "_prepare_4d_causal_attention_mask_with_cache_position",
        "_update_linear_attn_mask",
        "project_per_layer_inputs",
    ],
    sequential_targets: list[str] | None = None,
    sequential_offload_device: str = "cpu",
    quantization_aware_calibration: bool = True,
    sequential_prefetch: bool = False,
    output_dir: str | None = None,
    log_dir: str | None = None,
    **kwargs,
) -> PreTrainedModel

Performs oneshot calibration on a model.

Model arguments

Parameters:

  • model

    (str | PreTrainedModel) –

    A pretrained model identifier from huggingface.co/models or a path to a local model. Required parameter.

  • distill_teacher

    Teacher model (a trained text generation model) for distillation.

  • config_name

    (str | None, default: None ) –

    Pretrained config name or path if not the same as model_name.

  • tokenizer

    (str | PreTrainedTokenizerBase | None, default: None ) –

    Pretrained tokenizer name or path if not the same as model_name.

  • processor

    (str | ProcessorMixin | None, default: None ) –

    Pretrained processor name or path if not the same as model_name.

  • use_auth_token

    (bool, default: False ) –

    Whether to use Hugging Face auth token for private models.

  • precision

    (str, default: 'auto' ) –

    Precision to cast model weights to, default to auto.

  • tie_word_embeddings

    (bool, default: True ) –

    Whether the model's input and output word embeddings should be left tied if possible. False means always untie.

  • trust_remote_code_model

    (bool, default: False ) –

    Whether to allow for custom models to execute their own modeling files.

  • save_compressed

    (bool, default: True ) –

    Whether to compress sparse models during save.

  • model_revision

    (str, default: 'main' ) –

    The specific model version to use (can be branch name, tag, or commit id).

    Recipe arguments

  • recipe

    (str | list[str] | None, default: None ) –

    Path to a LLM Compressor recipe, or a list of paths to multiple LLM Compressor recipes.

  • recipe_args

    (list[str] | None, default: None ) –

    List of recipe arguments to evaluate, in the format "key1=value1", "key2=value2".

  • clear_sparse_session

    (bool, default: False ) –

    Whether to clear CompressionSession/ CompressionLifecycle data between runs.

  • stage

    (str | None, default: None ) –

    The stage of the recipe to use for oneshot.

    Dataset arguments

  • dataset

    (str | Dataset | DatasetDict | None, default: None ) –

    The name of the dataset to use (via the datasets library).

  • dataset_config_name

    (str | None, default: None ) –

    The configuration name of the dataset to use.

  • dataset_path

    (str | None, default: None ) –

    Path to a custom dataset. Supports json, csv, dvc.

  • splits

    (str | list[str] | dict[str, str] | None, default: None ) –

    Optional percentages of each split to download.

  • batch_size

    (int, default: 1 ) –

    calibration dataset batch size. During calibration, LLM Compressor disables lm_head output computations to reduce memory usage from large calibration batch sizes. Large batch sizes may result excess padding or truncation, depending on the data_collator

  • data_collator

    (str | Callable, default: 'truncation' ) –

    The function to use to form a batch from the dataset. Can also specify 'truncation' or 'padding' to truncate or pad non-uniform sequence lengths in a batch. Defaults to 'truncation'.

  • num_calibration_samples

    (int, default: 512 ) –

    Number of samples to use for one-shot calibration.

  • shuffle_calibration_samples

    (bool, default: True ) –

    Whether to shuffle the dataset before calibration.

  • max_seq_length

    (int, default: 384 ) –

    Maximum total input sequence length after tokenization.

  • pad_to_max_length

    (bool, default: True ) –

    Whether to pad all samples to max_seq_length.

  • text_column

    (str, default: 'text' ) –

    Key to use as the text input to tokenizer/processor.

  • concatenate_data

    (bool, default: False ) –

    Whether to concatenate datapoints to fill max_seq_length.

  • streaming

    (bool, default: False ) –

    True to stream data from a cloud dataset.

  • overwrite_cache

    (bool, default: False ) –

    Whether to overwrite the cached preprocessed datasets.

  • preprocessing_num_workers

    (int | None, default: None ) –

    Number of processes for dataset preprocessing.

  • dataloader_num_workers

    (int, default: 0 ) –

    Number of worker processes for data loading. Default is 0 (safe for low CPU/GPU memory). Set to 2 or more for faster calibration if you have sufficient RAM. Custom data collators may not work with multiprocessing.

  • min_tokens_per_module

    (float | None, default: None ) –

    Minimum percentage of tokens per module, relevant for MoE models.

  • moe_calibrate_all_experts

    (bool, default: True ) –

    Whether to calibrate all experts during MoE model calibration. When True, all experts will see all tokens during calibration, ensuring proper quantization statistics. When False, only routed experts will be used. Only relevant for MoE models. Default is True.

  • pipeline

    (str | None, default: 'independent' ) –

    Calibration pipeline used to calibrate model Options: ['basic', 'datafree', 'sequential', 'independent']

  • tracing_ignore

    (list[str], default: ['_update_causal_mask', 'create_causal_mask', '_update_mamba_mask', 'make_causal_mask', 'get_causal_mask', 'mask_interface', 'mask_function', '_prepare_4d_causal_attention_mask', '_prepare_fsmt_decoder_inputs', '_prepare_4d_causal_attention_mask_with_cache_position', '_update_linear_attn_mask', 'project_per_layer_inputs'] ) –

    List of functions to ignore during tracing, either {module}.{method_name} or {function_name}

  • sequential_targets

    (list[str] | None, default: None ) –

    List of layer targets for the sequential pipeline. This is typically a single DecoderLayer. Not specifying this argument will cause the sequential pipeline to default to using the no_split_params specified by the HF model definition

  • sequential_offload_device

    (str, default: 'cpu' ) –

    Device used to offload intermediate activations between sequential layers. It is recommended to use cuda:1 if using more than one gpu. Default is cpu.

  • quantization_aware_calibration

    (bool, default: True ) –

    Whether to enable quantization-aware calibration in the sequential pipeline. When True, quantization is applied during forward pass in calibration. When False, quantization is disabled during forward pass in calibration. Default is set to True.

  • sequential_prefetch

    (bool, default: False ) –

    When using the sequential pipeline, prefetch the next batch in a background thread to overlap onload with forward. Default False; set True for faster calibration when GPU memory allows.

    Miscellaneous arguments

  • output_dir

    (str | None, default: None ) –

    Path to save the output model after calibration. Nothing is saved if None.

  • log_dir

    (str | None, default: None ) –

    Path to save logs during oneshot run. Nothing is logged to file if None.

Returns:

  • PreTrainedModel

    The calibrated PreTrainedModel

Source code in llmcompressor/entrypoints/oneshot.py
def oneshot(
    # Model arguments
    model: str | PreTrainedModel,
    config_name: str | None = None,
    tokenizer: str | PreTrainedTokenizerBase | None = None,
    processor: str | ProcessorMixin | None = None,
    use_auth_token: bool = False,
    precision: str = "auto",
    tie_word_embeddings: bool = True,
    trust_remote_code_model: bool = False,
    save_compressed: bool = True,
    model_revision: str = "main",
    # Recipe arguments
    recipe: str | list[str] | None = None,
    recipe_args: list[str] | None = None,
    clear_sparse_session: bool = False,
    stage: str | None = None,
    # Dataset arguments
    dataset: str | Dataset | DatasetDict | None = None,
    dataset_config_name: str | None = None,
    dataset_path: str | None = None,
    splits: str | list[str] | dict[str, str] | None = None,
    batch_size: int = 1,
    data_collator: str | Callable = "truncation",
    num_calibration_samples: int = 512,
    shuffle_calibration_samples: bool = True,
    max_seq_length: int = 384,
    pad_to_max_length: bool = True,
    text_column: str = "text",
    concatenate_data: bool = False,
    streaming: bool = False,
    overwrite_cache: bool = False,
    preprocessing_num_workers: int | None = None,
    dataloader_num_workers: int = 0,
    min_tokens_per_module: float | None = None,
    moe_calibrate_all_experts: bool = True,
    pipeline: str | None = "independent",
    tracing_ignore: list[str] = [
        "_update_causal_mask",
        "create_causal_mask",
        "_update_mamba_mask",
        "make_causal_mask",
        "get_causal_mask",
        "mask_interface",
        "mask_function",
        "_prepare_4d_causal_attention_mask",
        "_prepare_fsmt_decoder_inputs",
        "_prepare_4d_causal_attention_mask_with_cache_position",
        "_update_linear_attn_mask",
        "project_per_layer_inputs",
    ],
    sequential_targets: list[str] | None = None,
    sequential_offload_device: str = "cpu",
    quantization_aware_calibration: bool = True,
    sequential_prefetch: bool = False,
    # Miscellaneous arguments
    output_dir: str | None = None,
    log_dir: str | None = None,
    **kwargs,
) -> PreTrainedModel:
    """
    Performs oneshot calibration on a model.

    # Model arguments
    :param model: A pretrained model identifier from huggingface.co/models or a path
        to a local model. Required parameter.
    :param distill_teacher: Teacher model (a trained text generation model)
        for distillation.
    :param config_name: Pretrained config name or path if not the same as
        model_name.
    :param tokenizer: Pretrained tokenizer name or path if not the same as
        model_name.
    :param processor: Pretrained processor name or path if not the same as
        model_name.
    :param use_auth_token: Whether to use Hugging Face auth token for private
        models.
    :param precision: Precision to cast model weights to, default to auto.
    :param tie_word_embeddings: Whether the model's input and output word embeddings
        should be left tied if possible. False means always untie.
    :param trust_remote_code_model: Whether to allow for custom models to execute
        their own modeling files.
    :param save_compressed: Whether to compress sparse models during save.
    :param model_revision: The specific model version to use (can be branch name,
        tag, or commit id).

    # Recipe arguments
    :param recipe: Path to a LLM Compressor recipe, or a list of paths
      to multiple LLM Compressor recipes.
    :param recipe_args: List of recipe arguments to evaluate, in the
        format "key1=value1", "key2=value2".
    :param clear_sparse_session: Whether to clear CompressionSession/
        CompressionLifecycle data between runs.
    :param stage: The stage of the recipe to use for oneshot.

    # Dataset arguments
    :param dataset: The name of the dataset to use (via the datasets
        library).
    :param dataset_config_name: The configuration name of the dataset
        to use.
    :param dataset_path: Path to a custom dataset. Supports json, csv, dvc.
    :param splits: Optional percentages of each split to download.
    :param batch_size: calibration dataset batch size. During calibration,
        LLM Compressor disables lm_head output computations to reduce memory
        usage from large calibration batch sizes. Large batch sizes may result
        excess padding or truncation, depending on the data_collator
    :param data_collator: The function to use to form a batch from the dataset. Can
        also specify 'truncation' or 'padding' to truncate or pad non-uniform sequence
        lengths in a batch. Defaults to 'truncation'.
    :param num_calibration_samples: Number of samples to use for one-shot
        calibration.
    :param shuffle_calibration_samples: Whether to shuffle the dataset before
        calibration.
    :param max_seq_length: Maximum total input sequence length after tokenization.
    :param pad_to_max_length: Whether to pad all samples to `max_seq_length`.
    :param text_column: Key to use as the `text` input to tokenizer/processor.
    :param concatenate_data: Whether to concatenate datapoints to fill
        max_seq_length.
    :param streaming: True to stream data from a cloud dataset.
    :param overwrite_cache: Whether to overwrite the cached preprocessed datasets.
    :param preprocessing_num_workers: Number of processes for dataset preprocessing.
    :param dataloader_num_workers: Number of worker processes for data loading. Default
        is 0 (safe for low CPU/GPU memory). Set to 2 or more for faster calibration if
        you have sufficient RAM. Custom data collators may not work with
        multiprocessing.
    :param min_tokens_per_module: Minimum percentage of tokens per
        module, relevant for MoE models.
    :param moe_calibrate_all_experts: Whether to calibrate all experts during MoE
        model calibration. When True, all experts will see all tokens during
        calibration, ensuring proper quantization statistics. When False, only
        routed experts will be used. Only relevant for MoE models. Default is True.
    :param pipeline: Calibration pipeline used to calibrate model Options:
        ['basic', 'datafree', 'sequential', 'independent']
    :param tracing_ignore: List of functions to ignore during tracing, either
        {module}.{method_name} or {function_name}
    :param sequential_targets: List of layer targets for the sequential pipeline.
        This is typically a single DecoderLayer. Not specifying this argument will
        cause the sequential pipeline to default to using the `no_split_params`
        specified by the HF model definition
    :param sequential_offload_device: Device used to offload intermediate activations
        between sequential layers. It is recommended to use `cuda:1` if using more
        than one gpu. Default is cpu.
    :param quantization_aware_calibration: Whether to enable quantization-aware
        calibration in the sequential pipeline. When True, quantization is applied
        during forward pass in calibration. When False, quantization is disabled
        during forward pass in calibration. Default is set to True.
    :param sequential_prefetch: When using the sequential pipeline, prefetch the
        next batch in a background thread to overlap onload with forward. Default
        False; set True for faster calibration when GPU memory allows.

    # Miscellaneous arguments
    :param output_dir: Path to save the output model after calibration.
        Nothing is saved if None.
    :param log_dir: Path to save logs during oneshot run.
        Nothing is logged to file if None.

    :return: The calibrated PreTrainedModel
    """

    # pass all args directly into Oneshot
    local_args = {
        k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
    }
    one_shot = Oneshot(**local_args, **kwargs)
    one_shot()

    return one_shot.model