Pass rounding_mode for div reference inputs through kwargs (#136308)

Previously, the reference inputs for div with rounding mode did not supply the rounding_mode keyword argument. This didn't match the sample inputs for this op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136308
Approved by: https://github.com/albanD

Co-authored-by: Xia, Weiwen <weiwen.xia@intel.com>
Co-authored-by: Bob Ren <bobren@meta.com>
Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com>
Co-authored-by: siahuat0727 <tansiahuat@gmail.com>
This commit is contained in:
George Wigley 2024-11-29 21:28:22 +00:00 committed by PyTorch MergeBot
parent ed092e2161
commit 44707b0667
3 changed files with 85 additions and 25 deletions

View File

@ -634,6 +634,11 @@ class TestCommon(TestCase):
# Direct calls to refs and prims are not translated
if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16:
self.skipTest("Skipped on ROCm")
if op.full_name == "_refs.div.floor_rounding" and dtype == torch.bfloat16:
self.skipTest(
"Skipped _refs.div.floor_rounding with bfloat16"
"Divide by 0: _refs produces NaN, torch produces +/-inf"
)
self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")

View File

@ -13120,7 +13120,8 @@ op_db: List[OpInfo] = [
variant_test_name='trunc_rounding',
dtypes=all_types_and(torch.half, torch.bfloat16),
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="trunc")),
sample_kwargs=lambda device, dtype, input:
({"rounding_mode": "trunc"}, {"rounding_mode": "trunc"}),
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,
@ -13149,7 +13150,8 @@ op_db: List[OpInfo] = [
variant_test_name='floor_rounding',
dtypes=all_types_and(torch.half, torch.bfloat16),
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="floor")),
sample_kwargs=lambda device, dtype, input:
({"rounding_mode": "floor"}, {"rounding_mode": "floor"}),
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
supports_forward_ad=True,

View File

