Skip to content

llmcompressor.modifiers.autoround

Modules:

Classes:

AutoRoundModifier

Bases: Modifier, QuantizationMixin

Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf. This modifier leverages signed gradient descent (SignSGD) optimizer and block-wise loss to optimize rounding values and weight clipping in a few steps.

Sample yaml:

test_stage:
  modifiers:
    AutoRoundModifier:
      iters: 200
      config_groups:
        group_0:
          targets:
            - "Linear"
          input_activations: null
          output_activations: null
          weights:
            num_bits: 4
            type: "int"
            symmetric: true
            strategy: group
            group_size: 128

Lifecycle:

  • on_initialize
    • apply config to model
  • on_start
    • add input capture hooks to decoding layers
  • on_sequential_epoch_end
    • apply_autoround
    • post_autoround_cleanup
  • on_finalize
    • remove_hooks()
    • model.apply(freeze_module_quantization)

Parameters:

  • config_groups

    dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

  • targets

    list of layer names to quantize if a scheme is provided. Defaults to Linear layers

  • ignore

    optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

  • scheme

    a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level.

  • sequential_targets

    class names of decoding layers to tune sequentially. If None, targets are inferred via get_no_split_params() to respect no-split constraints for large models. Defaults to None.

  • iters

    number of tuning iterations per block (decoding layer). Higher values typically improve accuracy at the cost of longer tuning time. Defaults to 200.

  • enable_torch_compile

    whether to enable torch.compile to accelerate the tuning loop. Disable if your environment or model encounters compilation issues. Defaults to True.

  • batch_size

    calibration/tuning batch size used by AutoRound when optimizing rounding/clipping parameters. Larger values can improve stability but require more memory. Defaults to 8.

  • device_ids

    optional device map string for layer dispatch during tuning. Examples: "0,1" for cuda:0 and cuda:1, or "auto" to use all available GPUs. When None, no dispatching occurs and the model remains on its current device. Defaults to None.

Methods:

  • apply_autoround

    Applies AutoRound quantization tuning on the current decoding layer.

  • on_end

    Finish calibrating by removing observers and calibration hooks

  • on_finalize

    disable the quantization observers used by the AutoRound algorithm

  • on_initialize

    Initialize the model state for quantization and calibration.

  • start_calibration

    Register activation calibration hooks and enable quantization as we calibrate

apply_autoround

apply_autoround(state, subgraph)

Applies AutoRound quantization tuning on the current decoding layer.

The tuning logic is as follows: for iter in range(iters): quant_output = forward(layer, cached_inputs) loss = mse_loss(quant_output, original_output) loss.backward() optimizer.step() if loss < best_loss: best_params = update_params(layer)

For more details, please refer to the AutoRound repository: https://github.com/intel/auto-round/

