mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Modify the existing `sum` operator in PyTorch, invoked by `torch.sum`, to allow for reductions along the ragged dimension of a nested tensor. This diff enables PyTorch users to invoke `torch.sum` on a nested tensor with `dim=1`, where `ragged_idx=1`. Functions modified in `caffe2/torch/nested/_internal/ops.py`: - `sum_dim_IntList()`: The function assumes that `ragged_idx=1`; in the case that `dim=1` as well, where `dim` is the dimension on which we reduce, this diff invokes the PyTorch benchmark found in D58423489. Specifically, this diff pads a nested tensor, e.g. of logical shape `(B, *, M)`, using [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), then reduces across the `*` dimension (`dim == 1`) to a `(B, M)` output tensor. - `_wrap_jagged_dims()`: This diff adds special handling to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`. In this function's creation, I created a helper function, `_get_condition_for_invalid_jagged_reductions()`, which makes it clearer which conditions apply to which operators. Specifically, operators which are enabled with jagged reductions are specified at the top of the file in `SUPPORTED_JAGGED_REDUCTIONS` and have a different set of conditions that need to be tested, as reducing along `dim == 1` without `dim == 0` is now possible. Functions modified in `caffe2/test/test_nestedtensor.py`: - `test_sum_int_DimList()`: This diff adds special handling in the `sum` unit test to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`. - `test_sum_int_DimList_ragged_dim_1()`: This diff adds a new unit test which verifies the accuracy and feasibility of reducing along the jagged dimension of a nested tensor. Notes: - This diff solely adds functionality for the case in which we reduce only along the ragged dimension. Cases in which we reduce along both the ragged and another dimension, like `dim == (1, 2)`, are not permitted, as this set of diffs focuses primarily on the former. - The `sum` operator is the only operator which uses the function `_wrap_jagged_dims()`; all other operators use `_wrap_jagged_dim()`. I would like to later look into why this is the case and if we can consolidate this! - I modified some of the comments in the `sum` function as well as the unit tests for more clarity. Test Plan: Verify that existing (`test_sum_int_DimList`) and new (`test_sum_int_DimList_ragged_dim_1`) unit tests pass via the following command: ``` buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_sum_int_DimList ``` Differential Revision: D59571209 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130425 Approved by: https://github.com/davidberard98 |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||