Skip to content

llmcompressor.pipelines.cache

Classes:

  • IntermediateValue

    Dataclass which recursively defines offloaded values and which device to onload to

  • IntermediatesCache

    Cache which stores intermediate values (activations) produced by batched, sequential

  • OverrideEqMode

    When using a torch.Tensor as a key in a dictionary, the equality

IntermediateValue dataclass

IntermediateValue(
    value: Tensor | "IntermediateValue" | Any,
    device: device | None,
)

Dataclass which recursively defines offloaded values and which device to onload to

Parameters:

  • value

    (Tensor | 'IntermediateValue' | Any) –

    either an offloaded Tensor, an primative value, or a recursable value

  • device

    (device | None) –

    if the value is a Tensor, then the device to onload the tensor to, otherwise None

IntermediatesCache

IntermediatesCache(
    batch_intermediates: list[IntermediateValues]
    | None = None,
    offload_device: device | None = "cpu",
)

Cache which stores intermediate values (activations) produced by batched, sequential execution of models. Values are offloaded to the offload_device when stored in the cache and onloaded to their original device when fetched from the cache. If offload_device is None, values will not be offloaded at all.

Currently supports nested offloading of dataclass instances and tuples

Construct using empty and from_dataloader class methods

Methods:

  • append

    Append new values to the cache. The new values will be assigned the next

  • delete

    Delete values from the cache

  • empty

    Construct an empty cache

  • fetch

    Fetch values belonging to a batch

  • from_dataloader

    Initialize a cache with data from the provided dataloader

  • size

    Returns the memory used by cached values, keyed by device, in bytes

  • update

    Update/put values belonging to a batch

Source code in llmcompressor/pipelines/cache.py
def __init__(
    self,
    batch_intermediates: list[IntermediateValues] | None = None,
    offload_device: torch.device | None = "cpu",
):
    self.batch_intermediates = batch_intermediates or []
    self.offload_device = offload_device

append

append(values: dict[str, Any])

Append new values to the cache. The new values will be assigned the next available batch index

Parameters:

  • values

    (dict[str, Any]) –

    dictionary mapping keys to values used for update

Source code in llmcompressor/pipelines/cache.py
def append(self, values: dict[str, Any]):
    """
    Append new values to the cache. The new values will be assigned the next
    available batch index

    :param values: dictionary mapping keys to values used for update
    """
    batch_index = len(self.batch_intermediates)
    self.batch_intermediates.append({})
    self.update(batch_index, values)

delete

delete(
    batch_index: int,
    consumed_names: list[str] | None = None,
)

Delete values from the cache

Parameters:

  • batch_index

    (int) –

    index of batch whose values will be deleted

  • consumed_names

    (list[str] | None, default: None ) –

    list of keys whose values will be deleted, defaults to removing all keys

Source code in llmcompressor/pipelines/cache.py
def delete(self, batch_index: int, consumed_names: list[str] | None = None):
    """
    Delete values from the cache

    :param batch_index: index of batch whose values will be deleted
    :param consumed_names: list of keys whose values will be deleted, defaults to
        removing all keys
    """
    intermediates = self.batch_intermediates[batch_index]

    if consumed_names is None:
        consumed_names = list(intermediates.keys())

    for name in consumed_names:
        del intermediates[name]

empty classmethod

empty(num_batches: int, offload_device: device)

Construct an empty cache

Parameters:

  • num_batches

    (int) –

    the expected number of batches to be stored

  • offload_device

    (device) –

    device to offload values to

Source code in llmcompressor/pipelines/cache.py
@classmethod
def empty(cls, num_batches: int, offload_device: torch.device):
    """
    Construct an empty cache

    :param num_batches: the expected number of batches to be stored
    :param offload_device: device to offload values to
    """
    batch_intermediates = [{} for _ in range(num_batches)]
    return cls(batch_intermediates, offload_device)

fetch

fetch(
    batch_index: int, input_names: list[str] | None = None
) -> dict[str, Any]

Fetch values belonging to a batch

Parameters:

  • batch_index

    (int) –

    index of batch whose values are being fetched

  • input_names

    (list[str] | None, default: None ) –

    list of keys whose values are being fetched

Returns:

  • dict[str, Any]

    dictionary mapping keys to onloaded values

Source code in llmcompressor/pipelines/cache.py
def fetch(
    self, batch_index: int, input_names: list[str] | None = None
) -> dict[str, Any]:
    """
    Fetch values belonging to a batch

    :param batch_index: index of batch whose values are being fetched
    :param input_names: list of keys whose values are being fetched
    :return: dictionary mapping keys to onloaded values
    """
    intermediates = self.batch_intermediates[batch_index]

    return {
        key: self._onload_value(subgraph_input)
        for key, subgraph_input in intermediates.items()
        if input_names is None or key in input_names
    }

