# Owner(s): ["oncall: fx"] from torch.testing._internal.common_utils import TestCase, run_tests import torch import unittest import warnings from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed from torch.testing._internal.common_device_type import ops from torch.fx.experimental.proxy_tensor import make_fx # Copied from functorch def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): return (op_name, variant_name, device_type, dtypes, True) def skip(op_name, variant_name='', *, device_type=None, dtypes=None): return (op_name, variant_name, device_type, dtypes, False) def skipOps(test_case_name, base_test_name, to_skip): all_opinfos = op_db for xfail in to_skip: op_name, variant_name, device_type, dtypes, expected_failure = xfail matching_opinfos = [o for o in all_opinfos if o.name == op_name and o.variant_test_name == variant_name] assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" for opinfo in matching_opinfos: decorators = list(opinfo.decorators) if expected_failure: decorator = DecorateInfo(unittest.expectedFailure, test_case_name, base_test_name, device_type=device_type, dtypes=dtypes) decorators.append(decorator) else: decorator = DecorateInfo(unittest.skip("Skipped!"), test_case_name, base_test_name, device_type=device_type, dtypes=dtypes) decorators.append(decorator) opinfo.decorators = tuple(decorators) # This decorator doesn't modify fn in any way def wrapped(fn): return fn return wrapped USE_TORCHVISION = False try: import torchvision USE_TORCHVISION = True except ImportError: warnings.warn("Couldn't import torchvision. Some of our tests use it, try " "to install it with commands from pytorch.org, post-fixed with " "`--no-deps` to avoid overwriting the pytorch installation", UserWarning) class TestProxyTensor(TestCase): def test_make_fx(self, device): def f(x): return torch.sin(x) inp = torch.randn(3) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp)) def test_scalar_device(self, device): def f(a, b): return a + b inps = [torch.randn(3, device=device), torch.tensor(5)] fx_f = make_fx(f)(*inps) self.assertEqual(fx_f(*inps), f(*inps)) @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_resnet18_backward_trace(self, device): mod = torchvision.models.resnet18() def f(x): out = mod(x) out.sum().backward() return [a.grad for a in mod.parameters()] inp = torch.randn(3, 3, 250, 250, requires_grad=True) grads = f(inp) mod.zero_grad() mod(inp).sum().backward() grads2 = [a.grad for a in mod.parameters()] self.assertEqual(grads, grads2) def test_proxy_tensor(self): def f_grad(x): val = x.cos().cos().sum() return torch.autograd.grad(val, x) def f_backward(x): val = x.cos().cos().sum() val.backward() return x.grad for f in [f_grad, f_backward]: traced_graph = make_fx(f)(torch.randn(3, requires_grad=True)) inp = torch.randn(3, requires_grad=True) traced_graph_out = traced_graph(inp) assert inp.grad is None torch.testing.assert_close(traced_graph_out, f(inp)) def test_mode_tracing_factory_function(self): def f(x): return x + torch.randn(x.shape) # default behavior should trace factory functions traced = make_fx(f)(torch.randn(3)) self.assertTrue( any( isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn' for node in traced.graph.nodes ) ) def test_mode_tracing_factory_function_no_factory_function(self): def f(x): return x + torch.randn(x.shape) traced = make_fx(f, trace_factory_functions=False)(torch.randn(3)) # default behavior should not trace factory functions self.assertFalse( any( isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn' for node in traced.graph.nodes ) ) make_fx_failures = { xfail('allclose'), xfail('equal'), xfail('linalg.eigvals'), xfail('nn.functional.max_pool1d', device_type='cpu'), # empty skip('new_empty'), skip('empty_like'), skip('empty'), # flaky skip('linalg.lstsq', 'grad_oriented'), skip('nn.functional.max_unpool1d', '', device_type='cpu'), skip('nn.functional.max_unpool2d', '', device_type='cpu'), skip('nn.functional.max_unpool3d', '', device_type='cpu'), skip('linalg.lstsq'), # flaky, probably just a precision issue xfail('histogram'), xfail('scatter'), # data-dependent control flow xfail('cov'), xfail('istft'), xfail('nanquantile'), xfail('nn.functional.gaussian_nll_loss'), xfail('quantile'), xfail('tensor_split'), xfail('corrcoef'), # Masked failures (creating a scalar tensor just to call `.item` on it) xfail('_masked.amax'), xfail('_masked.amax'), xfail('_masked.amin'), xfail('_masked.argmax'), xfail('_masked.argmin'), xfail('_masked.cumprod'), xfail('_masked.cumsum'), xfail('_masked.log_softmax'), xfail('_masked.logaddexp'), xfail('_masked.logsumexp'), xfail('_masked.mean'), xfail('_masked.median'), xfail('_masked.norm'), xfail('_masked.prod'), xfail('_masked.softmax'), xfail('_masked.softmin'), xfail('_masked.std'), xfail('_masked.sum'), xfail('_masked.var'), # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse xfail('sparse.sampled_addmm'), # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse xfail('nn.functional.ctc_loss'), } class TestProxyTensorOpInfo(TestCase): @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures ) def test_make_fx_exhaustive(self, device, dtype, op): def f(args, kwargs): return op.op(*args, **kwargs) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) new_f = None for sample_input in sample_inputs_itr: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs new_f = make_fx(f, trace_factory_functions=True)(args, kwargs) for arg in args: if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: arg.uniform_(0, 1) try: old_out = f(args, kwargs) except Exception: continue new_out = wrapper_set_seed(new_f, args, kwargs) self.assertEqual(new_out, old_out) only_for = ("cpu") instantiate_device_type_tests( TestProxyTensor, globals(), only_for=only_for, ) instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for) if __name__ == '__main__': run_tests()