mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pass rounding_mode for div reference inputs through kwargs (#136308)
Previously, the reference inputs for div with rounding mode did not supply the rounding_mode keyword argument. This didn't match the sample inputs for this op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136308 Approved by: https://github.com/albanD Co-authored-by: Xia, Weiwen <weiwen.xia@intel.com> Co-authored-by: Bob Ren <bobren@meta.com> Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Co-authored-by: siahuat0727 <tansiahuat@gmail.com>
This commit is contained in:
parent
ed092e2161
commit
44707b0667
|
|
@ -634,6 +634,11 @@ class TestCommon(TestCase):
|
|||
# Direct calls to refs and prims are not translated
|
||||
if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16:
|
||||
self.skipTest("Skipped on ROCm")
|
||||
if op.full_name == "_refs.div.floor_rounding" and dtype == torch.bfloat16:
|
||||
self.skipTest(
|
||||
"Skipped _refs.div.floor_rounding with bfloat16"
|
||||
"Divide by 0: _refs produces NaN, torch produces +/-inf"
|
||||
)
|
||||
self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
|
|
|
|||
|
|
@ -13120,7 +13120,8 @@ op_db: List[OpInfo] = [
|
|||
variant_test_name='trunc_rounding',
|
||||
dtypes=all_types_and(torch.half, torch.bfloat16),
|
||||
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
|
||||
sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="trunc")),
|
||||
sample_kwargs=lambda device, dtype, input:
|
||||
({"rounding_mode": "trunc"}, {"rounding_mode": "trunc"}),
|
||||
# https://github.com/pytorch/pytorch/issues/80411
|
||||
gradcheck_fast_mode=True,
|
||||
supports_forward_ad=True,
|
||||
|
|
@ -13149,7 +13150,8 @@ op_db: List[OpInfo] = [
|
|||
variant_test_name='floor_rounding',
|
||||
dtypes=all_types_and(torch.half, torch.bfloat16),
|
||||
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
|
||||
sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="floor")),
|
||||
sample_kwargs=lambda device, dtype, input:
|
||||
({"rounding_mode": "floor"}, {"rounding_mode": "floor"}),
|
||||
# https://github.com/pytorch/pytorch/issues/80411
|
||||
gradcheck_fast_mode=True,
|
||||
supports_forward_ad=True,
|
||||
|
|
|
|||
|
|
@ -2008,7 +2008,9 @@ def generate_elementwise_binary_tensors(
|
|||
for shape in shapes:
|
||||
lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
|
||||
rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
|
||||
yield SampleInput(lhs, args=(rhs,))
|
||||
yield SampleInput(
|
||||
lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
|
||||
)
|
||||
|
||||
|
||||
def generate_elementwise_binary_arbitrarily_strided_tensors(
|
||||
|
|
@ -2036,7 +2038,7 @@ def generate_elementwise_binary_arbitrarily_strided_tensors(
|
|||
500,
|
||||
).as_strided(shape, strides, offset)
|
||||
b = make_arg(shape)
|
||||
yield SampleInput(a, args=(b,))
|
||||
yield SampleInput(a, args=(b,), kwargs=op.sample_kwargs(device, dtype, a)[0])
|
||||
|
||||
|
||||
# Returns a generator of pairs of contiguous tensors on the requested device and with
|
||||
|
|
@ -2100,7 +2102,7 @@ def generate_elementwise_binary_small_value_tensors(
|
|||
lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
yield SampleInput(lhs, args=(rhs,))
|
||||
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
|
||||
|
||||
|
||||
def generate_elementwise_binary_large_value_tensors(
|
||||
|
|
@ -2135,7 +2137,7 @@ def generate_elementwise_binary_large_value_tensors(
|
|||
lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
yield SampleInput(lhs, args=(rhs,))
|
||||
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
|
||||
|
||||
|
||||
def generate_elementwise_binary_extremal_value_tensors(
|
||||
|
|
@ -2164,7 +2166,7 @@ def generate_elementwise_binary_extremal_value_tensors(
|
|||
lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
yield SampleInput(lhs, args=(rhs,))
|
||||
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
|
||||
|
||||
# Test case for NaN propagation
|
||||
nan = (
|
||||
|
|
@ -2179,7 +2181,7 @@ def generate_elementwise_binary_extremal_value_tensors(
|
|||
)
|
||||
rhs.view(-1)[::3] = nan
|
||||
|
||||
yield SampleInput(lhs, args=(rhs,))
|
||||
yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
|
||||
|
||||
|
||||
# Returns a generator of pairs of contiguous and noncontiguous tensors that
|
||||
|
|
@ -2217,7 +2219,12 @@ def generate_elementwise_binary_broadcasting_tensors(
|
|||
shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs
|
||||
)
|
||||
|
||||
yield SampleInput(lhs, args=(rhs,), broadcasts_input=True)
|
||||
yield SampleInput(
|
||||
lhs,
|
||||
args=(rhs,),
|
||||
broadcasts_input=True,
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
)
|
||||
|
||||
|
||||
# Returns a generator of pairs of contiguous tensors and scalars
|
||||
|
|
@ -2236,17 +2243,27 @@ def generate_elementwise_binary_with_scalar_samples(
|
|||
lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
|
||||
rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
|
||||
|
||||
yield SampleInput(lhs, args=(rhs_scalar,))
|
||||
yield SampleInput(
|
||||
lhs, args=(rhs_scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
|
||||
)
|
||||
|
||||
# Extends with scalar lhs
|
||||
if op.supports_one_python_scalar:
|
||||
yield SampleInput(lhs_scalar, args=(rhs,))
|
||||
yield SampleInput(
|
||||
lhs_scalar,
|
||||
args=(rhs,),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0],
|
||||
)
|
||||
|
||||
if op.supports_two_python_scalars:
|
||||
lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
|
||||
rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
|
||||
|
||||
yield SampleInput(lhs_scalar, args=(rhs_scalar,))
|
||||
yield SampleInput(
|
||||
lhs_scalar,
|
||||
args=(rhs_scalar,),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0],
|
||||
)
|
||||
|
||||
|
||||
# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion
|
||||
|
|
@ -2277,10 +2294,16 @@ def generate_elementwise_binary_with_scalar_and_type_promotion_samples(
|
|||
lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
|
||||
rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
|
||||
for scalar in values + scalar_tensors:
|
||||
yield SampleInput(lhs, args=(scalar,))
|
||||
yield SampleInput(
|
||||
lhs, args=(scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
|
||||
)
|
||||
# Extends with scalar lhs
|
||||
if op.supports_one_python_scalar:
|
||||
yield SampleInput(scalar, args=(rhs,))
|
||||
yield SampleInput(
|
||||
scalar,
|
||||
args=(rhs,),
|
||||
kwargs=op.sample_kwargs(device, dtype, scalar)[0],
|
||||
)
|
||||
|
||||
|
||||
# Returns a generator of pairs of noncontiguous tensors
|
||||
|
|
@ -2299,14 +2322,20 @@ def generate_elementwise_binary_noncontiguous_tensors(
|
|||
lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs)
|
||||
rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs)
|
||||
|
||||
yield SampleInput(lhs.clone(), args=(rhs.clone(),))
|
||||
yield SampleInput(lhs.contiguous(), args=(rhs,))
|
||||
yield SampleInput(
|
||||
lhs.clone(), args=(rhs.clone(),), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
|
||||
)
|
||||
yield SampleInput(
|
||||
lhs.contiguous(), args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
|
||||
)
|
||||
|
||||
# Transposed
|
||||
lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs)
|
||||
rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs)
|
||||
|
||||
yield SampleInput(lhs.T, args=(rhs.T,))
|
||||
yield SampleInput(
|
||||
lhs.T, args=(rhs.T,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
|
||||
)
|
||||
|
||||
# More noncontiguity
|
||||
shapes = ((5, 7), (1024,))
|
||||
|
|
@ -2321,8 +2350,16 @@ def generate_elementwise_binary_noncontiguous_tensors(
|
|||
rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
|
||||
rhs_non_contig.copy_(rhs)
|
||||
|
||||
yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),))
|
||||
yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,))
|
||||
yield SampleInput(
|
||||
lhs_non_contig.clone(),
|
||||
args=(rhs_non_contig.clone(),),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
)
|
||||
yield SampleInput(
|
||||
lhs_non_contig.contiguous(),
|
||||
args=(rhs_non_contig,),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
)
|
||||
|
||||
# Noncontiguous indices
|
||||
shape = (2, 2, 1, 2)
|
||||
|
|
@ -2332,8 +2369,16 @@ def generate_elementwise_binary_noncontiguous_tensors(
|
|||
lhs_non_contig = lhs[:, 1, ...]
|
||||
rhs_non_contig = rhs[:, 1, ...]
|
||||
|
||||
yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),))
|
||||
yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,))
|
||||
yield SampleInput(
|
||||
lhs_non_contig.clone(),
|
||||
args=(rhs_non_contig.clone(),),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
)
|
||||
yield SampleInput(
|
||||
lhs_non_contig.contiguous(),
|
||||
args=(rhs_non_contig,),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
)
|
||||
|
||||
# Expanded tensors
|
||||
shapes = ((1, 3), (1, 7), (5, 7))
|
||||
|
|
@ -2345,7 +2390,11 @@ def generate_elementwise_binary_noncontiguous_tensors(
|
|||
lhs_non_contig = lhs.expand(3, -1, -1)
|
||||
rhs_non_contig = rhs.expand(3, -1, -1)
|
||||
|
||||
yield SampleInput(lhs_non_contig, args=(rhs_non_contig,))
|
||||
yield SampleInput(
|
||||
lhs_non_contig,
|
||||
args=(rhs_non_contig,),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
)
|
||||
|
||||
|
||||
# Sample inputs for elementwise binary operators, like add
|
||||
|
|
@ -2376,15 +2425,16 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
|
|||
((0, 1, XS), (0, _M, XS)),
|
||||
)
|
||||
|
||||
sample_kwargs = kwargs.get("sample_kwargs", {})
|
||||
|
||||
for shape_lhs, shape_rhs in shapes:
|
||||
lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs)
|
||||
rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs)
|
||||
broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
|
||||
|
||||
yield SampleInput(
|
||||
lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input
|
||||
lhs,
|
||||
args=(rhs,),
|
||||
kwargs=op.sample_kwargs(device, dtype, lhs)[0],
|
||||
broadcasts_input=broadcasts_input,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2409,6 +2459,7 @@ class BinaryUfuncInfo(OpInfo):
|
|||
*,
|
||||
sample_inputs_func=sample_inputs_elementwise_binary,
|
||||
reference_inputs_func=reference_inputs_elementwise_binary,
|
||||
sample_kwargs=lambda device, dtype, input: ({}, {}),
|
||||
error_inputs_func=None,
|
||||
lhs_make_tensor_kwargs=None,
|
||||
rhs_make_tensor_kwargs=None,
|
||||
|
|
@ -2439,6 +2490,8 @@ class BinaryUfuncInfo(OpInfo):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
self.sample_kwargs = sample_kwargs
|
||||
|
||||
# [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on.
|
||||
if lhs_make_tensor_kwargs is None:
|
||||
lhs_make_tensor_kwargs = {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user