Source code in llmcompressor/modifiers/autoround/base.py
def apply_autoround(self, state, subgraph):
    """
    Applies AutoRound quantization tuning on the current decoding layer.

    The tuning logic is as follows:
    for iter in range(iters):
        quant_output = forward(layer, cached_inputs)
        loss = mse_loss(quant_output, original_output)
        loss.backward()
        optimizer.step()
        if loss < best_loss:
            best_params = update_params(layer)

    For more details, please refer to the AutoRound repository:
    https://github.com/intel/auto-round/
    """
    modules = list(subgraph.submodules(model=state.model))

    decoding_layers = [m for m in modules if self._is_decoding_layer(m)]
    if len(decoding_layers) == 0:
        return
    assert len(decoding_layers) == 1, (
        "Only one decoding layer is expected in the subgraph, "
        f"found {len(decoding_layers)}."
    )
    decoding_layer = decoding_layers[0]

    logger.info("Applying AutoRound on layer {}", decoding_layer._tmp_name)

    # Build wrapped_model for AutoRound initialization
    wrapped_model = _wrap_decoding_layer(decoding_layer)
    wrapped_model.name_or_path = state.model.name_or_path
    wrapped_model.config = state.model.config

    # Build kwargs for AutoRound initialization
    ar_quant_scheme = self._mapping_config_to_autoround()
    fp_layers = self.get_unquantized_layer_names(decoding_layer)
    kwargs = {
        "tokenizer": "",  # A placeholder
        "scheme": ar_quant_scheme,
        "iters": self.iters,
        "lr": self.lr,
        "enable_torch_compile": self.enable_torch_compile,
        "batch_size": self.batch_size,
        "device_map": self.device_ids,
        "fp_layers": ",".join(fp_layers) if fp_layers else "",
    }

    llmc_registered_qparams = self._preprocess_qparams(decoding_layer)
    with (
        torch.enable_grad(),
        align_module_device(decoding_layer),
        suspend_offloading(wrapped_model),
    ):
        ar = AutoRound(
            model=wrapped_model,
            **kwargs,
        )
        # TODO: configure layer-wise config based on self.resolved_config
        ar.configure_layer_config(enable_gguf_official_mixed=False)
        ar.batch_dim = 0
        first_param = next(decoding_layer.parameters())
        device = first_param.device
        cur_inputs = self._all_module_input[decoding_layer._tmp_name]
        decoding_layer.tuning_device = device
        # Leave offload for LLMC to handle if `device_ids` is not set
        auto_offload = False
        if self.device_ids is not None:
            # When device_ids is set, we move decoding layer to CPU first,
            # then the submodules will be re-dispatched by AutoRound.
            decoding_layer.to("cpu")
            auto_offload = True

        q_input, _ = ar.quantize_block(
            block=decoding_layer,
            inputs=cur_inputs,
            q_input=self._q_input,
            device=str(device),
            auto_offload=auto_offload,
        )
        self._q_input = q_input

        decoding_layer = self._unwrapper_quantized_layer(decoding_layer)

    decoding_layer.eval()
    # Update offload parameters and remove temporary attributes
    self._postprocess_qparams(decoding_layer, llmc_registered_qparams)

on_end

on_end(state: State, event: Event, **kwargs)

Finish calibrating by removing observers and calibration hooks

Source code in llmcompressor/modifiers/autoround/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by removing observers and calibration hooks
    """
    self.ended_ = True
    QuantizationMixin.end_calibration(self, state.model)
    self._remove_temporary_names(state.model)
    self.remove_hooks()
    self._q_input = None

on_finalize

on_finalize(state: State, **kwargs) -> bool

disable the quantization observers used by the AutoRound algorithm

Parameters:

  • state

    (State) –

    session state storing input model and calibration data

Source code in llmcompressor/modifiers/autoround/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    disable the quantization observers used by the AutoRound algorithm

    :param state: session state storing input model and calibration data
    """
    if not self.ended_:
        self.on_end(state, None)

    return True

on_initialize

on_initialize(state: State, **kwargs) -> bool

Initialize the model state for quantization and calibration.

Parameters:

  • state

    (State) –

    session state storing input model and calibration data

Source code in llmcompressor/modifiers/autoround/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Initialize the model state for quantization and calibration.

    :param state: session state storing input model and calibration data
    """
    # apply config to model and prepare calibration hooks
    if QuantizationMixin.has_config(self):
        QuantizationMixin.initialize_quantization(self, state.model)

    # prepare module names
    self._add_temporary_names(state.model)
    # freeze all model parameters
    for _, param in state.model.named_parameters():
        param.requires_grad_(False)

    self.sequential_targets = self._infer_sequential_targets(state.model)
    return True

start_calibration

start_calibration(model: Module)

Register activation calibration hooks and enable quantization as we calibrate

Parameters:

  • model

    (Module) –

    model to prepare for calibration

Source code in llmcompressor/modifiers/autoround/base.py
def start_calibration(self, model: torch.nn.Module):
    """
    Register activation calibration hooks and enable quantization as we calibrate

    :param model: model to prepare for calibration
    """
    targets = match_named_modules(model, self.targets, self.ignore)
    if targets_embeddings(model, targets):
        untie_word_embeddings(model)

    for _, module in match_named_modules(model, self.targets, self.ignore):
        # skip register observers for auto-round
        apply_calibration_status(module)

    model.apply(enable_quantization)  # quantize at the same time as calibrate