pytorch/test/test_ops.py
kshitij12345 49c2da0ee0 [testing] improve broadcasts_input error message (#58295)
Summary:
Context:
The Error message when `broadcasts_input` is marked incorrectly is uninformative [See Error Currently]
https://github.com/pytorch/pytorch/pull/57941#discussion_r631749435

Error Currently
```
Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 326, in test_variant_consistency_eager
    _test_consistency_helper(samples, variants)
  File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 310, in _test_consistency_helper
    variant_forward = variant(cloned,
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 227, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 164, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: RuntimeError not raised
```

Error After PR
```
Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 329, in test_variant_consistency_eager
    _test_consistency_helper(samples, variants)
  File "/home/kshiteej/Pytorch/pytorch_i0_promotion/test/test_ops.py", line 313, in _test_consistency_helper
    variant_forward = variant(cloned,
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 227, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.8/unittest/case.py", line 164, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: RuntimeError not raised : inplace variant either allowed resizing or you have marked the sample SampleInput(input=Tensor, args=(tensor([[[ 2.1750, -8.5027, -3.1403, -6.9942,  3.2609],
         [-2.5057, -5.9123, -5.4633,  6.1203, -8.2124],
         [-3.5802, -8.4869, -6.0700,  2.3431, -8.1955],
         [-7.3316,  1.3248, -6.8661,  7.1483, -8.0719],
         [ 4.5977, -4.0448, -6.2044, -2.1314, -8.4956]],

        [[ 3.2769, -8.4360,  1.2826,  7.1749,  4.7653],
         [-0.2816, -2.5997, -4.7659, -3.7814,  3.9704],
         [-2.1778, -3.8117, -6.0276, -0.8423, -5.9646],
         [ 8.6544, -3.0922,  0.2558, -4.9318, -4.7596],
         [ 4.5583,  4.3830,  5.8793,  0.9713, -2.1481]],

        [[-1.0447,  0.9334,  7.6405, -4.8933, -7.4010],
         [ 7.7168, -8.4266, -5.5980, -6.9368,  7.1309],
         [-8.7720, -5.0890, -0.4975,  1.9518,  1.7074],
         [-8.5783,  8.5510, -8.5459, -3.5451,  8.4319],
         [ 8.5052, -8.9149, -6.6298, -1.2750, -5.7367]],

        [[-6.5625,  8.2795, -4.9311,  1.9501, -7.1777],
         [-8.4035,  1.1136, -7.6418, -7.0726, -2.8281],
         [ 4.2668, -0.2883, -6.2246,  2.3396,  1.2911],
         [ 4.6550, -1.9525,  4.4873, -3.8061, -0.8653],
         [-3.4256,  4.4423,  8.2937, -5.3456, -4.2624]],

        [[ 7.6128, -6.3932,  4.7131, -5.4938,  6.4792],
         [-6.5385,  2.4385,  4.5570,  3.7803, -8.3281],
         [-2.9785, -4.4745, -1.1778, -8.9324,  1.3663],
         [ 3.7437,  3.5171, -6.3135, -8.4519, -2.7033],
         [-5.0568, -8.4630, -4.2870, -3.7284, -1.5238]]], device='cuda:0',
       dtype=torch.float32, requires_grad=True),), broadcasts_input=True) incorrectly with `broadcasts_self=True
```

**NOTE**:
Printing the sample looks very verbose and it may be hard to figure out which sample is incorrectly configured if there are multiple samples with similar input shapes.

Two Options to make this error less verbose
* Don't print the sample and just print `inplace variant either allowed resizing or you have marked one of the sample incorrectly with broadcasts_self=True`
* Have some mechanism to name samples which will be printed in the `repr` (which will need extra machinery)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58295

Reviewed By: ngimel

Differential Revision: D28627308

Pulled By: mruberry

fbshipit-source-id: b3bdeacac3cf9c0d984f0b85410ecce474291d20
2021-05-25 21:14:17 -07:00

744 lines
36 KiB
Python

from functools import partial, wraps
import warnings
import torch
from torch.testing import \
(FileCheck, floating_and_complex_types_and)
from torch.testing._internal.common_utils import \
(TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor,
gradcheck, gradgradcheck)
from torch.testing._internal.common_methods_invocations import \
(op_db, method_tests)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, onlyCPU, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
# Get names of all the operators which have entry in `method_tests` (legacy testing infra)
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
# Tests that apply to all operators
class TestOpInfo(TestCase):
exact_dtype = True
# Verifies that ops have their unsupported dtypes
# registered correctly by testing that each claimed unsupported dtype
# throws a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.unsupported)
def test_unsupported_dtypes(self, device, dtype, op):
# sample_inputs can have a function for generating the input that doesn't work for specified dtype
# https://github.com/pytorch/pytorch/issues/49024
with self.assertRaises(RuntimeError):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
op(sample.input, *sample.args, **sample.kwargs)
# Verifies that ops have their supported dtypes
# registered correctly by testing that each claimed supported dtype
# does NOT throw a runtime error
# In addition verifies that the generated sample_inputs have the requested device and dtype
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.supported)
def test_supported_dtypes(self, device, dtype, op):
for sample in op.sample_inputs(device, dtype):
op(sample.input, *sample.args, **sample.kwargs)
# NOTE: only check the first tensor in the iterable of tensors
sample_input = sample.input[0] if is_iterable_of_tensors(sample.input) else sample.input
self.assertTrue(sample_input.dtype == dtype)
self.assertTrue(sample_input.device.type == self.device_type)
# Verifies that backward for each supported floating or complex dtype
# does NOT throw a runtime error.
# TODO: support multi-tensor outputs
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.supported_backward,
allowed_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16))
def test_supported_backward(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
for sample in op.sample_inputs(device, dtype, requires_grad=True):
result = op(sample.input, *sample.args, **sample.kwargs)
if not isinstance(result, torch.Tensor):
continue
result.sum().backward()
# Verifies that ops do not have an entry in
# `method_tests` (legacy testing infra).
@onlyCPU
@ops(op_db, allowed_dtypes=[torch.float32])
def test_duplicate_method_tests(self, device, dtype, op):
self.assertFalse(op.name in method_tested_operators)
# gradcheck requires double precision
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=[torch.double, torch.cdouble])
class TestGradients(TestCase):
exact_dtype = True
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
def is_inplace(variant):
if hasattr(variant, "__wrapped__"):
return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace()
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
if sample.broadcasts_input and is_inplace(variant):
continue
# Note on TensorList inputs
#
# gradcheck does not support TensorList inputs so here we pass TensorList
# inputs of size n as n single Tensor inputs to gradcheck and wrap the op
# in a function that puts the n Tensor inputs back into a TensorList
def fn(*inputs):
# Put tensors back into TensorList since we splat them when passing to gradcheck
if is_iterable_of_tensors(sample.input):
n = len(sample.input)
inputs = (inputs[:n], *inputs[n:])
output = op.gradcheck_wrapper(variant, *inputs, **sample.kwargs)
if sample.output_process_fn_grad is not None:
return sample.output_process_fn_grad(output)
return output
# Splat TensorList inputs into single Tensor inputs
gradcheck_args = (sample.input,) if isinstance(sample.input, torch.Tensor) else tuple(sample.input)
gradcheck_args += sample.args
if check == 'gradcheck':
self.assertTrue(gradcheck(fn, gradcheck_args,
check_batched_grad=op.check_batched_grad,
check_grad_dtypes=True,
nondet_tol=op.gradcheck_nondet_tol,
fast_mode=op.gradcheck_fast_mode,
check_forward_ad=check_forward_ad))
elif check == 'gradgradcheck':
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
self.assertTrue(gradgradcheck(fn, gradcheck_args,
gen_non_contig_grad_outputs=False,
check_batched_grad=op.check_batched_gradgrad,
check_grad_dtypes=True,
nondet_tol=op.gradcheck_nondet_tol,
fast_mode=op.gradcheck_fast_mode))
self.assertTrue(gradgradcheck(fn, gradcheck_args,
gen_non_contig_grad_outputs=True,
check_batched_grad=op.check_batched_gradgrad,
check_grad_dtypes=True,
nondet_tol=op.gradcheck_nondet_tol,
fast_mode=op.gradcheck_fast_mode))
else:
self.assertTrue(False, msg="Unknown check requested!")
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False):
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad)
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
def _skip_helper(self, op, device, dtype):
if not op.supports_autograd:
self.skipTest("Skipped! autograd not supported.")
if not op.supports_complex_autograd(torch.device(device).type) and dtype.is_complex:
self.skipTest("Skipped! Complex autograd not supported.")
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
self._grad_test_helper(device, dtype, op, op.get_op())
# Method grad (and gradgrad, see below) tests are disabled since they're
# costly and redundant with function grad (and gradgad) tests
# @_gradcheck_ops(op_db)
# def test_method_grad(self, device, dtype, op):
# self._skip_helper(op, device, dtype)
# self._grad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
def test_inplace_grad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant or not op.supports_inplace_autograd:
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad:
self.skipTest("Skipped! Operation does not support gradgrad")
self._gradgrad_test_helper(device, dtype, op, op.get_op())
# Test that gradients of gradients are properly raising
@_gradcheck_ops(op_db)
def test_fn_fail_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if op.supports_gradgrad:
self.skipTest("Skipped! Operation does support gradgrad")
err_msg = r"derivative for .* is not implemented"
with self.assertRaisesRegex(RuntimeError, err_msg):
self._gradgrad_test_helper(device, dtype, op, op.get_op())
# Method gradgrad (and grad, see above) tests are disabled since they're
# costly and redundant with function gradgrad (and grad) tests
# @_gradcheck_ops(op_db)
# def test_method_gradgrad(self, device, dtype, op):
# self._skip_helper(op, device, dtype)
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant or not op.supports_inplace_autograd:
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
@_gradcheck_ops(op_db)
def test_forward_mode_AD(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if op.supports_forward_ad:
self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
else:
err_msg = r"Trying to use forward AD with .* that does not support it\."
hint_msg = ("Running forward AD for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
with self.assertRaisesRegex(RuntimeError, err_msg, msg=hint_msg):
self._grad_test_helper(device, dtype, op, op.get_op(), check_forward_ad=True)
# Tests operators for consistency between JIT and eager, also checks
# correctness of JIT specific alias schemas and intended
# autodifferentiation behavior.
# Inherits from JitCommonTestCase instead of TestCase directly to share
# functionality with original test_jit.py method operator tests
class TestCommon(JitCommonTestCase):
exact_dtype = True
# variant testing is only done with torch.float and torch.cfloat to avoid
# excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float, torch.cfloat))
# alias testing is only done with troch.float for the same reason
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float,))
# Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (method, inplace)
# against eager's gold standard op function variant
@_variant_ops(op_db)
def test_variant_consistency_eager(self, device, dtype, op):
# Acquires variants (method variant, inplace variant, aliases)
method = op.method_variant
inplace = op.inplace_variant
# list of all inplace ops: inplace variant + alias inplace variants if exist
inplace_ops = [inplace, ]
variants = [method, inplace]
for a_op in op.aliases:
variants.append(a_op.op)
variants.append(a_op.method_variant)
variants.append(a_op.inplace_variant)
inplace_ops.append(a_op.inplace_variant)
inplace_variants = tuple(filter(None, inplace_ops))
variants = tuple(filter(None, variants))
_requires_grad = (op.supports_autograd and
(dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
def _test_consistency_helper(samples, variants):
for sample in samples:
# TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
# Computes function forward and backward values
tensor.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None
# Skips inplace variants if the output dtype is not the same as
# the input dtype
skip_inplace = False
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is not tensor.dtype):
skip_inplace = True
# TODO: backward consistency only supported for single tensor outputs
# TODO: backward consistency only checked on sample.input, not all
# tensor inputs
# TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output
if (op.supports_autograd and isinstance(expected_forward, torch.Tensor)
and (dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type))):
expected_forward.sum().backward()
expected_grad = tensor.grad
# Test eager consistency
for variant in variants:
# Skips inplace ops
if variant in inplace_ops and skip_inplace:
continue
# Compares variant's forward
# Note: copies the to-be-modified input when testing the inplace variant
tensor.grad = None
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
if variant in inplace_ops and sample.broadcasts_input:
with self.assertRaises(RuntimeError,
msg=('inplace variant either incorrectly allowed '
'resizing or you have marked the sample {}'
' incorrectly with `broadcasts_self=True'.format(sample.summary()))):
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
continue
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
self.assertEqual(expected_forward, variant_forward)
# Compares variant's backward
if expected_grad is not None and \
(variant not in inplace_ops or op.supports_inplace_autograd):
variant_forward.sum().backward()
self.assertEqual(expected_grad, tensor.grad)
_test_consistency_helper(samples, variants)
def _test_inplace_preserve_storage(samples, variants):
for sample in samples:
# Skips inplace variants if the output dtype is not the same as
# the input dtype
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
skip_inplace = False
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is not tensor.dtype):
skip_inplace = True
if skip_inplace:
return
for variant in variants:
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0]
data_ptr = inp_tensor.data_ptr()
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
# TODO Support non-tensor outputs if they exist for inplace ops
if (isinstance(variant_forward, torch.Tensor)):
self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0)
else:
self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported")
if len(inplace_ops) > 0:
inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples))
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
# Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (function, method, inplace)
# and runtimes (eager, traced, scripted).
# TODO WARNING: inplace x {traced, scripted} not currently tested
@_variant_ops(op_db)
def test_variant_consistency_jit(self, device, dtype, op):
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
op.supports_complex_autograd(torch.device(device).type))
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
for sample in samples:
# Acquires variants to test
func = op.get_op()
method = op.get_method()
variants = {
# TODO: inplace tests currently fail, fix and add inplace variant
'function': func, 'method': method,
}
# Test traced and scripted consistency
for func_type, variant in variants.items():
if variant is None:
continue
# Create accessor for script function variant
name = op.name + '_' if func_type == 'inplace' else op.name
# run with disable_autodiff_subgraph_inlining(True) to test
# autodiff support. Context manager forces the graph to contain
# DifferentiableGraph nodes if they are present
with disable_autodiff_subgraph_inlining():
# Check scripted forward, grad, and grad grad
script_fn = create_script_fn(self, name, func_type)
def out_fn(output):
# Processes the output for autograd
if sample.output_process_fn_grad is not None:
return sample.output_process_fn_grad(output)
return output
check_against_reference(self,
script_fn,
func,
out_fn,
(sample.input,) + sample.args,
sample.kwargs,
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
# Check traced forward, grad, and grad grad
traced_fn = create_traced_fn(self, variant)
check_against_reference(self,
traced_fn,
func,
out_fn,
(sample.input,) + sample.args,
sample.kwargs,
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
# Check alias annotation schema for correctness (make
# sure inputs that aren't supposed to be modified aren't)
# Note: only runs in float32 and int64 because schema isn't affected by dtype,
# so running it on all dtypes is would be excessive
if dtype in [torch.float32, torch.int32]:
check_alias_annotation(name, (sample.input,) + sample.args, sample.kwargs,
func_type=func_type, aten_name=op.aten_name)
# Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
if dtype is torch.float32:
# Sandcastle doesn't fuse nodes
if IS_SANDCASTLE:
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
fusible_nodes = []
else:
nonfusible_nodes = op.autodiff_nonfusible_nodes
fusible_nodes = op.autodiff_fusible_nodes
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
@_alias_ops((op for op in op_db if op.aliases))
def test_jit_alias_remapping(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
# [Scripting Data Preparation]
# Prepare data for test scripting
# Below we prepare strings of args/kwargs with and without type annotations.
# These strings are inserted into function template strings which is then torch scripted.
# - args string is ["t0"] corresponding to the "input" tensor required by the op
# - args_annot_kw is the string for the template function signature, for example,
# ["t0", "s0: float", "s1: bool", "max: float = 1.0", "min: float = 0.0"] ->
# def fn(t0, s0: float, s1: bool, max: float = 1.0, min: float = 0.0)
# - args_kw is the string of args/kwargs used to call the op, same as args_annot_kw but
# without type annotations
args = ["t0"]
def quote_strs(v):
if isinstance(v, str):
return f"'{v}'"
return str(v)
args_annot_kw = args + \
[f"s{i}: {type(v).__name__}" for i, v in enumerate(sample.args)] + \
[f"{k}: {type(v).__name__} = {quote_strs(v)}" for k, v in sample.kwargs.items()]
args_kw = args + \
[f"s{i}" for i in range(len(sample.args))] + \
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
# Prepare data for test tracing
sample_args_kwargs = ()
if len(sample.args) > 0:
sample_args_kwargs += (sample.args, )
if len(sample.kwargs) > 0:
sample_args_kwargs += (sample.kwargs, )
original_name = op.aten_name
original_name_inplace = original_name + "_"
expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype
for a_op in op.aliases:
inplace = a_op.inplace_variant
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
# Test scripting:
for variant in variants:
variant_name = variant.__name__
op_name = original_name_inplace if variant is inplace else original_name
if variant in method_or_inplace:
fn_template = '''
def _fn(t0{c}{args_annot_kw}):
return t0.{alias_name}({args_kw})
'''
# remove the first input tensor
script = fn_template.format(
c=", " if len(args_kw[1:]) > 1 or len(args_annot_kw[1:]) >= 1 else "",
args_annot_kw=", ".join(args_annot_kw[1:]),
args_kw=", ".join(args_kw[1:]),
alias_name=variant_name,
)
else:
fn_template = '''
def _fn({args_annot_kw}):
return variant({args_kw})
'''
script = fn_template.format(
args_annot_kw=", ".join(args_annot_kw),
args_kw=", ".join(args_kw),
)
scripted = torch.jit.CompilationUnit(script)._fn
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
try:
inp = clone_input_helper(sample.input)
scripted(inp, *sample.args, **sample.kwargs)
except Exception as e:
continue
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
inp = clone_input_helper(sample.input)
scripted(inp, *sample.args, **sample.kwargs)
inp = clone_input_helper(sample.input)
graph = scripted.graph_for(inp, *sample.args, **sample.kwargs)
FileCheck().check(op.aten_name).check_not(variant_name).run(graph)
# Test tracing:
for variant in variants:
variant_name = variant.__name__
op_name = original_name_inplace if variant is inplace else original_name
def _fn(*sample_args, **sample_kwargs):
return variant(*sample_args, **sample_kwargs)
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
traced = torch.jit.trace(_fn, *inp)
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
traced(*inp)
inp = (clone_input_helper(sample.input),) + sample_args_kwargs
graph = traced.graph_for(*inp)
FileCheck().check(op_name).check_not(variant_name).run(graph)
# Validates ops implement the correct out= behavior
# See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
# for a description of the correct behavior
# TODO: operations that support out= but don't support float
# are not covered by this test.
@ops(op_db, allowed_dtypes=(torch.float,))
def test_out(self, device, dtype, op):
# TODO: verify the op doesn't support the out= kwarg
if not op.supports_out:
self.skipTest("Skipped! Op doesn't support out= kwarg.")
# NOTE: only tests on first sample
samples = op.sample_inputs(device, dtype)
sample = samples[0]
# calls it normally to get the expected result
expected = op(sample.input, *sample.args, **sample.kwargs)
op_out = partial(op, sample.input, *sample.args, **sample.kwargs)
# Short-circuits if output is not a single tensor or an
# iterable of tensors
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
# A wrapper around map that works with single tensors and always
# instantiates the map. Used below to apply transforms to
# single tensor and iterable tensor outputs.
def _apply_out_transform(fn, out):
if isinstance(out, torch.Tensor):
return fn(out)
# assumes (see above) that out is an iterable of tensors
return tuple(map(fn, out))
# Case 0: out= with the correct shape, dtype, and device
# but NaN values for floating point and complex tensors, and
# maximum values for integer tensors.
# Expected behavior: out= values have no effect on the computation.
def _case_zero_transform(t):
try:
info = torch.iinfo(t.dtype)
return torch.full_like(t, info.max)
except TypeError as te:
# for non-integer types fills with NaN
return torch.full_like(t, float('nan'))
out = _apply_out_transform(_case_zero_transform, expected)
result = op_out(out=out)
self.assertEqual(expected, out)
# Checks that the returned value shares storage with out
# NOTE: only checks on the CPU and CUDA device types since some
# device types don't have storage
if self.device_type == 'cpu' or self.device_type == 'cuda':
if isinstance(out, torch.Tensor):
self.assertEqual(out.storage().data_ptr(), result.storage().data_ptr())
else:
for out_t, result_t in zip(out, result):
self.assertEqual(out_t.storage().data_ptr(), result_t.storage().data_ptr())
# Case 1: out= with the correct shape, dtype, and device,
# but noncontiguous.
# Expected behavior: strides are respected and `out` storage is not changed.
def _case_one_transform(t):
return make_tensor(t.shape,
dtype=t.dtype,
device=t.device,
noncontiguous=True)
# Extracts strides from a tensor or iterable of tensors into a tuple
def _extract_strides(out):
if isinstance(out, torch.Tensor):
return (out.stride(),)
# assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.stride(), out))
def _extract_data_ptrs(out):
if isinstance(out, torch.Tensor):
return (out.data_ptr(),)
# assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.data_ptr(), out))
out = _apply_out_transform(_case_one_transform, expected)
original_strides = _extract_strides(out)
original_ptrs = _extract_data_ptrs(out)
op_out(out=out)
final_strides = _extract_strides(out)
final_ptrs = _extract_data_ptrs(out)
self.assertEqual(expected, out)
self.assertEqual(original_strides, final_strides)
self.assertEqual(original_ptrs, final_ptrs)
# Case 2: out= with the correct dtype and device, but the wrong shape
# Expected behavior: resize with a warning.
def _case_two_transform(t):
wrong_shape = list(t.shape)
if len(wrong_shape) == 0:
# Handles scalar tensor case (empty list)
wrong_shape = [2]
else:
wrong_shape[-1] = wrong_shape[-1] + 1
return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)
out = _apply_out_transform(_case_two_transform, expected)
msg_fail = "Resized a non-empty tensor but did not warn about it."
with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail):
op_out(out=out)
self.assertEqual(expected, out)
# Case 3: out= with the correct dtype and device, but an empty
# tensor.
# Expected behavior: resize without warning.
def _case_three_transform(t):
return make_tensor((0,),
dtype=t.dtype,
device=t.device)
out = _apply_out_transform(_case_three_transform, expected)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
op_out(out=out)
# Verifies no warning is a resize warning
for w in caught:
if "An output with one or more elements" in str(w.message):
self.fail("Resizing an out= argument with no elements threw a resize warning!")
self.assertEqual(expected, out)
# Case 4: out= with correct shape and dtype, but wrong device.
wrong_device = None
if torch.device(device).type != 'cpu':
wrong_device = 'cpu'
elif torch.cuda.is_available():
wrong_device = 'cuda'
if wrong_device is not None:
def _case_four_transform(t):
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
out = _apply_out_transform(_case_four_transform, expected)
msg_fail = f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}"
with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out)
# Case 5: out= with correct shape and device, but a dtype
# that output cannot be "safely" cast to (long).
# Expected behavior: error.
# NOTE: this case is filtered by dtype since some ops produce
# bool tensors, for example, which can be safely cast to any
# dtype. It is applied when single tensors are floating point or complex
# dtypes, or if an op returns multiple tensors when at least one such
# tensor is a floating point or complex dtype.
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
(not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))):
def _case_five_transform(t):
return make_tensor(t.shape, dtype=torch.long, device=t.device)
out = _apply_out_transform(_case_five_transform, expected)
msg_fail = "" if not isinstance(expected, torch.Tensor) else \
("Expected RuntimeError when doing an unsafe cast from a result of dtype "
f"{expected.dtype} into an out= with dtype torch.long")
with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out)
instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals())
instantiate_device_type_tests(TestCommon, globals())
if __name__ == '__main__':
run_tests()