Skip to content

llmcompressor.utils

General utility functions used throughout LLM Compressor.

Modules:

Functions:

DisableQuantization

DisableQuantization(module: Module)

Disable quantization during forward passes after applying a quantization config

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def DisableQuantization(module: torch.nn.Module):
    """
    Disable quantization during forward passes after applying a quantization config
    """
    try:
        module.apply(disable_quantization)
        yield
    finally:
        module.apply(enable_quantization)

calibration_forward_context

calibration_forward_context(model: Module)

Context in which all calibration forward passes should occur.

  • Remove gradient calculations
  • Disable the KV cache
  • Disable train mode and enable eval mode
  • Disable hf kernels which could bypass hooks
  • Disable lm head (input and weights can still be calibrated, output will be meta)
Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def calibration_forward_context(model: torch.nn.Module):
    """
    Context in which all calibration forward passes should occur.

    - Remove gradient calculations
    - Disable the KV cache
    - Disable train mode and enable eval mode
    - Disable hf kernels which could bypass hooks
    - Disable lm head (input and weights can still be calibrated, output will be meta)
    """
    with contextlib.ExitStack() as stack:
        stack.enter_context(torch.no_grad())
        stack.enter_context(disable_cache(model))
        stack.enter_context(eval_context(model))
        stack.enter_context(disable_hf_kernels(model))
        stack.enter_context(disable_lm_head(model))
        yield

disable_cache

disable_cache(module: Module)

Temporarily disable the key-value cache for transformer models. Used to prevent excess memory use in one-shot cases where the model only performs the prefill phase and not the generation phase.

Example:

model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") input = torch.randint(0, 32, size=(1, 32)) with disable_cache(model): ... output = model(input)

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def disable_cache(module: torch.nn.Module):
    """
    Temporarily disable the key-value cache for transformer models. Used to prevent
    excess memory use in one-shot cases where the model only performs the prefill
    phase and not the generation phase.

    Example:
    >>> model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    >>> input = torch.randint(0, 32, size=(1, 32))
    >>> with disable_cache(model):
    ...     output = model(input)
    """

    if isinstance(module, PreTrainedModel):
        config = module.config
        config = getattr(config, "text_config", config)
        with patch_attr(config, "use_cache", False):
            yield

    else:
        yield

disable_hf_kernels

disable_hf_kernels(module: Module)

In transformers>=4.50.0, some module forward methods may be replaced by calls to hf hub kernels. This has the potential to bypass hooks added by LLM Compressor

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def disable_hf_kernels(module: torch.nn.Module):
    """
    In transformers>=4.50.0, some module forward methods may be
    replaced by calls to hf hub kernels. This has the potential
    to bypass hooks added by LLM Compressor
    """
    if isinstance(module, PreTrainedModel):
        with patch_attr(module.config, "disable_custom_kernels", True):
            yield

    else:
        yield

disable_lm_head

disable_lm_head(model: Module)

Disable the lm_head of a model by moving it to the meta device. This function does not untie parameters and restores the model proper loading upon exit

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def disable_lm_head(model: torch.nn.Module):
    """
    Disable the lm_head of a model by moving it to the meta device. This function
    does not untie parameters and restores the model proper loading upon exit
    """
    _, lm_head = get_embeddings(model)
    if lm_head is None:
        logger.warning(
            f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
            "but was unable to to find lm_head. This may lead to unexpected OOM."
        )
        yield
        return

    elif not isinstance(lm_head, torch.nn.Linear):
        logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}")
        yield
        return

    else:
        dummy_weight = lm_head.weight.to("meta")

        def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
            return input.to("meta") @ dummy_weight.T

        with contextlib.ExitStack() as stack:
            lm_head_forward = dummy_forward.__get__(lm_head)
            stack.enter_context(patch_attr(lm_head, "forward", lm_head_forward))

            if hasattr(model, "_hf_hook"):
                stack.enter_context(patch_attr(model._hf_hook, "io_same_device", False))

            yield

dispatch_for_generation

dispatch_for_generation(*args, **kwargs) -> PreTrainedModel

Dispatch a model autoregressive generation. This means that modules are dispatched evenly across avaiable devices and kept onloaded if possible.

Parameters:

  • model

    model to dispatch

  • hint_batch_size

    reserve memory for batch size of inputs

  • hint_batch_seq_len

    reserve memory for sequence of length of inputs

  • hint_model_dtype

    reserve memory for model's dtype. Will be inferred from model if none is provided

  • hint_extra_memory

    extra memory reserved for model serving

  • no_split_modules

    names of module classes which should not be split across multiple devices

Returns:

  • PreTrainedModel

    dispatched model

