Skip to content

llmcompressor.pipelines.sequential.transformers_helpers

Classes:

  • HFCacheProxy –

    Proxy that represents an instance of transformers.cache_utils.Cache.

  • HFProxy –

    Proxy that uses metadata to handle data-dependent control-flow.

  • HFProxyableClassMeta –

    Metaclass that creates a class with its main methods wrapped to be proxyable.

  • HFTracer –

    Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the

Functions:

  • gen_constructor_wrapper –

    Wraps target to be proxyable. Used for tensor creators like torch.ones, torch.arange and so on.

  • symbolic_trace –

    Performs symbolic tracing on the model.

HFCacheProxy

Bases: HFProxy

Proxy that represents an instance of transformers.cache_utils.Cache.

HFProxy

Bases: Proxy

Proxy that uses metadata to handle data-dependent control-flow.

HFProxyableClassMeta

Bases: type

Metaclass that creates a class with its main methods wrapped to be proxyable.

HFTracer

HFTracer(autowrap_modules=(math,), autowrap_functions=())

Bases: Tracer

Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the regular PyTorch torch.fx.Proxy.

Methods:

  • keys –

    Called when a proxy object is has the keys() method called.

  • path_of_module –

    Helper method to find the qualified name of mod in the Module hierarchy of root. For example, if root has

  • trace –

    Traces root and returns the corresponding FX torch.fx.Graph representation. root can either be a

Source code in llmcompressor/pipelines/sequential/transformers_helpers.py
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
    super().__init__(
        autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions
    )

keys

keys(obj: Proxy) -> Any

Called when a proxy object is has the keys() method called. This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in your custom tracer.

Source code in llmcompressor/pipelines/sequential/transformers_helpers.py
@compatibility(is_backward_compatible=True)
def keys(self, obj: "Proxy") -> Any:
    """Called when a proxy object is has the keys() method called.
    This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
    your custom tracer.
    """
    attribute = HFAttribute(obj, "keys")()
    if obj.node.target.startswith("**"):
        return attribute._metadata
    return attribute

path_of_module

path_of_module(mod: Module) -> str

Helper method to find the qualified name of mod in the Module hierarchy of root. For example, if root has a submodule named foo, which has a submodule named bar, passing bar into this function will return the string "foo.bar".

Args: mod (str): The Module to retrieve the qualified name for.

Source code in llmcompressor/pipelines/sequential/transformers_helpers.py
def path_of_module(self, mod: nn.Module) -> str:
    """
    Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
    a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
    string "foo.bar".

    Args:
        mod (str): The `Module` to retrieve the qualified name for.
    """
    try:
        return super().path_of_module(mod)
    except NameError as e:
        if (
            self.allow_insert_stateless_mods
            and len(list(mod.parameters())) == 0
            and len(list(mod.buffers())) == 0
        ):
            path = self._insert_module_as_submodule(mod)
            return path
        raise e

trace

