pytorch/test/test_proxy_tensor.py
Horace He 159d459c50 Switched to tracing overloads by default (#80013)
There are many cases where it's more convenient to use overloads, but we've hesitated in doing so since we can't torchscript it directly.

Luckily, it's pretty easy to strip overloads. See https://github.com/pytorch/functorch/pull/899
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80013
Approved by: https://github.com/zou3519
2022-06-22 18:55:06 +00:00

228 lines
7.7 KiB
Python

# Owner(s): ["oncall: fx"]
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import unittest
import warnings
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch.testing._internal.common_device_type import ops
from torch.fx.experimental.proxy_tensor import make_fx
# Copied from functorch
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, True)
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, False)
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db
for xfail in to_skip:
op_name, variant_name, device_type, dtypes, expected_failure = xfail
matching_opinfos = [o for o in all_opinfos
if o.name == op_name and o.variant_test_name == variant_name]
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
for opinfo in matching_opinfos:
decorators = list(opinfo.decorators)
if expected_failure:
decorator = DecorateInfo(unittest.expectedFailure,
test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes)
decorators.append(decorator)
else:
decorator = DecorateInfo(unittest.skip("Skipped!"),
test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes)
decorators.append(decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
class TestProxyTensor(TestCase):
def test_make_fx(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_scalar_device(self, device):
def f(a, b):
return a + b
inps = [torch.randn(3, device=device), torch.tensor(5)]
fx_f = make_fx(f)(*inps)
self.assertEqual(fx_f(*inps), f(*inps))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()
def f(x):
out = mod(x)
out.sum().backward()
return [a.grad for a in mod.parameters()]
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
grads = f(inp)
mod.zero_grad()
mod(inp).sum().backward()
grads2 = [a.grad for a in mod.parameters()]
self.assertEqual(grads, grads2)
def test_proxy_tensor(self):
def f_grad(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
def f_backward(x):
val = x.cos().cos().sum()
val.backward()
return x.grad
for f in [f_grad, f_backward]:
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
inp = torch.randn(3, requires_grad=True)
traced_graph_out = traced_graph(inp)
assert inp.grad is None
torch.testing.assert_close(traced_graph_out, f(inp))
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
# default behavior should trace factory functions
traced = make_fx(f)(torch.randn(3))
self.assertTrue(
any(
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
for node in traced.graph.nodes
)
)
def test_mode_tracing_factory_function_no_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
traced = make_fx(f, trace_factory_functions=False)(torch.randn(3)) # default behavior should not trace factory functions
self.assertFalse(
any(
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
for node in traced.graph.nodes
)
)
make_fx_failures = {
xfail('allclose'),
xfail('equal'),
xfail('linalg.eigvals'),
xfail('nn.functional.max_pool1d', device_type='cpu'),
# empty
skip('new_empty'),
skip('empty_like'),
skip('empty'),
# flaky
skip('linalg.lstsq', 'grad_oriented'),
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
skip('linalg.lstsq'), # flaky, probably just a precision issue
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nanquantile'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('quantile'),
xfail('tensor_split'),
xfail('corrcoef'),
# Masked failures (creating a scalar tensor just to call `.item` on it)
xfail('_masked.amax'),
xfail('_masked.amax'),
xfail('_masked.amin'),
xfail('_masked.argmax'),
xfail('_masked.argmin'),
xfail('_masked.cumprod'),
xfail('_masked.cumsum'),
xfail('_masked.log_softmax'),
xfail('_masked.logaddexp'),
xfail('_masked.logsumexp'),
xfail('_masked.mean'),
xfail('_masked.median'),
xfail('_masked.norm'),
xfail('_masked.prod'),
xfail('_masked.softmax'),
xfail('_masked.softmin'),
xfail('_masked.std'),
xfail('_masked.sum'),
xfail('_masked.var'),
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('nn.functional.ctc_loss'),
}
class TestProxyTensorOpInfo(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
)
def test_make_fx_exhaustive(self, device, dtype, op):
def f(args, kwargs):
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
for sample_input in sample_inputs_itr:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
new_f = make_fx(f, trace_factory_functions=True)(args, kwargs)
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)
try:
old_out = f(args, kwargs)
except Exception:
continue
new_out = wrapper_set_seed(new_f, args, kwargs)
self.assertEqual(new_out, old_out)
only_for = ("cpu")
instantiate_device_type_tests(
TestProxyTensor,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()