mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix jagged NT softmax semantics (#119459)
Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong) After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values` Pull Request resolved: https://github.com/pytorch/pytorch/pull/119459 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
278a0e1600
commit
6adadbaf79
|
|
@ -3182,6 +3182,22 @@ class TestNestedTensorSubclass(TestCase):
|
|||
):
|
||||
torch.split(nt, [1, 2], 1)
|
||||
|
||||
def test_softmax(self, device):
|
||||
nt = random_nt_from_dims(
|
||||
[3, None, 5], device=device, dtype=torch.float32, layout=torch.jagged)
|
||||
|
||||
# operate on dim=2
|
||||
output = nt.softmax(dim=2)
|
||||
for in_component, out_component in zip(nt.unbind(), output.unbind()):
|
||||
# dim=2 -> dim=1 after unbind
|
||||
self.assertEqual(in_component.softmax(dim=1), out_component)
|
||||
|
||||
# operate on dim=-1
|
||||
output2 = nt.softmax(dim=-1)
|
||||
self.assertEqual(output, output2)
|
||||
for in_component, out_component in zip(nt.unbind(), output2.unbind()):
|
||||
self.assertEqual(in_component.softmax(dim=-1), out_component)
|
||||
|
||||
def test_views_inherit_ragged_dim(self, device):
|
||||
# view
|
||||
nt = random_nt_from_dims(
|
||||
|
|
|
|||
|
|
@ -447,9 +447,19 @@ register_jagged_func(
|
|||
)(jagged_unary_pointwise)
|
||||
|
||||
|
||||
register_jagged_func(
|
||||
@register_jagged_func(
|
||||
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
|
||||
)(jagged_unary_pointwise)
|
||||
)
|
||||
def _softmax_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
dim = new_kwargs["dim"]
|
||||
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "softmax")
|
||||
|
||||
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user