Fix BC-breaking change introduced by #91499 (#93091)

This fixes BC-breaking changes introduced by https://github.com/pytorch/pytorch/pull/91499
Make enum accept both `min` and `amin` values
Reinstante testing

To reiterate
454361435c/torch/masked/_ops.py (L786)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93091
Approved by: https://github.com/ngimel
This commit is contained in:
Nikita Shulga 2023-01-27 03:58:32 +00:00 committed by PyTorch MergeBot
parent 7fade4f771
commit 661800a2cf
3 changed files with 8 additions and 8 deletions

View File

@ -7,11 +7,11 @@ namespace at { namespace native {
enum ReductionType {MAX, MEAN, MIN, SUM, PROD};
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "amax") {
if (reduce == "max" || reduce == "amax") {
return ReductionType::MAX;
} else if (reduce == "mean") {
return ReductionType::MEAN;
} else if (reduce == "amin") {
} else if (reduce == "min" || reduce == "amin") {
return ReductionType::MIN;
} else if (reduce == "sum") {
return ReductionType::SUM;

View File

@ -18,17 +18,17 @@ from torch.testing._internal.common_utils import (
)
reductions = ["amax", "mean", "amin", "sum", "prod"]
reductions = ["max", "mean", "min", "sum", "prod"]
def get_default_value(initial_value, reduction):
if initial_value is not None:
return initial_value
if reduction == "amax":
if reduction == "max":
return -float("Inf")
elif reduction == "mean":
return float("nan")
elif reduction == "amin":
elif reduction == "min":
return float("Inf")
elif reduction == "sum":
return 0.0
@ -133,13 +133,13 @@ class TestSegmentReductions(TestCase):
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "amax":
if reduction == "max":
expected_result = [1, float("nan"), 5, default_value]
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
elif reduction == "mean":
expected_result = [1, float("nan"), 4.666, default_value]
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
elif reduction == "amin":
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)

View File

@ -6260,7 +6260,7 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode=
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
)
reductions = ["amax", "mean", "amin", "sum", "prod"]
reductions = ["max", "mean", "min", "sum", "prod"]
for args, reduce, initial in product(test_cases, reductions, [1, 2]):
inp_shape, dim, lengths, unsafe = args
lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)