support unbacked softmax / logsoftmax (#162216)

### DDE

```
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*u0, 0) (unhinted: Eq(3*u0, 0)).  (Size-like symbols: u0)

Caused by: (_decomp/decompositions.py:1185 in _softmax)
```

```
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, 0) (unhinted: Eq(u0, 0)).  (Size-like symbols: u0)

Caused by: logsoft = torch.nn.functional.log_softmax(nz, dim=0)  # test/inductor/test_unbacked_symints.py:573 in fn (_decomp/decompositions.py:1212 in _log_softmax)
```

```
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 0) (unhinted: Ne(u0, 0)).  (Size-like symbols: u0)

Caused by: (_refs/__init__.py:2218 in _reduction)
```

### Cannot convert symbols to int
```
  File "torch/_inductor/lowering.py", line 7160, in prepare_softmax_online
    and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "orch/_inductor/sizevars.py", line 591, in size_hint
    return int(out)
           ^^^^^^^^
  File "sympy/core/expr.py", line 342, in __int__
    raise TypeError("Cannot convert symbols to int")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162216
Approved by: https://github.com/laithsakka, https://github.com/eellison
This commit is contained in:
Colin Peppler 2025-09-17 17:48:45 -07:00 committed by PyTorch MergeBot
parent 1f21f8544c
commit 3c8b90542c
4 changed files with 32 additions and 10 deletions

View File

@ -489,6 +489,22 @@ class TestUnbackedSymints(InductorTestCase):
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_softmax(self, device):
def fn(x):
nz = x.nonzero().float()
soft = torch.softmax(nz, dim=0)
logsoft = torch.nn.functional.log_softmax(nz, dim=0)
return soft * logsoft
example_inputs = (
torch.randint(low=0, high=2, size=(32,), device=device, dtype=torch.int8),
)
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})

View File

@ -1173,6 +1173,8 @@ def native_dropout(input: Tensor, p: float, train: Optional[bool]):
@register_decomposition(aten._softmax)
@out_wrapper()
def _softmax(x: Tensor, dim: int, half_to_float: bool):
from torch.fx.experimental.symbolic_shapes import guard_or_false
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
# a contiguous tensor.
x = x.contiguous()
@ -1182,7 +1184,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
if x.numel() == 0:
if guard_or_false(x.numel() == 0):
unnormalized = torch.exp(x)
else:
x_max = torch.amax(x, dim, keepdim=True)
@ -1196,6 +1198,8 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
@register_decomposition(aten._log_softmax)
@out_wrapper(exact_dtype=True)
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
from torch.fx.experimental.symbolic_shapes import guard_or_false
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
# returns a contiguous tensor.
x = x.contiguous()
@ -1205,7 +1209,7 @@ def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
if x.numel() == 0:
if guard_or_false(x.numel() == 0):
shifted = x
else:
x_max = torch.amax(x, dim, keepdim=True)

View File

@ -7220,9 +7220,8 @@ def prepare_softmax_online(x, dim):
reduction_numel=rnumel,
)
if (
num_split == 1
and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
if num_split == 1 and V.graph.sizevars.statically_known_geq(
rnumel, config.unroll_reductions_threshold
):
max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
input_node=x, num_output=2, reduction_hint=hint, **kwargs

View File

@ -2259,11 +2259,14 @@ def _reduction(
dims = (dims,) # type: ignore[assignment]
dims = utils.reduction_dims(a.shape, dims)
if not has_identity:
valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
if not valid_shape:
raise RuntimeError(
"reducing over zero-size dimension for reduction operation without identity"
)
from torch.fx.experimental.symbolic_shapes import sym_and
valid_shape = a.ndim == 0 or sym_and(*(a.shape[i] > 0 for i in dims))
torch._check(
valid_shape,
lambda: "reducing over zero-size dimension for reduction operation without identity",
)
computation_dtype, result_dtype = utils.reduction_dtypes(
a, output_dtype_kind, dtype
)