Replace the forward method of the given module with a recompiled version where all untraceble code patterns are removed and replaced with torch.fx function wrappers.
For a list of untraceable code patterns and their explainations, see https://github.com/vllm-project/llm-compressor/pull/1411
Parameters:
-
module
(Module) – module whose forward method should be replaced
-
ignore
(list[str]) – explicit list of function names to wrap
Source code in llmcompressor/pipelines/sequential/ast_helpers.py
| @contextlib.contextmanager
def autowrap_forward(module: torch.nn.Module, ignore: list[str]):
"""
Replace the `forward` method of the given module with a recompiled version where
all untraceble code patterns are removed and replaced with torch.fx function
wrappers.
For a list of untraceable code patterns and their explainations, see
https://github.com/vllm-project/llm-compressor/pull/1411
:param module: module whose forward method should be replaced
:param ignore: explicit list of function names to wrap
"""
# check forward method is implemented
if module.forward.__name__ == "_forward_unimplemented":
raise ValueError(
"Cannot calibrate model which does not implement `forward` method. Please "
"either implement a forward method on the model, or pass a submodule to "
"`oneshot`. For example, `oneshot(model.thinker, ...)`"
)
# get source code of module forward
source = inspect.getsource(module.forward)
source = textwrap.dedent(source)
tree = ast.parse(source)
# construct namespace for our new code
defining_module = sys.modules[module.__class__.__module__]
namespace = defining_module.__dict__.copy()
namespace.update({"torch.fx.wrap": torch.fx.wrap})
namespace.update({"self": module})
# autowrap untraceable code
auto_wrapper = AutoWrapper(namespace, ignore)
tree = auto_wrapper.auto_wrap(tree)
source = ast.unparse(tree)
# compile new forward function from autowrapped code
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
code = compile(source, filename=filename, mode="exec")
with append_autowrap_source_on_fail():
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap
# enable better tracebacks if autowrapped code fails
linecache.cache[filename] = (
len(source),
None,
[line + "\n" for line in source.splitlines()],
filename,
)
# patch forward with autowrapped forward
new_forward = namespace["forward"].__get__(module)
with patch_attr(module, "forward", new_forward):
yield
|