Fix jagged NT softmax semantics (#119459)

Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong)
After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119459
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser 2024-02-08 11:46:53 -05:00 committed by PyTorch MergeBot
parent 278a0e1600
commit 6adadbaf79
2 changed files with 28 additions and 2 deletions

View File

@ -3182,6 +3182,22 @@ class TestNestedTensorSubclass(TestCase):
):
torch.split(nt, [1, 2], 1)
def test_softmax(self, device):
nt = random_nt_from_dims(
[3, None, 5], device=device, dtype=torch.float32, layout=torch.jagged)
# operate on dim=2
output = nt.softmax(dim=2)
for in_component, out_component in zip(nt.unbind(), output.unbind()):
# dim=2 -> dim=1 after unbind
self.assertEqual(in_component.softmax(dim=1), out_component)
# operate on dim=-1
output2 = nt.softmax(dim=-1)
self.assertEqual(output, output2)
for in_component, out_component in zip(nt.unbind(), output2.unbind()):
self.assertEqual(in_component.softmax(dim=-1), out_component)
def test_views_inherit_ragged_dim(self, device):
# view
nt = random_nt_from_dims(

View File

@ -447,9 +447,19 @@ register_jagged_func(
)(jagged_unary_pointwise)
register_jagged_func(
@register_jagged_func(
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
)(jagged_unary_pointwise)
)
def _softmax_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
dim = new_kwargs["dim"]
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "softmax")
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
@register_jagged_func(