Source code in llmcompressor/utils/dev.py
@deprecated("compressed_tensors.offload::dispatch_model")
@wraps(dispatch_model)
def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel:
    """
    Dispatch a model autoregressive generation. This means that modules are dispatched
    evenly across avaiable devices and kept onloaded if possible.

    :param model: model to dispatch
    :param hint_batch_size: reserve memory for batch size of inputs
    :param hint_batch_seq_len: reserve memory for sequence of length of inputs
    :param hint_model_dtype: reserve memory for model's dtype.
        Will be inferred from model if none is provided
    :param hint_extra_memory: extra memory reserved for model serving
    :param no_split_modules: names of module classes which should not be split
        across multiple devices
    :return: dispatched model
    """
    return dispatch_model(*args, **kwargs)

eval_context

eval_context(module: Module)

Disable pytorch training mode for the given module

Source code in llmcompressor/utils/helpers.py
@contextlib.contextmanager
def eval_context(module: torch.nn.Module):
    """
    Disable pytorch training mode for the given module
    """
    restore_value = module.training
    try:
        module.train(False)  # equivalent to eval()
        yield

    finally:
        module.train(restore_value)

get_embeddings

get_embeddings(
    model: PreTrainedModel,
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]

Returns input and output embeddings of a model. If get_input_embeddings/ get_output_embeddings is not implemented on the model, then None will be returned instead.

Parameters:

  • model

    (PreTrainedModel) –

    model to get embeddings from

Returns:

  • tuple[Module | None, Module | None]

    tuple of containing embedding modules or none

