mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
There are many cases where it's more convenient to use overloads, but we've hesitated in doing so since we can't torchscript it directly. Luckily, it's pretty easy to strip overloads. See https://github.com/pytorch/functorch/pull/899 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80013 Approved by: https://github.com/zou3519
228 lines
7.7 KiB
Python
228 lines
7.7 KiB
Python
# Owner(s): ["oncall: fx"]
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch
|
|
import unittest
|
|
import warnings
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
|
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
|
|
|
|
from torch.testing._internal.common_device_type import ops
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
# Copied from functorch
|
|
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
|
|
return (op_name, variant_name, device_type, dtypes, True)
|
|
|
|
|
|
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
|
|
return (op_name, variant_name, device_type, dtypes, False)
|
|
|
|
|
|
def skipOps(test_case_name, base_test_name, to_skip):
|
|
all_opinfos = op_db
|
|
for xfail in to_skip:
|
|
op_name, variant_name, device_type, dtypes, expected_failure = xfail
|
|
matching_opinfos = [o for o in all_opinfos
|
|
if o.name == op_name and o.variant_test_name == variant_name]
|
|
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
|
|
for opinfo in matching_opinfos:
|
|
decorators = list(opinfo.decorators)
|
|
if expected_failure:
|
|
decorator = DecorateInfo(unittest.expectedFailure,
|
|
test_case_name, base_test_name,
|
|
device_type=device_type, dtypes=dtypes)
|
|
decorators.append(decorator)
|
|
else:
|
|
decorator = DecorateInfo(unittest.skip("Skipped!"),
|
|
test_case_name, base_test_name,
|
|
device_type=device_type, dtypes=dtypes)
|
|
decorators.append(decorator)
|
|
opinfo.decorators = tuple(decorators)
|
|
|
|
# This decorator doesn't modify fn in any way
|
|
def wrapped(fn):
|
|
return fn
|
|
return wrapped
|
|
|
|
|
|
USE_TORCHVISION = False
|
|
try:
|
|
import torchvision
|
|
USE_TORCHVISION = True
|
|
except ImportError:
|
|
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
|
|
"to install it with commands from pytorch.org, post-fixed with "
|
|
"`--no-deps` to avoid overwriting the pytorch installation",
|
|
UserWarning)
|
|
|
|
|
|
class TestProxyTensor(TestCase):
|
|
def test_make_fx(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
inp = torch.randn(3)
|
|
fx_f = make_fx(f)(inp)
|
|
|
|
new_inp = torch.randn(3)
|
|
self.assertEqual(fx_f(new_inp), f(new_inp))
|
|
|
|
def test_scalar_device(self, device):
|
|
def f(a, b):
|
|
return a + b
|
|
inps = [torch.randn(3, device=device), torch.tensor(5)]
|
|
fx_f = make_fx(f)(*inps)
|
|
self.assertEqual(fx_f(*inps), f(*inps))
|
|
|
|
|
|
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
|
|
def test_resnet18_backward_trace(self, device):
|
|
mod = torchvision.models.resnet18()
|
|
|
|
def f(x):
|
|
out = mod(x)
|
|
out.sum().backward()
|
|
return [a.grad for a in mod.parameters()]
|
|
|
|
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
|
|
grads = f(inp)
|
|
|
|
mod.zero_grad()
|
|
mod(inp).sum().backward()
|
|
grads2 = [a.grad for a in mod.parameters()]
|
|
self.assertEqual(grads, grads2)
|
|
|
|
def test_proxy_tensor(self):
|
|
def f_grad(x):
|
|
val = x.cos().cos().sum()
|
|
return torch.autograd.grad(val, x)
|
|
|
|
def f_backward(x):
|
|
val = x.cos().cos().sum()
|
|
val.backward()
|
|
return x.grad
|
|
|
|
for f in [f_grad, f_backward]:
|
|
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
|
|
inp = torch.randn(3, requires_grad=True)
|
|
traced_graph_out = traced_graph(inp)
|
|
assert inp.grad is None
|
|
torch.testing.assert_close(traced_graph_out, f(inp))
|
|
|
|
def test_mode_tracing_factory_function(self):
|
|
def f(x):
|
|
return x + torch.randn(x.shape)
|
|
|
|
# default behavior should trace factory functions
|
|
traced = make_fx(f)(torch.randn(3))
|
|
self.assertTrue(
|
|
any(
|
|
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
def test_mode_tracing_factory_function_no_factory_function(self):
|
|
def f(x):
|
|
return x + torch.randn(x.shape)
|
|
|
|
traced = make_fx(f, trace_factory_functions=False)(torch.randn(3)) # default behavior should not trace factory functions
|
|
self.assertFalse(
|
|
any(
|
|
isinstance(node.target, torch._ops.OpOverloadPacket) and node.target._qualified_op_name == 'aten::randn'
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
make_fx_failures = {
|
|
xfail('allclose'),
|
|
xfail('equal'),
|
|
xfail('linalg.eigvals'),
|
|
xfail('nn.functional.max_pool1d', device_type='cpu'),
|
|
# empty
|
|
skip('new_empty'),
|
|
skip('empty_like'),
|
|
skip('empty'),
|
|
# flaky
|
|
skip('linalg.lstsq', 'grad_oriented'),
|
|
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
|
|
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
|
|
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
|
|
skip('linalg.lstsq'), # flaky, probably just a precision issue
|
|
# data-dependent control flow
|
|
xfail('cov'),
|
|
xfail('istft'),
|
|
xfail('nanquantile'),
|
|
xfail('nn.functional.gaussian_nll_loss'),
|
|
xfail('quantile'),
|
|
xfail('tensor_split'),
|
|
xfail('corrcoef'),
|
|
# Masked failures (creating a scalar tensor just to call `.item` on it)
|
|
xfail('_masked.amax'),
|
|
xfail('_masked.amax'),
|
|
xfail('_masked.amin'),
|
|
xfail('_masked.argmax'),
|
|
xfail('_masked.argmin'),
|
|
xfail('_masked.cumprod'),
|
|
xfail('_masked.cumsum'),
|
|
xfail('_masked.log_softmax'),
|
|
xfail('_masked.logaddexp'),
|
|
xfail('_masked.logsumexp'),
|
|
xfail('_masked.mean'),
|
|
xfail('_masked.median'),
|
|
xfail('_masked.norm'),
|
|
xfail('_masked.prod'),
|
|
xfail('_masked.softmax'),
|
|
xfail('_masked.softmin'),
|
|
xfail('_masked.std'),
|
|
xfail('_masked.sum'),
|
|
xfail('_masked.var'),
|
|
|
|
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
|
|
xfail('sparse.sampled_addmm'),
|
|
|
|
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
|
|
xfail('nn.functional.ctc_loss'),
|
|
}
|
|
|
|
|
|
class TestProxyTensorOpInfo(TestCase):
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
|
|
)
|
|
def test_make_fx_exhaustive(self, device, dtype, op):
|
|
|
|
def f(args, kwargs):
|
|
return op.op(*args, **kwargs)
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
|
new_f = None
|
|
for sample_input in sample_inputs_itr:
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
|
|
new_f = make_fx(f, trace_factory_functions=True)(args, kwargs)
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
|
|
arg.uniform_(0, 1)
|
|
try:
|
|
old_out = f(args, kwargs)
|
|
except Exception:
|
|
continue
|
|
new_out = wrapper_set_seed(new_f, args, kwargs)
|
|
self.assertEqual(new_out, old_out)
|
|
|
|
|
|
|
|
only_for = ("cpu")
|
|
instantiate_device_type_tests(
|
|
TestProxyTensor,
|
|
globals(),
|
|
only_for=only_for,
|
|
)
|
|
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|