trace(
    root: Module | Callable[..., Any],
    concrete_args: dict[str, Any] | None = None,
    dummy_inputs: dict[str, Any] | None = None,
    complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph

Traces root and returns the corresponding FX torch.fx.Graph representation. root can either be a torch.nn.Module instance or a Python callable. Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create a torch.nn.Module instance to use as the root and add embedded constants to.

Args: root (torch.nn.Module or Callable): Either a torch.nn.Module`` or a function to be traced through. If root is not a [~transformers.PreTrainedModel], thendummy_inputsmust be passed, otherwise tracing will fail. concrete_args (dict[str, Any], optional): Concrete arguments that should not be treated as Proxies dummy_inputs (dict[str, Any], optional): The dummy inputs needed to handle data-dependent control-flow if root is not a [~transformers.PreTrainedModel]. It can also be used when root is a [~transformers.PreTrainedModel] to specify custom dummy inputs for a subset or all the model inputs. complete_concrete_args_with_inputs_not_in_dummy_inputs (bool, optional, defaults to True): If True, and dummy_inputs is specified, every argument that root can take that is not in dummy_inputs and not in concrete_args will be added to concrete_args, otherwise does nothing.

Returns: torch.fx.Graph: A FX torch.fx.Graph representing the semantics of the passed-in root.

Source code in llmcompressor/pipelines/sequential/transformers_helpers.py
def trace(
    self,
    root: torch.nn.Module | Callable[..., Any],
    concrete_args: dict[str, Any] | None = None,
    dummy_inputs: dict[str, Any] | None = None,
    complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
    """
    Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
    `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
    the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
    `torch.nn.Module` instance to use as the root and add embedded constants to.

    Args:
        root (`torch.nn.Module` or  `Callable`):
            Either a `torch.nn.Module`` or a function to be traced through. If root is not a
            [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
        concrete_args (`dict[str, Any], *optional*):
            Concrete arguments that should not be treated as Proxies
        dummy_inputs (`dict[str, Any]`, *optional*):
            The dummy inputs needed to handle data-dependent control-flow if `root` is not a
            [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
            [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
        complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
            If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
            `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

    Returns:
        `torch.fx.Graph`:
            A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

    """
    sig = inspect.signature(
        root.forward if isinstance(root, torch.nn.Module) else root
    )

    if concrete_args is None:
        concrete_args = {}

    if (
        dummy_inputs is not None
        and complete_concrete_args_with_inputs_not_in_dummy_inputs
    ):
        for param in sig.parameters.values():
            if param.name in dummy_inputs:
                continue
            if param.default is inspect.Parameter.empty:
                raise ValueError(
                    f"You need to specify a default value for the parameter {param.name}."
                )
        concrete_args.update(
            {
                p.name: p.default
                for p in sig.parameters.values()
                if (p.name not in dummy_inputs and p.name not in concrete_args)
            }
        )

    input_names = sig.parameters.keys() - concrete_args.keys()

    # Creating a random input shape to generate dummy inputs.
    batch_size = _generate_random_int()
    sequence_length = _generate_random_int()
    shape = [batch_size, sequence_length]

    if root.__class__.__name__ in get_values(
        MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
    ):
        num_choices = _generate_random_int(low=2, high=5)
        shape.insert(1, num_choices)

    inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
    for input_name in input_names:
        if input_name in inputs:
            continue
        # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
        # be able to use HFTracer._generate_dummy_input.
        if isinstance(root, self.supported_archs) or type(
            root
        ).__qualname__.startswith(("_deserialize_graph_module", "_CodeOnlyModule")):
            inputs.update(
                self._generate_dummy_input(
                    root, input_name, shape, input_names=input_names
                )
            )
        else:
            raise RuntimeError(
                f"Could not generate input named {input_name} for because root is not a"
                " transformers.PreTrainedModel."
            )

    def to_meta(value):
        if isinstance(value, torch.Tensor):
            return value.to("meta")
        return value

    concrete_metas = pytree.tree_map(to_meta, inputs)

    for param in sig.parameters.values():
        if (
            param.kind == inspect.Parameter.VAR_KEYWORD
            and param.name not in input_names
        ):
            concrete_metas[f"**{param.name}"] = {}
    self.meta_args = concrete_metas

    global _CURRENT_TRACER
    _CURRENT_TRACER = self
    with self.patch_for_tracing(root):
        try:
            self.graph = super().trace(root, concrete_args=concrete_args)
        finally:
            _CURRENT_TRACER = None

    # This is necessary because concrete args are added as input to the traced module since
    # https://github.com/pytorch/pytorch/pull/55888.
    for node in self.graph.nodes:
        if node.op == "placeholder":
            # Removing default values for inputs as the forward pass will fail with them.
            if node.target in input_names:
                node.args = ()
                # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
                # It cannot infer on the attributes and methods the input should have, and fails.
                node.type = torch.Tensor
            # It is a concrete arg so it is not used and should be removed.
            else:
                to_visit = [node]
                to_delete = collections.OrderedDict()
                while to_visit:
                    n = to_visit.pop(0)
                    to_delete[n] = None
                    to_visit += list(n.users.keys())

                for user in reversed(to_delete.keys()):
                    self.graph.erase_node(user)

        # TODO: solves GraphModule creation.
        # Without this, return type annotation "Tuple" is causing code execution failure.
        if node.op == "output":
            node.type = None

    return self.graph

gen_constructor_wrapper

gen_constructor_wrapper(
    target: Callable,
) -> tuple[Callable, Callable]

Wraps target to be proxyable. Used for tensor creators like torch.ones, torch.arange and so on.

Source code in llmcompressor/pipelines/sequential/transformers_helpers.py
def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
    """
    Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
    """
    wrapper = create_wrapper(target, "call_function")
    return wrapper, target

symbolic_trace

symbolic_trace(
    model: PreTrainedModel,
    input_names: list[str] | None = None,
    disable_check: bool = False,
    tracer_cls: type[HFTracer] = HFTracer,
) -> GraphModule

Performs symbolic tracing on the model.

Args: model ([PretrainedModel]): The model to trace. input_names (list[str], optional): The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead. disable_check (bool, optional, defaults to False): If True, no check is done before trying to trace the model, this is mostly usesul for debugging purposes. tracer_cls (Type[HFTracer], optional, defaults to HFTracer): The tracer class to use for instantiating the tracer. If unset, HFTracer is used instead.

Returns: torch.fx.GraphModule: A GraphModule constructed by recording operations seen while tracing the model.

Example:

```python
from transformers.utils.fx import symbolic_trace

traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
```
Source code in llmcompressor/pipelines/sequential/transformers_helpers.py
def symbolic_trace(
    model: "PreTrainedModel",
    input_names: list[str] | None = None,
    disable_check: bool = False,
    tracer_cls: type[HFTracer] = HFTracer,
) -> GraphModule:
    """
    Performs symbolic tracing on the model.

    Args:
        model ([`PretrainedModel`]):
            The model to trace.
        input_names (`list[str]`, *optional*):
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
        tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
            The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.

    Returns:
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

        ```python
        from transformers.utils.fx import symbolic_trace

        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
    if input_names is None:
        input_names = model.dummy_inputs.keys()

    input_names = list(input_names)
    concrete_args = get_concrete_args(model, input_names)

    if not disable_check:
        check_if_model_is_supported(model)

    if "past_key_values" in input_names and not getattr(
        model.config, "use_cache", False
    ):
        logger.warning(
            "`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
            "unexpected behavior."
        )
    if "past_key_values" not in input_names and getattr(
        model.config, "use_cache", False
    ):
        logger.warning(
            "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
            "model.config.use_cache = False."
        )
        model.config.use_cache = False

    # Tracing.
    tracer = tracer_cls()
    traced_graph = tracer.trace(model, concrete_args=concrete_args)
    traced = torch.fx.GraphModule(model, traced_graph)

    traced.config = model.config
    # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
    # _generate_dummy_input, where the model class is needed.
    traced.class_for_deserialization = model.__class__
    traced.device = model.device

    return traced