from_dataloader classmethod

from_dataloader(
    dataloader: DataLoader,
    model_device: device = torch.device("cpu"),
    offload_device: device | None = torch.device("cpu"),
)

Initialize a cache with data from the provided dataloader

This method iterates through all batches in the dataloader and offloads them to the specified device. For faster cache preparation, consider: - Increasing batch_size to reduce the number of iterations - Using num_workers > 0 in the DataLoader for parallel loading (e.g. the calibration DataLoader from format_calibration_data uses dataloader_num_workers; when > 0, pin_memory and prefetch_factor are also set where applicable, which speeds both cache build and calibration) - Ensuring data preprocessing is done before creating the dataloader

Parameters:

  • dataloader

    (DataLoader) –

    dataloader which generates values to be cached

  • model_device

    (device, default: device('cpu') ) –

    device which values will be onloaded to when fetched

  • offload_device

    (device | None, default: device('cpu') ) –

    device to offload values to

Source code in llmcompressor/pipelines/cache.py
@classmethod
def from_dataloader(
    cls,
    dataloader: torch.utils.data.DataLoader,
    model_device: torch.device = torch.device("cpu"),
    offload_device: torch.device | None = torch.device("cpu"),
):
    """
    Initialize a cache with data from the provided dataloader

    This method iterates through all batches in the dataloader and offloads
    them to the specified device. For faster cache preparation, consider:
    - Increasing batch_size to reduce the number of iterations
    - Using num_workers > 0 in the DataLoader for parallel loading (e.g. the
      calibration DataLoader from format_calibration_data uses
      dataloader_num_workers; when > 0, pin_memory and prefetch_factor are
      also set where applicable, which speeds both cache build and calibration)
    - Ensuring data preprocessing is done before creating the dataloader

    :param dataloader: dataloader which generates values to be cached
    :param model_device: device which values will be onloaded to when fetched
    :param offload_device: device to offload values to
    """
    batch_intermediates = [
        {
            key: cls._offload_value(value, offload_device, model_device)
            for key, value in batch.items()
        }
        for batch in tqdm(dataloader, desc="Preparing cache")
    ]

    return cls(batch_intermediates, offload_device)

size

size() -> dict[torch.device, int]

Returns the memory used by cached values, keyed by device, in bytes

Returns:

  • dict[device, int]

    dictionary mapping torch device to number of bytes in cache

Source code in llmcompressor/pipelines/cache.py
def size(self) -> dict[torch.device, int]:
    """
    Returns the memory used by cached values, keyed by device, in bytes

    :return: dictionary mapping torch device to number of bytes in cache
    """
    sizes = defaultdict(lambda: 0)
    memo = set()

    def _size_helper(intermediate: IntermediateValue) -> int:
        value = intermediate.value

        match value:
            case torch.Tensor():
                if value not in memo:
                    sizes[value.device] += value.nbytes
                memo.add(value)
            case list() | tuple():
                for v in value:
                    _size_helper(v)
            case dict():
                for v in value.values():
                    _size_helper(v)
            case _ if is_dataclass(value):
                for field in fields(value):
                    _size_helper(getattr(value, field.name))
            case _:
                # this handles primitive values that don't match any other cases
                sizes[torch.device("cpu")] += sys.getsizeof(value, 0)

    for intermediates in self.batch_intermediates:
        for value in intermediates.values():
            _size_helper(value)

    return dict(sizes)

update

update(batch_index: int, values: dict[str, Any])

Update/put values belonging to a batch

Parameters:

  • batch_index

    (int) –

    index of batch whose values will be updated

  • values

    (dict[str, Any]) –

    dictionary mapping keys to values used for update

Source code in llmcompressor/pipelines/cache.py
def update(self, batch_index: int, values: dict[str, Any]):
    """
    Update/put values belonging to a batch

    :param batch_index: index of batch whose values will be updated
    :param values: dictionary mapping keys to values used for update
    """
    device = self.offload_device
    intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
    self.batch_intermediates[batch_index].update(intermediates)

OverrideEqMode

Bases: TorchDispatchMode

When using a torch.Tensor as a key in a dictionary, the equality check must return a single value instead of a torch.Tensor of bool values. Use this override context for such cases, to swap out the torch.eq equality check for a check on id

a = torch.tensor([1,2,3]) b = torch.tensor([1,2,3]) a == b tensor([True, True, True]) with OverrideEqMode(): ... a == b tensor(True)