mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
support unbacked softmax / logsoftmax (#162216)
### DDE
```
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*u0, 0) (unhinted: Eq(3*u0, 0)). (Size-like symbols: u0)
Caused by: (_decomp/decompositions.py:1185 in _softmax)
```
```
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, 0) (unhinted: Eq(u0, 0)). (Size-like symbols: u0)
Caused by: logsoft = torch.nn.functional.log_softmax(nz, dim=0) # test/inductor/test_unbacked_symints.py:573 in fn (_decomp/decompositions.py:1212 in _log_softmax)
```
```
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 0) (unhinted: Ne(u0, 0)). (Size-like symbols: u0)
Caused by: (_refs/__init__.py:2218 in _reduction)
```
### Cannot convert symbols to int
```
File "torch/_inductor/lowering.py", line 7160, in prepare_softmax_online
and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "orch/_inductor/sizevars.py", line 591, in size_hint
return int(out)
^^^^^^^^
File "sympy/core/expr.py", line 342, in __int__
raise TypeError("Cannot convert symbols to int")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162216
Approved by: https://github.com/laithsakka, https://github.com/eellison
This commit is contained in:
parent
1f21f8544c
commit
3c8b90542c
|
|
@ -489,6 +489,22 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
def test_softmax(self, device):
|
||||
def fn(x):
|
||||
nz = x.nonzero().float()
|
||||
soft = torch.softmax(nz, dim=0)
|
||||
logsoft = torch.nn.functional.log_softmax(nz, dim=0)
|
||||
return soft * logsoft
|
||||
|
||||
example_inputs = (
|
||||
torch.randint(low=0, high=2, size=(32,), device=device, dtype=torch.int8),
|
||||
)
|
||||
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
|
|
|
|||
|
|
@ -1173,6 +1173,8 @@ def native_dropout(input: Tensor, p: float, train: Optional[bool]):
|
|||
@register_decomposition(aten._softmax)
|
||||
@out_wrapper()
|
||||
def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
|
||||
# a contiguous tensor.
|
||||
x = x.contiguous()
|
||||
|
|
@ -1182,7 +1184,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
|||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
x = x.to(computation_dtype)
|
||||
if x.numel() == 0:
|
||||
if guard_or_false(x.numel() == 0):
|
||||
unnormalized = torch.exp(x)
|
||||
else:
|
||||
x_max = torch.amax(x, dim, keepdim=True)
|
||||
|
|
@ -1196,6 +1198,8 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
|||
@register_decomposition(aten._log_softmax)
|
||||
@out_wrapper(exact_dtype=True)
|
||||
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
||||
# returns a contiguous tensor.
|
||||
x = x.contiguous()
|
||||
|
|
@ -1205,7 +1209,7 @@ def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
|||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
x = x.to(computation_dtype)
|
||||
if x.numel() == 0:
|
||||
if guard_or_false(x.numel() == 0):
|
||||
shifted = x
|
||||
else:
|
||||
x_max = torch.amax(x, dim, keepdim=True)
|
||||
|
|
|
|||
|
|
@ -7220,9 +7220,8 @@ def prepare_softmax_online(x, dim):
|
|||
reduction_numel=rnumel,
|
||||
)
|
||||
|
||||
if (
|
||||
num_split == 1
|
||||
and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
|
||||
if num_split == 1 and V.graph.sizevars.statically_known_geq(
|
||||
rnumel, config.unroll_reductions_threshold
|
||||
):
|
||||
max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
|
||||
input_node=x, num_output=2, reduction_hint=hint, **kwargs
|
||||
|
|
|
|||
|
|
@ -2259,11 +2259,14 @@ def _reduction(
|
|||
dims = (dims,) # type: ignore[assignment]
|
||||
dims = utils.reduction_dims(a.shape, dims)
|
||||
if not has_identity:
|
||||
valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
|
||||
if not valid_shape:
|
||||
raise RuntimeError(
|
||||
"reducing over zero-size dimension for reduction operation without identity"
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import sym_and
|
||||
|
||||
valid_shape = a.ndim == 0 or sym_and(*(a.shape[i] > 0 for i in dims))
|
||||
torch._check(
|
||||
valid_shape,
|
||||
lambda: "reducing over zero-size dimension for reduction operation without identity",
|
||||
)
|
||||
|
||||
computation_dtype, result_dtype = utils.reduction_dtypes(
|
||||
a, output_dtype_kind, dtype
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user