mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This fixes BC-breaking changes introduced by https://github.com/pytorch/pytorch/pull/91499
Make enum accept both `min` and `amin` values
Reinstante testing
To reiterate
454361435c/torch/masked/_ops.py (L786)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93091
Approved by: https://github.com/ngimel
This commit is contained in:
parent
7fade4f771
commit
661800a2cf
|
|
@ -7,11 +7,11 @@ namespace at { namespace native {
|
|||
enum ReductionType {MAX, MEAN, MIN, SUM, PROD};
|
||||
|
||||
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
if (reduce == "amax") {
|
||||
if (reduce == "max" || reduce == "amax") {
|
||||
return ReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
return ReductionType::MEAN;
|
||||
} else if (reduce == "amin") {
|
||||
} else if (reduce == "min" || reduce == "amin") {
|
||||
return ReductionType::MIN;
|
||||
} else if (reduce == "sum") {
|
||||
return ReductionType::SUM;
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
|
||||
reductions = ["amax", "mean", "amin", "sum", "prod"]
|
||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||
|
||||
|
||||
def get_default_value(initial_value, reduction):
|
||||
if initial_value is not None:
|
||||
return initial_value
|
||||
if reduction == "amax":
|
||||
if reduction == "max":
|
||||
return -float("Inf")
|
||||
elif reduction == "mean":
|
||||
return float("nan")
|
||||
elif reduction == "amin":
|
||||
elif reduction == "min":
|
||||
return float("Inf")
|
||||
elif reduction == "sum":
|
||||
return 0.0
|
||||
|
|
@ -133,13 +133,13 @@ class TestSegmentReductions(TestCase):
|
|||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "amax":
|
||||
if reduction == "max":
|
||||
expected_result = [1, float("nan"), 5, default_value]
|
||||
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
|
||||
elif reduction == "mean":
|
||||
expected_result = [1, float("nan"), 4.666, default_value]
|
||||
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
|
||||
elif reduction == "amin":
|
||||
elif reduction == "min":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
|
|
|
|||
|
|
@ -6260,7 +6260,7 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode=
|
|||
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
|
||||
)
|
||||
|
||||
reductions = ["amax", "mean", "amin", "sum", "prod"]
|
||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||
for args, reduce, initial in product(test_cases, reductions, [1, 2]):
|
||||
inp_shape, dim, lengths, unsafe = args
|
||||
lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user