mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Experimental MetaTensorTracer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76003 Approved by: https://github.com/jansel
This commit is contained in:
parent
f1f99ab310
commit
15e36f03ad
|
|
@ -26,6 +26,7 @@ from torch.fx.experimental.partitioner_utils import (
|
|||
)
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
|
||||
from torch.fx.experimental.meta_tracer import MetaTracer
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.operator_schemas import (
|
||||
|
|
@ -667,6 +668,28 @@ class TestFXExperimental(JitTestCase):
|
|||
# Confirm that the output is correct
|
||||
self.assertEqual(traced(3, 3), m(3, 3))
|
||||
|
||||
def test_meta_tracer(self):
|
||||
mt = MetaTracer()
|
||||
|
||||
class MetaTracerTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
|
||||
self.layernorm = torch.nn.LayerNorm(16)
|
||||
|
||||
def forward(self, x):
|
||||
emb = self.emb(x)
|
||||
lol = self.layernorm(emb)
|
||||
return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
|
||||
|
||||
mttm = MetaTracerTestModule()
|
||||
for BS in [15, 35]:
|
||||
x = torch.zeros(BS, dtype=torch.long).random_(42)
|
||||
graph = mt.trace(mttm, meta_args={'x' : x.to(device='meta')})
|
||||
gm = torch.fx.GraphModule(mttm, graph)
|
||||
torch.testing.assert_close(gm(x), mttm(x))
|
||||
|
||||
|
||||
def test_call_to_assert_with_msg(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
|
|
|
|||
256
torch/fx/experimental/meta_tracer.py
Normal file
256
torch/fx/experimental/meta_tracer.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
import torch
|
||||
import torch.fx
|
||||
import warnings
|
||||
import functools
|
||||
import builtins
|
||||
|
||||
from typing import Callable, Dict
|
||||
|
||||
def embedding_override(self, input):
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
|
||||
|
||||
|
||||
def nn_layernorm_override(self, input):
|
||||
return input
|
||||
|
||||
|
||||
def torch_relu_override(x):
|
||||
return x
|
||||
|
||||
|
||||
def torch_nn_relu_override(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def functional_relu_override(x, inplace=False):
|
||||
assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
|
||||
return x
|
||||
|
||||
|
||||
def torch_where_override(condition, x, y):
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
|
||||
|
||||
|
||||
def torch_abs_override(input, *, out=None):
|
||||
assert out is None, 'Dont support in-place abs for MetaTensor analysis'
|
||||
return input
|
||||
|
||||
manual_meta_overrides : Dict[Callable, Callable] = {
|
||||
torch.nn.Embedding: embedding_override,
|
||||
torch.nn.LayerNorm: nn_layernorm_override,
|
||||
torch.relu: torch_relu_override,
|
||||
torch.nn.functional.relu: functional_relu_override,
|
||||
torch.nn.ReLU: torch_nn_relu_override,
|
||||
torch.where: torch_where_override,
|
||||
torch.abs: torch_abs_override,
|
||||
}
|
||||
|
||||
def gen_constructor_wrapper(target):
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
proxy = None
|
||||
|
||||
def check_has_proxy(v):
|
||||
if isinstance(v, torch.fx.Proxy):
|
||||
nonlocal proxy
|
||||
proxy = v
|
||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||||
|
||||
if proxy is not None:
|
||||
return proxy.tracer.create_proxy('call_function', target, args, kwargs)
|
||||
else:
|
||||
return target(*args, **kwargs)
|
||||
return wrapper, target
|
||||
|
||||
class MetaProxy(torch.fx.Proxy):
|
||||
def install_tensor_meta(self, tensor_meta):
|
||||
self._tensor_meta = tensor_meta
|
||||
|
||||
def size(self, dim=None):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
return self._tensor_meta.size(*[dim] if dim else [])
|
||||
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
|
||||
|
||||
def dim(self):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
return self._tensor_meta.dim()
|
||||
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
return self._tensor_meta.shape
|
||||
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
return self._tensor_meta.dtype
|
||||
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||||
# replace these values with a constant 'meta'
|
||||
return MetaDeviceAttribute(self, 'device')
|
||||
|
||||
def __getattr__(self, k):
|
||||
if k == '_tensor_meta':
|
||||
return self.__getattribute__(k)
|
||||
# note: not added to the graph yet, if this is a method call
|
||||
# we peephole optimize to the method invocation
|
||||
return MetaAttribute(self, k)
|
||||
|
||||
class MetaAttribute(MetaProxy):
|
||||
def __init__(self, root, attr: str):
|
||||
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
self._node = None
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
class MetaDeviceAttribute(MetaAttribute):
|
||||
pass
|
||||
|
||||
def proxys_to_metas(v):
|
||||
if isinstance(v, MetaDeviceAttribute):
|
||||
return 'meta'
|
||||
if isinstance(v, torch.fx.Proxy):
|
||||
assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
|
||||
assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
|
||||
return v._tensor_meta
|
||||
return v
|
||||
|
||||
class MetaTracer(torch.fx.Tracer):
|
||||
allow_insert_stateless_mods : bool = True
|
||||
|
||||
_TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
|
||||
|
||||
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
|
||||
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
|
||||
if kind == 'placeholder' and target in self.meta_args:
|
||||
rv.install_tensor_meta(self.meta_args[target])
|
||||
return rv
|
||||
|
||||
if target in self.orig_fns:
|
||||
# NOTE: tensor constructors in PyTorch define the `device` argument as
|
||||
# *kwargs-only*. That is why this works. If you add methods to
|
||||
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
|
||||
# this will break and you will likely see issues where we cannot infer
|
||||
# the size of the output.
|
||||
if 'device' in kwargs:
|
||||
kwargs['device'] = 'meta'
|
||||
|
||||
try:
|
||||
args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
|
||||
kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
|
||||
|
||||
if kind == 'call_function':
|
||||
meta_target = manual_meta_overrides.get(target, target)
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
elif kind == 'call_method':
|
||||
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
|
||||
elif kind == 'call_module':
|
||||
assert hasattr(self, 'orig_forward')
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
mod = self.root.get_submodule(target)
|
||||
mod_type = type(mod)
|
||||
if mod_type in manual_meta_overrides:
|
||||
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
|
||||
else:
|
||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
elif kind == 'get_attr':
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
attr_itr = self.root
|
||||
atoms = target.split('.')
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
assert isinstance(attr_itr, torch.Tensor)
|
||||
meta_out = attr_itr.to(device='meta')
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
else:
|
||||
return rv
|
||||
|
||||
# TODO
|
||||
assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
|
||||
rv.install_tensor_meta(meta_out)
|
||||
except Exception as e:
|
||||
warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
|
||||
|
||||
return rv
|
||||
|
||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if getattr(self, '_disable_module_getattr', False):
|
||||
return attr_val
|
||||
else:
|
||||
return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
self.orig_forward = forward
|
||||
return super().call_module(m, forward, args, kwargs)
|
||||
|
||||
def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
|
||||
"""
|
||||
Helper method which tries to insert a module that was not declared as submodule.
|
||||
"""
|
||||
idx = 0
|
||||
mod_name = mod.__class__.__name__.lower()
|
||||
path = f"{mod_name}_{idx}"
|
||||
while hasattr(self.root, path):
|
||||
path = f"{mod_name}_{idx}"
|
||||
idx += 1
|
||||
|
||||
self.root.add_module(path, mod)
|
||||
return path
|
||||
|
||||
def path_of_module(self, mod: torch.nn.Module) -> str:
|
||||
try:
|
||||
return super().path_of_module(mod)
|
||||
except NameError as e:
|
||||
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
self.prev_module = path
|
||||
return path
|
||||
raise
|
||||
|
||||
def proxy(self, node):
|
||||
return MetaProxy(node, self)
|
||||
|
||||
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
|
||||
assert isinstance(meta_args, dict)
|
||||
self.meta_args = meta_args
|
||||
|
||||
self.patched_torch_methods = {
|
||||
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
self.orig_fns = set()
|
||||
|
||||
for name, (wrapper, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
self.orig_fns.add(orig)
|
||||
|
||||
try:
|
||||
return super().trace(root, concrete_args)
|
||||
finally:
|
||||
for name, (_, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
Loading…
Reference in New Issue
Block a user