Skip to content

llmcompressor.pipelines.sequential.ast_utils.auto_wrapper

Classes:

  • AutoWrapper

    Automatically wraps untracable code according to the following patterns:

AutoWrapper

AutoWrapper(namespace: dict[str, Any], ignore: list[str])

Bases: NodeTransformer

Automatically wraps untracable code according to the following patterns:

The following patterns are automatically wrapped 1. If statements whose conditions cannot be statically evaluated 2. Ignored functions (_update_causal_mask) 3. Starred tuple unpacking 4. Starred argument unpacking

See also: https://github.com/vllm-project/llm-compressor/pull/1411

Methods:

  • auto_wrap

    Modify ast by automatically wrapping any untraceable code segments. Segments to

  • visit_Call

    Wrap any functions which use (4) variadic arguments or (2) match the ignore list

  • visit_Delete

    Remove any deleted names from self._local_names,

  • visit_FunctionDef

    Remove decorators which prevent forward function recompilation

  • visit_If

    Attempt to statically evaluate the condition of the if statement. If the

  • visit_Name

    Add any new names in self._local_names,

  • visit_Tuple

    (3) Wrap any tuples which use starred unpacking

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def __init__(self, namespace: dict[str, Any], ignore: list[str]):
    self.namespace = namespace
    self.ignore = ignore
    self._wrapper_fn_defs: list[ast.FunctionDef] = list()
    self._local_names = set()
    self._wrapped_counter = 0

auto_wrap

auto_wrap(tree: Module) -> ast.Module

Modify ast by automatically wrapping any untraceable code segments. Segments to wrap are determined through analysis of the code and basic pattern matching

Parameters:

  • tree

    (Module) –

    module containing a definition to an original forward function

Returns:

  • Module

    module with added wrapper function definitions and function calls

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def auto_wrap(self, tree: ast.Module) -> ast.Module:
    """
    Modify ast by automatically wrapping any untraceable code segments. Segments to
    wrap are determined through analysis of the code and basic pattern matching

    :param tree: module containing a definition to an original forward function
    :return: module with added wrapper function definitions and function calls
    """
    tree = self.visit(tree)
    for fn_def in self._wrapper_fn_defs:
        tree.body.insert(0, fn_def)

    return ast.fix_missing_locations(tree)

visit_Call

visit_Call(node: Call) -> ast.Call

Wrap any functions which use (4) variadic arguments or (2) match the ignore list

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Call(self, node: ast.Call) -> ast.Call:
    """
    Wrap any functions which use (4) variadic arguments or (2) match the ignore list
    """
    # check for variadic starred
    if any(isinstance(elem, ast.Starred) for elem in node.args):
        return self._wrap_if_possible(node)

    # attempt to evaluate caller and check against ignore list
    try:
        caller = self._eval_expr(node.func)

    except Exception:
        caller = None

    finally:
        if (
            isinstance(caller, (FunctionType, MethodType))
            and caller.__name__ in self.ignore
        ):
            return self._wrap_if_possible(node)

    return super().generic_visit(node)

visit_Delete

visit_Delete(node: Delete)

Remove any deleted names from self._local_names, which are used to determine function wrapper arguments

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Delete(self, node: ast.Delete):
    """
    Remove any deleted names from `self._local_names`,
    which are used to determine function wrapper arguments
    """
    ret = super().generic_visit(node)

    for target in node.targets:
        if isinstance(target, ast.Name):
            self._local_names.remove(target.id)

    return ret

visit_FunctionDef

visit_FunctionDef(node: FunctionDef) -> ast.FunctionDef

Remove decorators which prevent forward function recompilation For example, add_start_docstrings_to_model_forward

Because _wrapper_fn_defs are appended after visit finishes, this function will not affect wrapper functions

Parameters:

  • node

    (FunctionDef) –

    function definition whose decorators will be stripped

Returns:

  • FunctionDef

    function definition without decorators

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
    """
    Remove decorators which prevent forward function recompilation
    For example, add_start_docstrings_to_model_forward

    Because `_wrapper_fn_defs` are appended after `visit` finishes, this function
    will not affect wrapper functions

    :param node: function definition whose decorators will be stripped
    :return: function definition without decorators
    """
    node.decorator_list = [
        decorator_name
        for decorator_name in node.decorator_list
        if isinstance(decorator_name, ast.Name)
        and decorator_name.id in ("can_return_tuple",)  # modifies func signature
    ]

    if node.name == "forward":
        for arg in node.args.args:
            self._local_names.add(arg.arg)
        for arg in node.args.posonlyargs:
            self._local_names.add(arg.arg)
        for arg in node.args.kwonlyargs:
            self._local_names.add(arg.arg)
        if node.args.vararg:
            self._local_names.add(node.args.vararg.arg)
        if node.args.kwarg:
            self._local_names.add(node.args.kwarg.arg)
    return super().generic_visit(node)

visit_If

visit_If(node: If) -> ast.If | ast.Assign

Attempt to statically evaluate the condition of the if statement. If the condition can not be statically evaluated (1), then attmept to wrap the if statement

Parameters:

  • node

    (If) –

    if statement which may be wrapped

Returns:

  • If | Assign

    if the if statement cannot be statically evaluated, return the if statement with the condition replaced by True or False. Otherwise, return a wrapper function call

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_If(self, node: ast.If) -> ast.If | ast.Assign:
    """
    Attempt to statically evaluate the condition of the `if` statement. If the
    condition can not be statically evaluated (1), then attmept to wrap the `if`
    statement

    :param node: `if` statement which may be wrapped
    :return: if the `if` statement cannot be statically evaluated, return the
        `if` statement with the condition replaced by `True` or `False`.
        Otherwise, return a wrapper function call
    """
    try:
        value = bool(self._eval_expr(node.test))

        # force a wrap if any assignments occur within the if statement
        for expr in ast.walk(node):
            if isinstance(expr, ast.NamedExpr):
                raise Exception("If statement contains assignment")

    except Exception:
        return self._wrap_if_possible(node)

    else:
        node.test = ast.Constant(value=value)
        return super().generic_visit(node)

visit_Name

visit_Name(node: Name)

Add any new names in self._local_names, which are used to determine function wrapper arguments

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Name(self, node: ast.Name):
    """
    Add any new names in `self._local_names`,
    which are used to determine function wrapper arguments
    """
    if isinstance(node.ctx, ast.Store):
        self._local_names.add(node.id)

    return super().generic_visit(node)

visit_Tuple

visit_Tuple(node: Tuple) -> ast.Tuple | ast.Call

(3) Wrap any tuples which use starred unpacking

Source code in llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple | ast.Call:
    """
    (3) Wrap any tuples which use starred unpacking
    """
    if any(isinstance(elem, ast.Starred) for elem in node.elts):
        return self._wrap_if_possible(node)

    return super().generic_visit(node)