mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview. Let's break down the parts of this PR: * Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode. * Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.) * New `lift_fresh` view operator. Note that like lift, we have to manually write the functionalize kernel for these functions. * The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.) This is mildly BC-breaking if anyone was previously interposing on at::lift, but this operator was relatively new and I checked functorch which has no explicit reference to lift. So I think it should not be too disruptive. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192 Approved by: https://github.com/samdow, https://github.com/bdhirsh
338 lines
11 KiB
Python
338 lines
11 KiB
Python
# Owner(s): ["oncall: fx"]
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch
|
|
import unittest
|
|
import warnings
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
|
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
|
|
from torch._subclasses.fake_tensor import DynamicOutputShapeException
|
|
|
|
from torch._decomp import decomposition_table
|
|
from torch.testing._internal.common_device_type import ops
|
|
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
|
|
from torch.utils._pytree import tree_map
|
|
|
|
# Copied from functorch
|
|
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
|
|
return (op_name, variant_name, device_type, dtypes, True)
|
|
|
|
|
|
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
|
|
return (op_name, variant_name, device_type, dtypes, False)
|
|
|
|
|
|
def skipOps(test_case_name, base_test_name, to_skip):
|
|
all_opinfos = op_db
|
|
for xfail in to_skip:
|
|
op_name, variant_name, device_type, dtypes, expected_failure = xfail
|
|
matching_opinfos = [o for o in all_opinfos
|
|
if o.name == op_name and o.variant_test_name == variant_name]
|
|
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
|
|
for opinfo in matching_opinfos:
|
|
decorators = list(opinfo.decorators)
|
|
if expected_failure:
|
|
decorator = DecorateInfo(unittest.expectedFailure,
|
|
test_case_name, base_test_name,
|
|
device_type=device_type, dtypes=dtypes)
|
|
decorators.append(decorator)
|
|
else:
|
|
decorator = DecorateInfo(unittest.skip("Skipped!"),
|
|
test_case_name, base_test_name,
|
|
device_type=device_type, dtypes=dtypes)
|
|
decorators.append(decorator)
|
|
opinfo.decorators = tuple(decorators)
|
|
|
|
# This decorator doesn't modify fn in any way
|
|
def wrapped(fn):
|
|
return fn
|
|
return wrapped
|
|
|
|
|
|
USE_TORCHVISION = False
|
|
try:
|
|
import torchvision
|
|
USE_TORCHVISION = True
|
|
except ImportError:
|
|
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
|
|
"to install it with commands from pytorch.org, post-fixed with "
|
|
"`--no-deps` to avoid overwriting the pytorch installation",
|
|
UserWarning)
|
|
|
|
|
|
def _create_new_input(x):
|
|
if not isinstance(x, torch.Tensor):
|
|
return x
|
|
if x.dtype != torch.float:
|
|
return x + 1
|
|
if x.is_leaf:
|
|
return torch.rand_like(x, requires_grad=True)
|
|
else:
|
|
return torch.rand_like(x)
|
|
|
|
class TestProxyTensor(TestCase):
|
|
def _test(self, f, inps):
|
|
fx_f = make_fx(f)(*inps)
|
|
new_inps = tree_map(_create_new_input, inps)
|
|
self.assertEqual(fx_f(*new_inps), f(*new_inps))
|
|
|
|
def test_make_fx_simple(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
self._test(f, (torch.randn(3),))
|
|
|
|
def test_scalar_device(self, device):
|
|
def f(a, b):
|
|
return a + b
|
|
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
|
|
|
|
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
|
def test_resnet18_backward_trace(self, device):
|
|
mod = torchvision.models.resnet18()
|
|
|
|
def f(x):
|
|
for a in mod.parameters():
|
|
a.grad = None
|
|
out = mod(x)
|
|
out.sum().backward()
|
|
return [a.grad for a in mod.parameters()]
|
|
|
|
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
|
|
self._test(f, [inp])
|
|
|
|
def test_proxy_tensor(self):
|
|
def f_grad(x):
|
|
val = x.cos().cos().sum()
|
|
return torch.autograd.grad(val, x)
|
|
|
|
def f_backward(x):
|
|
val = x.cos().cos().sum()
|
|
val.backward()
|
|
return x.grad
|
|
|
|
for f in [f_grad, f_backward]:
|
|
self._test(f, [torch.randn(3, requires_grad=True)])
|
|
|
|
def test_inplace_metadata(self):
|
|
def f(x):
|
|
x = x.clone()
|
|
x.unsqueeze_(-1)
|
|
assert x.shape[-1] == 1
|
|
return x
|
|
|
|
self._test(f, [torch.randn(5)])
|
|
|
|
def test_mode_tracing_factory_function(self):
|
|
def f(x):
|
|
return x + torch.randn(x.shape)
|
|
|
|
# default behavior should trace factory functions
|
|
traced = make_fx(f)(torch.randn(3))
|
|
self.assertTrue(
|
|
any(
|
|
node.target == torch.ops.aten.randn.default
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
def test_mode_tracing_factory_function_no_factory_function(self):
|
|
def f(x):
|
|
return x + torch.randn(x.shape)
|
|
# setting the flag to false should not trace factory functions
|
|
traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
|
|
self.assertFalse(
|
|
any(
|
|
node.target == torch.ops.aten.randn.default
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
def test_make_fx_overloads(self):
|
|
def f(x):
|
|
return x.cos() + torch.randn(x.shape)
|
|
|
|
traced = make_fx(f)(torch.randn(3))
|
|
|
|
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
|
|
for node in traced.graph.nodes if node.op == 'call_function']))
|
|
|
|
def test_tensor_constants(self):
|
|
def f():
|
|
val = torch.tensor(float('inf'))
|
|
return torch.full((100, 100), val)
|
|
|
|
self._test(f, [])
|
|
|
|
def test_constant_proxy_tensor(self):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
def f():
|
|
val = torch.tensor(float('inf'))
|
|
return torch.full((100, 100), val)
|
|
|
|
g = make_fx(f)()
|
|
self.assertEqual(g(), f())
|
|
|
|
def test_constant_proxy_tensor_mut(self):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
def f():
|
|
val = torch.tensor(float(1))
|
|
val.add_(2)
|
|
return torch.full((100, 100), val)
|
|
|
|
g = make_fx(f)()
|
|
self.assertEqual(g(), f())
|
|
# In case we mutated shared state in the g graph!
|
|
self.assertEqual(g(), f())
|
|
|
|
g = make_fx(f, use_fake=True)()
|
|
self.assertEqual(g(), f())
|
|
# In case we mutated shared state in the g graph!
|
|
self.assertEqual(g(), f())
|
|
|
|
def test_use_fake_and_tensor(self):
|
|
def f(x, y):
|
|
z = torch.tensor([2.0, 3.0])
|
|
return x + y + z
|
|
|
|
g = make_fx(f, use_fake=True)(torch.randn(2), torch.randn(2))
|
|
x, y = torch.randn(2), torch.randn(2)
|
|
self.assertEqual(g(x, y), f(x, y))
|
|
|
|
def test_decomposition_interpreter(self):
|
|
def fn(x):
|
|
return torch.nn.functional.silu(x)
|
|
|
|
x = torch.rand((4, 4))
|
|
fx_module = make_fx(fn, decomposition_table=None)(x)
|
|
|
|
found_silu = False
|
|
for n in fx_module.graph.nodes:
|
|
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
|
|
found_silu = True
|
|
|
|
self.assertTrue(found_silu)
|
|
|
|
new_graph = torch.fx.Graph()
|
|
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
|
|
DecompositionInterpreter(
|
|
fx_module,
|
|
new_graph=new_graph,
|
|
decomposition_table=silu_decomp_table,
|
|
).run(x)
|
|
|
|
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
|
|
|
|
for n in decomposed_module.graph.nodes:
|
|
self.assertTrue(n.target != torch.ops.aten.silu)
|
|
self.assertTrue(n.target != torch.ops.aten.silu.default)
|
|
|
|
self.assertEqual(fx_module(x), decomposed_module(x))
|
|
|
|
make_fx_failures = {
|
|
# unknown
|
|
xfail('allclose'),
|
|
xfail('equal'),
|
|
xfail('linalg.eigvals'),
|
|
xfail('nn.functional.max_pool1d', device_type='cpu'),
|
|
# empty
|
|
skip('new_empty'),
|
|
skip('empty_like'),
|
|
skip('empty'),
|
|
# flaky
|
|
skip('linalg.lstsq', 'grad_oriented'),
|
|
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
|
|
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
|
|
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
|
|
skip('linalg.lstsq'), # flaky, probably just a precision issue
|
|
|
|
# data-dependent control flow
|
|
xfail('cov'),
|
|
xfail('istft'),
|
|
xfail('nanquantile'),
|
|
xfail('nn.functional.gaussian_nll_loss'),
|
|
xfail('quantile'),
|
|
xfail('tensor_split'),
|
|
xfail('corrcoef'),
|
|
|
|
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
|
|
xfail('sparse.sampled_addmm'),
|
|
|
|
# ???
|
|
xfail('nn.functional.ctc_loss'),
|
|
# Sparse tensors are not supported with faketensors for now
|
|
xfail('to_sparse'),
|
|
# segfaults
|
|
skip('block_diag'),
|
|
}
|
|
|
|
fake_tensor_failures = {
|
|
# Needs complex-value support
|
|
xfail('polar'),
|
|
xfail('complex'),
|
|
xfail('linalg.eig'),
|
|
# FakeTensor fallback doesn't work
|
|
xfail('linalg.matrix_power'),
|
|
xfail('segment_reduce', 'lengths'),
|
|
xfail('multinomial'),
|
|
xfail('mvlgamma', 'mvlgamma_p_1'),
|
|
xfail('mvlgamma', 'mvlgamma_p_3'),
|
|
xfail('mvlgamma', 'mvlgamma_p_5'),
|
|
xfail('cholesky'),
|
|
xfail('cholesky_inverse'),
|
|
# ASAN failures due to divide by 0
|
|
skip('nn.functional.nll_loss'),
|
|
}
|
|
|
|
|
|
def _test_make_fx_helper(self, device, dtype, op, use_fake):
|
|
def f(args, kwargs):
|
|
return op.op(*args, **kwargs)
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
|
new_f = None
|
|
for sample_input in sample_inputs_itr:
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
|
|
try:
|
|
new_f = make_fx(f, use_fake=use_fake)(args, kwargs)
|
|
except DynamicOutputShapeException as e:
|
|
self.skipTest("Dynamic output shape operation in trace")
|
|
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
|
|
arg.uniform_(0, 1)
|
|
try:
|
|
old_out = f(args, kwargs)
|
|
except Exception:
|
|
continue
|
|
new_out = wrapper_set_seed(new_f, args, kwargs)
|
|
self.assertEqual(new_out, old_out)
|
|
|
|
class TestProxyTensorOpInfo(TestCase):
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
|
|
def test_make_fx_exhaustive(self, device, dtype, op):
|
|
_test_make_fx_helper(self, device, dtype, op, False)
|
|
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
|
|
def test_make_fx_fake_exhaustive(self, device, dtype, op):
|
|
_test_make_fx_helper(self, device, dtype, op, True)
|
|
|
|
|
|
|
|
only_for = ("cpu")
|
|
instantiate_device_type_tests(
|
|
TestProxyTensor,
|
|
globals(),
|
|
only_for=only_for,
|
|
)
|
|
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|