mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53727 This is first diff to add native support for segment reduction in PyTorch. It provides similar functionality like torch.scatter or "numpy.ufunc.reduceat". This diff mainly focuses on API layer to make sure future improvements will not cause backward compatibility issues. Once API is settled, here are next steps I am planning: - Add support for other major reduction types (e.g. min, sum) for 1D tensor - Add Cuda support - Backward support - Documentation for the op - Perf optimizations and benchmark util - Support for multi dimensional tensors (on data and lengths) (not high priority) - Support for 'indices' (not high priority) Test Plan: Added unit test Reviewed By: ngimel Differential Revision: D26952075 fbshipit-source-id: 8040ec96def3013e7240cf675d499ee424437560
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import torch
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
onlyCPU,
|
|
dtypes,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
)
|
|
|
|
|
|
class TestSegmentReductions(TestCase):
|
|
@onlyCPU
|
|
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
|
def test_max_simple_1d(self, device, dtype):
|
|
lengths = torch.tensor([1, 2, 3], device=device)
|
|
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
|
|
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
|
|
actual_result = torch.segment_reduce(
|
|
data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
|
|
)
|
|
self.assertEqual(
|
|
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
|
)
|
|
actual_result = torch.segment_reduce(
|
|
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
|
|
)
|
|
self.assertEqual(
|
|
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
|
)
|
|
|
|
|
|
instantiate_device_type_tests(TestSegmentReductions, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|