Source code in llmcompressor/utils/transformers.py
def get_embeddings(
    model: PreTrainedModel,
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
    """
    Returns input and output embeddings of a model. If `get_input_embeddings`/
    `get_output_embeddings` is not implemented on the model, then None will be returned
    instead.

    :param model: model to get embeddings from
    :return: tuple of containing embedding modules or none
    """
    try:
        input_embed = model.get_input_embeddings()

    except (AttributeError, NotImplementedError, TypeError):
        input_embed = None

    try:
        output_embed = model.get_output_embeddings()
    except (AttributeError, NotImplementedError):
        output_embed = None

    return input_embed, output_embed

greedy_bin_packing

greedy_bin_packing(
    items: list[T],
    num_bins: int,
    item_weight_fn: Callable[[T], float] = lambda x: 1,
) -> tuple[list[T], list[list[T]], dict[T, int]]

Distribute items across bins using a greedy bin-packing heuristic.

Items are sorted by weight in descending order, then each item is assigned to the bin with the smallest current total weight. This approximates an even distribution of weight across bins.

Parameters:

  • items

    (list[T]) –

    items to distribute. Sorted in-place by descending weight.

  • num_bins

    (int) –

    number of bins to distribute items across.

  • item_weight_fn

    (Callable[[T], float], default: lambda x: 1 ) –

    callable that returns the weight of an item. Defaults to uniform weight of 1.

Returns:

  • tuple[list[T], list[list[T]], dict[T, int]]

    a 3-tuple of: - items: the input list, now sorted by descending weight. - bin_to_items: list of length num_bins where each element is the list of items assigned to that bin. - item_to_bin: mapping from each item to its assigned bin index.

Source code in llmcompressor/utils/dist.py
def greedy_bin_packing(
    items: list[T],
    num_bins: int,
    item_weight_fn: Callable[[T], float] = lambda x: 1,
) -> tuple[list[T], list[list[T]], dict[T, int]]:
    """Distribute items across bins using a greedy bin-packing heuristic.

    Items are sorted by weight in descending order, then each item is
    assigned to the bin with the smallest current total weight. This
    approximates an even distribution of weight across bins.

    :param items: items to distribute. Sorted in-place by descending weight.
    :param num_bins: number of bins to distribute items across.
    :param item_weight_fn: callable that returns the weight of an item.
        Defaults to uniform weight of 1.
    :return: a 3-tuple of:
        - items: the input list, now sorted by descending weight.
        - bin_to_items: list of length ``num_bins`` where each element is
          the list of items assigned to that bin.
        - item_to_bin: mapping from each item to its assigned bin index.
    """
    items.sort(key=item_weight_fn, reverse=True)
    bin_to_items: list[list[T]] = [[] for _ in range(num_bins)]
    item_to_bin: dict[T, int] = dict()
    bin_weights: list[float] = [0 for _ in range(num_bins)]
    for item in items:
        target_bin = bin_weights.index(min(bin_weights))
        bin_to_items[target_bin].append(item)
        item_to_bin[item] = target_bin
        bin_weights[target_bin] += item_weight_fn(item)
    return items, bin_to_items, item_to_bin

import_from_path

import_from_path(path: str) -> str

Import the module and the name of the function/class separated by : Examples: path = "/path/to/file.py:func_or_class_name" path = "/path/to/file:focn" path = "path.to.file:focn"

Parameters:

  • path

    (str) –

    path including the file path and object name

Source code in llmcompressor/utils/helpers.py
def import_from_path(path: str) -> str:
    """
    Import the module and the name of the function/class separated by :
    Examples:
      path = "/path/to/file.py:func_or_class_name"
      path = "/path/to/file:focn"
      path = "path.to.file:focn"
    :param path: path including the file path and object name
    :return Function or class object
    """
    original_path, class_name = path.split(":")
    _path = original_path

    path = original_path.split(".py")[0]
    path = re.sub(r"/+", ".", path)
    try:
        module = importlib.import_module(path)
    except ImportError:
        raise ImportError(f"Cannot find module with path {_path}")

    try:
        return getattr(module, class_name)
    except AttributeError:
        raise AttributeError(f"Cannot find {class_name} in {_path}")

is_package_available

is_package_available(
    package_name: str, return_version: bool = False
) -> Union[Tuple[bool, str], bool]

A helper function to check if a package is available and optionally return its version. This function enforces a check that the package is available and is not just a directory/file with the same name as the package.

inspired from: https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/utils/import_utils.py#L41

Parameters:

  • package_name

    (str) –

    The package name to check for

  • return_version

    (bool, default: False ) –

    True to return the version of the package if available

Returns:

  • Union[Tuple[bool, str], bool]

    True if the package is available, False otherwise or a tuple of (bool, version) if return_version is True

Source code in llmcompressor/utils/helpers.py
def is_package_available(
    package_name: str,
    return_version: bool = False,
) -> Union[Tuple[bool, str], bool]:
    """
    A helper function to check if a package is available
    and optionally return its version. This function enforces
    a check that the package is available and is not
    just a directory/file with the same name as the package.

    inspired from:
    https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/utils/import_utils.py#L41

    :param package_name: The package name to check for
    :param return_version: True to return the version of
        the package if available
    :return: True if the package is available, False otherwise or a tuple of
        (bool, version) if return_version is True
    """

    package_exists = importlib.util.find_spec(package_name) is not None
    package_version = "N/A"
    if package_exists:
        try:
            package_version = importlib.metadata.version(package_name)
            package_exists = True
        except importlib.metadata.PackageNotFoundError:
            package_exists = False
        logger.debug(f"Detected {package_name} version {package_version}")
    if return_version:
        return package_exists, package_version
    else:
        return package_exists

patch_transformers_logger_level

patch_transformers_logger_level(level: int = logging.ERROR)

Context under which the transformers logger's level is modified

This can be used with skip_weights_download to squelch warnings related to missing parameters in the checkpoint

Parameters:

  • level

    (int, default: ERROR ) –

    new logging level for transformers logger. Logs whose level is below this level will not be logged

Source code in llmcompressor/utils/dev.py
@contextlib.contextmanager
def patch_transformers_logger_level(level: int = logging.ERROR):
    """
    Context under which the transformers logger's level is modified

    This can be used with `skip_weights_download` to squelch warnings related to
    missing parameters in the checkpoint

    :param level: new logging level for transformers logger. Logs whose level is below
        this level will not be logged
    """
    transformers_logger = logging.getLogger("transformers.modeling_utils")
    restore_log_level = transformers_logger.getEffectiveLevel()

    transformers_logger.setLevel(level=level)
    yield
    transformers_logger.setLevel(level=restore_log_level)

skip_weights_download

skip_weights_download(
    model_class: Type[
        PreTrainedModel
    ] = AutoModelForCausalLM,
)

Context manager under which models are initialized without having to download the model weight files. This differs from init_empty_weights in that weights are allocated on to assigned devices with random values, as opposed to being on the meta device

Parameters:

  • model_class

    (Type[PreTrainedModel], default: AutoModelForCausalLM ) –

    class to patch, defaults to AutoModelForCausalLM

Source code in llmcompressor/utils/dev.py
@contextlib.contextmanager
def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM):
    """
    Context manager under which models are initialized without having to download
    the model weight files. This differs from `init_empty_weights` in that weights are
    allocated on to assigned devices with random values, as opposed to being on the meta
    device

    :param model_class: class to patch, defaults to `AutoModelForCausalLM`
    """
    original_fn = model_class.from_pretrained
    weights_files = [
        "*.bin",
        "*.safetensors",
        "*.pth",
        SAFE_WEIGHTS_INDEX_NAME,
        WEIGHTS_INDEX_NAME,
        "*.msgpack",
        "*.pt",
    ]

    @classmethod
    def patched(cls, *args, **kwargs):
        nonlocal tmp_dir

        # intercept model stub
        model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")

        # download files into tmp dir
        os.makedirs(tmp_dir, exist_ok=True)
        snapshot_download(
            repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files
        )

        # make an empty weights file to avoid errors
        weights_file_path = os.path.join(tmp_dir, "model.safetensors")
        save_file({}, weights_file_path, metadata={"format": "pt"})

        # load from tmp dir
        model = original_fn(tmp_dir, **kwargs)

        # replace model_path
        model.name_or_path = model_stub
        model.config._name_or_path = model_stub

        return model

    with (
        tempfile.TemporaryDirectory() as tmp_dir,
        patch_attr(model_class, "from_pretrained", patched),
        skip_weights_initialize(),
        patch_transformers_logger_level(),
    ):
        yield