@ -2008,7 +2008,9 @@ def generate_elementwise_binary_tensors(
for shape in shapes:
lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
yield SampleInput(lhs, args=(rhs,))
yield SampleInput(
lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
)
def generate_elementwise_binary_arbitrarily_strided_tensors(
@ -2036,7 +2038,7 @@ def generate_elementwise_binary_arbitrarily_strided_tensors(
500,
).as_strided(shape, strides, offset)
b = make_arg(shape)
yield SampleInput(a, args=(b,))
yield SampleInput(a, args=(b,), kwargs=op.sample_kwargs(device, dtype, a)[0])
# Returns a generator of pairs of contiguous tensors on the requested device and with
@ -2100,7 +2102,7 @@ def generate_elementwise_binary_small_value_tensors(
lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(lhs, args=(rhs,))
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
def generate_elementwise_binary_large_value_tensors(
@ -2135,7 +2137,7 @@ def generate_elementwise_binary_large_value_tensors(
lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(lhs, args=(rhs,))
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
def generate_elementwise_binary_extremal_value_tensors(
@ -2164,7 +2166,7 @@ def generate_elementwise_binary_extremal_value_tensors(
lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(lhs, args=(rhs,))
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
# Test case for NaN propagation
nan = (
@ -2179,7 +2181,7 @@ def generate_elementwise_binary_extremal_value_tensors(
)
rhs.view(-1)[::3] = nan
yield SampleInput(lhs, args=(rhs,))
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
# Returns a generator of pairs of contiguous and noncontiguous tensors that
@ -2217,7 +2219,12 @@ def generate_elementwise_binary_broadcasting_tensors(
shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs
)
yield SampleInput(lhs, args=(rhs,), broadcasts_input=True)
yield SampleInput(
lhs,
args=(rhs,),
broadcasts_input=True,
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
)
# Returns a generator of pairs of contiguous tensors and scalars
@ -2236,17 +2243,27 @@ def generate_elementwise_binary_with_scalar_samples(
lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
yield SampleInput(lhs, args=(rhs_scalar,))
yield SampleInput(
lhs, args=(rhs_scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
)
# Extends with scalar lhs
if op.supports_one_python_scalar:
yield SampleInput(lhs_scalar, args=(rhs,))
yield SampleInput(
lhs_scalar,
args=(rhs,),
kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0],
)
if op.supports_two_python_scalars:
lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
yield SampleInput(lhs_scalar, args=(rhs_scalar,))
yield SampleInput(
lhs_scalar,
args=(rhs_scalar,),
kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0],
)
# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion
@ -2277,10 +2294,16 @@ def generate_elementwise_binary_with_scalar_and_type_promotion_samples(
lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
for scalar in values + scalar_tensors:
yield SampleInput(lhs, args=(scalar,))
yield SampleInput(
lhs, args=(scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
)
# Extends with scalar lhs
if op.supports_one_python_scalar:
yield SampleInput(scalar, args=(rhs,))
yield SampleInput(
scalar,
args=(rhs,),
kwargs=op.sample_kwargs(device, dtype, scalar)[0],
)
# Returns a generator of pairs of noncontiguous tensors
@ -2299,14 +2322,20 @@ def generate_elementwise_binary_noncontiguous_tensors(
lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs)
rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs)
yield SampleInput(lhs.clone(), args=(rhs.clone(),))
yield SampleInput(lhs.contiguous(), args=(rhs,))
yield SampleInput(
lhs.clone(), args=(rhs.clone(),), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
)
yield SampleInput(
lhs.contiguous(), args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
)
# Transposed
lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs)
rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs)
yield SampleInput(lhs.T, args=(rhs.T,))
yield SampleInput(
lhs.T, args=(rhs.T,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
)
# More noncontiguity
shapes = ((5, 7), (1024,))
@ -2321,8 +2350,16 @@ def generate_elementwise_binary_noncontiguous_tensors(
rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
rhs_non_contig.copy_(rhs)
yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),))
yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,))
yield SampleInput(
lhs_non_contig.clone(),
args=(rhs_non_contig.clone(),),
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
)
yield SampleInput(
lhs_non_contig.contiguous(),
args=(rhs_non_contig,),
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
)
# Noncontiguous indices
shape = (2, 2, 1, 2)
@ -2332,8 +2369,16 @@ def generate_elementwise_binary_noncontiguous_tensors(
lhs_non_contig = lhs[:, 1, ...]
rhs_non_contig = rhs[:, 1, ...]
yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),))
yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,))
yield SampleInput(
lhs_non_contig.clone(),
args=(rhs_non_contig.clone(),),
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
)
yield SampleInput(
lhs_non_contig.contiguous(),
args=(rhs_non_contig,),
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
)
# Expanded tensors
shapes = ((1, 3), (1, 7), (5, 7))
@ -2345,7 +2390,11 @@ def generate_elementwise_binary_noncontiguous_tensors(
lhs_non_contig = lhs.expand(3, -1, -1)
rhs_non_contig = rhs.expand(3, -1, -1)
yield SampleInput(lhs_non_contig, args=(rhs_non_contig,))
yield SampleInput(
lhs_non_contig,
args=(rhs_non_contig,),
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
)
# Sample inputs for elementwise binary operators, like add
@ -2376,15 +2425,16 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
((0, 1, XS), (0, _M, XS)),
)
sample_kwargs = kwargs.get("sample_kwargs", {})
for shape_lhs, shape_rhs in shapes:
lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs)
rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs)
broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
yield SampleInput(
lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input
lhs,
args=(rhs,),
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
broadcasts_input=broadcasts_input,
)
@ -2409,6 +2459,7 @@ class BinaryUfuncInfo(OpInfo):
*,
sample_inputs_func=sample_inputs_elementwise_binary,
reference_inputs_func=reference_inputs_elementwise_binary,
sample_kwargs=lambda device, dtype, input: ({}, {}),
error_inputs_func=None,
lhs_make_tensor_kwargs=None,
rhs_make_tensor_kwargs=None,
@ -2439,6 +2490,8 @@ class BinaryUfuncInfo(OpInfo):
**kwargs,
)
self.sample_kwargs = sample_kwargs
# [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on.
if lhs_make_tensor_kwargs is None:
lhs_make_tensor_kwargs = {}