Skip to content

llmcompressor.modifiers.awq

Modules:

Classes:

  • AWQMapping

    Dataclass storing config of activation mappings to smooth

  • AWQModifier

    Implements the AWQ (Activation-Weighted Quantization) algorithm,

Functions:

AWQMapping dataclass

AWQMapping(
    smooth_layer: str,
    balance_layers: list[str],
    activation_hook_target: str | None = None,
)

Dataclass storing config of activation mappings to smooth The output activations of smooth_layer are input activations into the balance_layers

AWQMappings are resolved into ResolvedMappings, which retain pointers to the actual torch.nn.Modules and additional metadata at runtime

Parameters:

  • smooth_layer

    (str) –

    regex or name of the activation layer to smooth

  • balance_layers

    (list[str]) –

    list of regex or names of weight layers that must be balanced to offset the smoothing

  • activation_hook_target

    (str | None, default: None ) –

    optional dotted attribute path relative to the parent module (lowest common ancestor of balance_layers) specifying which submodule to hook for activation caching. Useful for parallel transformer blocks (e.g. Cohere, Gemma 3) where the first balance layer is not the correct place to capture activations. When None (default), the hook is placed on balance_layers[0].

AWQModifier

Bases: Modifier, QuantizationMixin

Implements the AWQ (Activation-Weighted Quantization) algorithm, as described in https://arxiv.org/pdf/2306.00978. The algorithm significantly reduces quantization error by protecting only 1% of the most salient weight channels.

Instead of relying on raw weight values, AWQ identifies important channels by analyzing activation patterns, focusing on the channels in the weight tensor that are most responsive to the input. To reduce quantization error, it scales these channels in a way that preserves the model's original behavior, using scaling factors computed offline from activation statistics.

Because this modifier manipulates the weights of the model, it can only be used in in one-shot and not during training. Activation ranges are determined by running a small set of calibration data through the model.

example recipe:

AWQModifier:
  mappings:
    - smooth_layer: "re:.*self_attn_layer_norm"
      balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
    - smooth_layer: "re:.*final_layer_norm"
      balance_layers: ["re:.*fc1"]
    # activation_hook_target specifies which submodule of the parent to hook
    # for activation caching.
    # This change is only useful for MoE models with parallel transformer blocks,
    # and one should use the default value (None) in most cases.
  ignore: ["lm_head"]
  config_groups:
    group_0:
      targets:
        - "Linear"
      input_activations: null
      output_activations: null
      weights:
        num_bits: 4
        type: int
        symmetric: false
        strategy: group
        group_size: 128

Lifecycle:

  • on_initialize
    • resolve mappings
    • capture kwargs needed for forward passes into modules
  • on_start
    • set up activation cache hooks to capture input activations to balance layers
  • on sequential epoch end
    • apply smoothing to each smoothing layer
      • consume cached activations across all batches
        • clear cached activations as they are used
      • find best smoothing scale for each smoothing layer via grid search
      • apply best scales to model weights
      • raise error if any unused activations remain
  • on_end
    • re-run logic of sequential epoch end (in case of basic pipeline)
    • set scales and zero points
    • remove activation hooks
  • on_finalize
    • clear resolved mappings and captured activations

Parameters:

  • sequential_targets

    list of module names to compress in the same calibration pass

  • mappings

    list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first entry is a list of layers who share the same input activation (the one to be to smoothed) and the second entry is the layer whose output is scaled to achieve the smoothing. If regex is used, it matches layers with the largest overlap in module name. Each mapping may also include an activation_hook_target: a dotted attribute path relative to the parent module (lowest common ancestor) specifying which submodule to hook for activation caching. This is useful for parallel transformer blocks where the default (hooking balance_layers[0]) would capture the wrong activations.

  • ignore

    list of layers to ignore during quantization (not smoothed). It should match the name of layers whose outputs are scaled to achieve smoothing (the second entry of the mappings list).

  • offload_device

    offload cached args to this device, which reduces memory requirements but requires more time to move data between cpu and execution device. Defaults to None, so cached args are not offloaded. Consider setting to torch.device("cpu") if you are encountering OOM errors

  • duo_scaling

    whether to use duo scaling, which uses both input activations and weights to determine the scaling factor. Defaults to True If True, both activations and weights are used. If False, only activations are used. If "both", half the grid search is performed with duo_scaling=False and the other half is performed with duo_scaling=True.

  • n_grid

    when performing the best scales grid search for each mapping, this specifies how many grid points should be used. To decrease the runtime, at the possible cost of slightly worse scales, this can be decreased. Defaults to 20

Methods:

  • on_end

    Finish calibrating by setting scales and zero-points,

  • on_finalize

    Clean up by clearing the activations and mapping data

  • on_initialize

    Initialize AWQ on the given state

  • validate_duo_scaling

    Validate that duo_scaling is either True, False, or 'both' (lowercase)

on_end

on_end(state: State, event: Event, **kwargs)

