Add OpInfo for torch.equal and fix support for non-standard bools

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79389

Approved by: https://github.com/mruberry
This commit is contained in:
Peter Bell 2022-06-20 19:58:41 +01:00 committed by PyTorch MergeBot
parent aa911efdeb
commit 9bf52f4be8
7 changed files with 48 additions and 2 deletions

View File

@ -1987,7 +1987,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
char* other_data = data[1];
for (const auto i : c10::irange(dim_size)) {
(void)i; //Suppress unused variable warning
if (*((scalar_t*)self_data) != *((scalar_t*)other_data)) {
if (c10::load<scalar_t>(self_data) != c10::load<scalar_t>(other_data)) {
result = false;
return;
}

View File

@ -4946,6 +4946,9 @@ class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent):
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@ops(op_db, dtypes=OpDTypes.supported)
def test_nvfuser_correctness(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("nvfuser requires tracing support")
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
@ -4972,6 +4975,9 @@ class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent):
@ops(op_db, allowed_dtypes=(torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128))
def test_nvfuser_extremal_values(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("nvfuser requires tracing support")
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
def _get_extremal_tensor(x, val, dtype):

View File

@ -2624,6 +2624,9 @@ def f({', '.join(param_names)}):
@onlyCPU
@ops(op_db, dtypes=OpDTypes.supported)
def test_nnc_correctness(self, device, dtype, op):
if not op.supports_tracing:
self.skipTest("Requires tracing support")
with NoTracerWarnContextManager() as no_warn:
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)

View File

@ -460,6 +460,7 @@ meta_function_skips = {
torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
torch.diff: {b8},
torch.equal: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
torch.functional.cdist: {f32, f64},
torch.nanmean: {bf16, f16, f32, f64},
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
@ -611,6 +612,7 @@ meta_dispatch_expected_failures = {
aten.convolution.default: {c64, i64, f64, c128, bf16, f32},
aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.equal.default: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.floor_divide.default: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
aten.floor_divide.out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
aten.frexp.Tensor: {bf16, f16, f64, f32},

View File

@ -107,7 +107,7 @@ class TestJit(JitCommonTestCase):
# Check traced forward, grad, and grad grad
# TODO: fix tracing here
supports_tracing = not has_fake_function
supports_tracing = op.supports_tracing and not has_fake_function
if op.assert_jit_shape_analysis:
self.assertTrue(supports_tracing)

View File

@ -137,6 +137,7 @@ class TestProxyTensor(TestCase):
make_fx_failures = {
xfail('allclose'),
xfail('equal'),
xfail('linalg.eigvals'),
xfail('nn.functional.max_pool1d', device_type='cpu'),
# empty

View File

@ -771,6 +771,9 @@ class OpInfo(object):
# only run tracing tests
supports_scripting: bool = True
# if the operator can be traced
supports_tracing: bool = True
# the following metadata relates to sparse csr support and is used in test_sparse_csr.py
# whether the op supports sparse csr inputs
@ -2408,6 +2411,29 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input
)
def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
make_arg = partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
shapes = (
((), ()),
((S,), ()),
((), (S,)),
((S, 1), (S,)),
((M, S), ()),
((S, S), (S, S))
)
for shape_lhs, shape_rhs in shapes:
lhs = make_arg(shape_lhs)
rhs = make_arg(shape_rhs)
broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input)
if shape_lhs == shape_rhs:
yield SampleInput(lhs, args=(lhs.clone().detach_(),))
def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -11533,6 +11559,14 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
supports_two_python_scalars=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True)),
OpInfo('equal',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
ref=lambda input, other: (input == other).all(),
sample_inputs_func=sample_inputs_equal,
supports_autograd=False,
supports_tracing=False,
skips=(
)),
UnaryUfuncInfo('exp',
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),