targets_embeddings

targets_embeddings(
    model: PreTrainedModel,
    targets: NamedModules,
    check_input: bool = True,
    check_output: bool = True,
) -> bool

Returns True if the given targets target the word embeddings of the model

Parameters:

  • model

    (PreTrainedModel) –

    containing word embeddings

  • targets

    (NamedModules) –

    named modules to check

  • check_input

    (bool, default: True ) –

    whether to check if input embeddings are targeted

  • check_output

    (bool, default: True ) –

    whether to check if output embeddings are targeted

Returns:

  • bool

    True if embeddings are targeted, False otherwise

Source code in llmcompressor/utils/transformers.py
def targets_embeddings(
    model: PreTrainedModel,
    targets: NamedModules,
    check_input: bool = True,
    check_output: bool = True,
) -> bool:
    """
    Returns True if the given targets target the word embeddings of the model

    :param model: containing word embeddings
    :param targets: named modules to check
    :param check_input: whether to check if input embeddings are targeted
    :param check_output: whether to check if output embeddings are targeted
    :return: True if embeddings are targeted, False otherwise
    """
    input_embed, output_embed = get_embeddings(model)
    if (check_input and input_embed) is None or (check_output and output_embed is None):
        logger.warning(
            "Cannot check embeddings. If this model has word embeddings, please "
            "implement `get_input_embeddings` and `get_output_embeddings`"
        )
        return False

    targets = set(module for _, module in targets)
    return (check_input and input_embed in targets) or (
        check_output and output_embed in targets
    )

untie_word_embeddings

untie_word_embeddings(model: PreTrainedModel)

Untie word embeddings, if possible. This function raises a warning if embeddings cannot be found in the model definition.

The model config will be updated to reflect that embeddings are now untied

Parameters:

  • model

    (PreTrainedModel) –

    transformers model containing word embeddings

Source code in llmcompressor/utils/transformers.py
def untie_word_embeddings(model: PreTrainedModel):
    """
    Untie word embeddings, if possible. This function raises a warning if
    embeddings cannot be found in the model definition.

    The model config will be updated to reflect that embeddings are now untied

    :param model: transformers model containing word embeddings
    """
    input_embed, output_embed = get_embeddings(model)
    if input_embed is None or output_embed is None:
        logger.warning(
            "Cannot untie embeddings. If this model has word embeddings, please "
            "implement `get_input_embeddings` and `get_output_embeddings`"
        )
        return

    # clone data to untie
    for module in (input_embed, output_embed):
        weight = module.weight
        param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad)
        module.register_parameter("weight", param)

    # modify model config
    if hasattr(model.config, "tie_word_embeddings"):
        model.config.tie_word_embeddings = False

wait_for_comms

wait_for_comms(pending_comms: list[Work]) -> None

Block until all pending async distributed operations complete.

Calls wait() on each work handle, then clears the list in-place so it can be reused for the next batch of operations.

Parameters:

  • pending_comms

    (list[Work]) –

    mutable list of async communication handles (returned by dist.reduce, dist.broadcast, etc. with async_op=True). The list is cleared after all operations have completed.

Source code in llmcompressor/utils/dist.py
def wait_for_comms(pending_comms: list[dist.Work]) -> None:
    """Block until all pending async distributed operations complete.

    Calls ``wait()`` on each work handle, then clears the list in-place
    so it can be reused for the next batch of operations.

    :param pending_comms: mutable list of async communication handles
        (returned by ``dist.reduce``, ``dist.broadcast``, etc. with
        ``async_op=True``). The list is cleared after all operations
        have completed.
    """
    for comm in list(pending_comms):
        comm.wait()
    pending_comms.clear()