Finish calibrating by setting scales and zero-points, removing observers and calibration hooks

Source code in llmcompressor/modifiers/awq/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by setting scales and zero-points,
     removing observers and calibration hooks
    """
    self._assert_all_activations_consumed()

    self.ended_ = True

    named_modules = list(
        match_named_modules(state.model, self.resolved_targets, self.ignore)
    )

    # For TENSOR_GROUP (nvfp4), calculate global scales after smoothing
    for _, module in tqdm(named_modules, desc="Updating global scales"):
        update_weight_global_scale(module)

    # For TENSOR_GROUP (nvfp4), fuse global scales for attention and MLP layers
    # This is a requirement for vLLM inference.
    for module in tqdm(state.model.modules(), desc="Fusing global scales"):
        update_fused_layer_weight_global_scales(module)

    # Calculate scales and zero points using the fused global scales
    for _, module in tqdm(named_modules, desc="Calibrating weights"):
        update_weight_zp_scale(module)

    QuantizationMixin.end_calibration(self, state.model)

    # remove activation hooks
    self.remove_hooks()

on_finalize

on_finalize(state: State, **kwargs) -> bool

Clean up by clearing the activations and mapping data

Parameters:

  • state

    (State) –

    unused

Returns:

  • bool

    True

Source code in llmcompressor/modifiers/awq/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    Clean up by clearing the activations and mapping data

    :param state: unused
    :return: True
    """
    if not self.ended_:
        self.on_end(state, None)

    self._log_error_metrics()

    self._parent_args_cache.clear()
    self._smooth_activation_means.clear()
    self._resolved_mappings.clear()
    self._error_metrics.clear()

    return True

on_initialize

on_initialize(state: State, **kwargs) -> bool

Initialize AWQ on the given state Initialize quantization, resolve mappings, cache module kwargs

Parameters:

  • state

    (State) –

    state to run AWQ on

Returns:

  • bool

    True on a successful run, False otherwise

Source code in llmcompressor/modifiers/awq/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Initialize AWQ on the given state
    Initialize quantization, resolve mappings, cache module kwargs

    :param state: state to run AWQ on
    :return: True on a successful run, False otherwise
    """

    # apply config to model and prepare calibration hooks
    if QuantizationMixin.has_config(self):
        QuantizationMixin.initialize_quantization(self, state.model)

    # Validate that duo_scaling is only used with per-channel quantization
    if self.duo_scaling is not False:
        for _, module in match_named_modules(
            state.model, self.resolved_targets, self.ignore
        ):
            if (
                hasattr(module, "quantization_scheme")
                and hasattr(module.quantization_scheme, "weights")
                and module.quantization_scheme.weights.strategy
                == QuantizationStrategy.TENSOR
            ):
                raise ValueError(
                    "duo_scaling is only supported with per-channel quantization "
                    "strategies (group or channel), but found TENSOR strategy. "
                    "Please set duo_scaling=False or use a per-channel "
                    "quantization strategy."
                )

    if self.mappings is None:
        logger.info("No AWQModifier.mappings provided, inferring from model...")
        self.mappings = get_layer_mappings_from_architecture(
            architecture=state.model.__class__.__name__
        )

    # Set default offload_device
    if self.offload_device == Sentinel("not_provided"):
        # Check if we have a MoE model
        if is_moe_model(state.model):
            self.offload_device = torch.device("cpu")
            logger.info(
                "MoE model detected: setting offload_device to 'cpu' by default "
                "to reduce memory usage. You can override this by explicitly "
                "setting offload_device in your recipe."
            )
        else:
            # For non-MoE models, convert sentinel to None
            # (no offloading by default)
            self.offload_device = None

    self._set_resolved_mappings(state.model)

    return True

validate_duo_scaling classmethod

validate_duo_scaling(v)

Validate that duo_scaling is either True, False, or 'both' (lowercase)

Source code in llmcompressor/modifiers/awq/base.py
@field_validator("duo_scaling")
@classmethod
def validate_duo_scaling(cls, v):
    """Validate that duo_scaling is either True, False, or 'both' (lowercase)"""
    if v not in (True, False, "both"):
        raise ValueError(f"duo_scaling must be True, False, or 'both', got {v!r}")
    return v

get_layer_mappings_from_architecture

get_layer_mappings_from_architecture(
    architecture: str,
) -> list[AWQMapping]

Parameters:

  • architecture

    (str) –

    str: The architecture of the model

Returns:

  • list[AWQMapping]

    list: The layer mappings for the given architecture

Source code in llmcompressor/modifiers/awq/mappings.py
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:
    """
    :param architecture: str: The architecture of the model
    :return: list: The layer mappings for the given architecture
    """

    if architecture not in AWQ_MAPPING_REGISTRY:
        logger.info(
            f"Architecture {architecture} not found in mappings. "
            f"Using default mappings: {_default_mappings}"
        )

    return AWQ_MAPPING_REGISTRY.get(architecture, _default_mappings)