mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add OpInfo for torch.equal and fix support for non-standard bools
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79389 Approved by: https://github.com/mruberry
This commit is contained in:
parent
aa911efdeb
commit
9bf52f4be8
|
|
@ -1987,7 +1987,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
|
|||
char* other_data = data[1];
|
||||
for (const auto i : c10::irange(dim_size)) {
|
||||
(void)i; //Suppress unused variable warning
|
||||
if (*((scalar_t*)self_data) != *((scalar_t*)other_data)) {
|
||||
if (c10::load<scalar_t>(self_data) != c10::load<scalar_t>(other_data)) {
|
||||
result = false;
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4946,6 +4946,9 @@ class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent):
|
|||
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
|
||||
@ops(op_db, dtypes=OpDTypes.supported)
|
||||
def test_nvfuser_correctness(self, device, dtype, op):
|
||||
if not op.supports_tracing:
|
||||
self.skipTest("nvfuser requires tracing support")
|
||||
|
||||
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
||||
|
||||
for variant, sample in variant_sample_pairs:
|
||||
|
|
@ -4972,6 +4975,9 @@ class TestCudaFuserOpInfo(TestCudaFuserOpInfoParent):
|
|||
@ops(op_db, allowed_dtypes=(torch.float16, torch.bfloat16, torch.float32,
|
||||
torch.float64, torch.complex64, torch.complex128))
|
||||
def test_nvfuser_extremal_values(self, device, dtype, op):
|
||||
if not op.supports_tracing:
|
||||
self.skipTest("nvfuser requires tracing support")
|
||||
|
||||
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
||||
|
||||
def _get_extremal_tensor(x, val, dtype):
|
||||
|
|
|
|||
|
|
@ -2624,6 +2624,9 @@ def f({', '.join(param_names)}):
|
|||
@onlyCPU
|
||||
@ops(op_db, dtypes=OpDTypes.supported)
|
||||
def test_nnc_correctness(self, device, dtype, op):
|
||||
if not op.supports_tracing:
|
||||
self.skipTest("Requires tracing support")
|
||||
|
||||
with NoTracerWarnContextManager() as no_warn:
|
||||
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
||||
|
||||
|
|
|
|||
|
|
@ -460,6 +460,7 @@ meta_function_skips = {
|
|||
torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
|
||||
torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
|
||||
torch.diff: {b8},
|
||||
torch.equal: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
|
||||
torch.functional.cdist: {f32, f64},
|
||||
torch.nanmean: {bf16, f16, f32, f64},
|
||||
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
|
||||
|
|
@ -611,6 +612,7 @@ meta_dispatch_expected_failures = {
|
|||
aten.convolution.default: {c64, i64, f64, c128, bf16, f32},
|
||||
aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.equal.default: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten.floor_divide.default: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
||||
aten.floor_divide.out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
||||
aten.frexp.Tensor: {bf16, f16, f64, f32},
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class TestJit(JitCommonTestCase):
|
|||
|
||||
# Check traced forward, grad, and grad grad
|
||||
# TODO: fix tracing here
|
||||
supports_tracing = not has_fake_function
|
||||
supports_tracing = op.supports_tracing and not has_fake_function
|
||||
if op.assert_jit_shape_analysis:
|
||||
self.assertTrue(supports_tracing)
|
||||
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ class TestProxyTensor(TestCase):
|
|||
|
||||
make_fx_failures = {
|
||||
xfail('allclose'),
|
||||
xfail('equal'),
|
||||
xfail('linalg.eigvals'),
|
||||
xfail('nn.functional.max_pool1d', device_type='cpu'),
|
||||
# empty
|
||||
|
|
|
|||
|
|
@ -771,6 +771,9 @@ class OpInfo(object):
|
|||
# only run tracing tests
|
||||
supports_scripting: bool = True
|
||||
|
||||
# if the operator can be traced
|
||||
supports_tracing: bool = True
|
||||
|
||||
# the following metadata relates to sparse csr support and is used in test_sparse_csr.py
|
||||
|
||||
# whether the op supports sparse csr inputs
|
||||
|
|
@ -2408,6 +2411,29 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
|
|||
lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input
|
||||
)
|
||||
|
||||
def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
shapes = (
|
||||
((), ()),
|
||||
((S,), ()),
|
||||
((), (S,)),
|
||||
((S, 1), (S,)),
|
||||
((M, S), ()),
|
||||
((S, S), (S, S))
|
||||
)
|
||||
|
||||
for shape_lhs, shape_rhs in shapes:
|
||||
lhs = make_arg(shape_lhs)
|
||||
rhs = make_arg(shape_rhs)
|
||||
broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
|
||||
|
||||
yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input)
|
||||
if shape_lhs == shape_rhs:
|
||||
yield SampleInput(lhs, args=(lhs.clone().detach_(),))
|
||||
|
||||
|
||||
|
||||
def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -11533,6 +11559,14 @@ op_db: List[OpInfo] = [
|
|||
supports_fwgrad_bwgrad=True,
|
||||
supports_two_python_scalars=True,
|
||||
rhs_make_tensor_kwargs=dict(exclude_zero=True)),
|
||||
OpInfo('equal',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
ref=lambda input, other: (input == other).all(),
|
||||
sample_inputs_func=sample_inputs_equal,
|
||||
supports_autograd=False,
|
||||
supports_tracing=False,
|
||||
skips=(
|
||||
)),
|
||||
UnaryUfuncInfo('exp',
|
||||
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user