mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
support unbacked softmax / logsoftmax (#162216)
### DDE
```
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*u0, 0) (unhinted: Eq(3*u0, 0)). (Size-like symbols: u0)
Caused by: (_decomp/decompositions.py:1185 in _softmax)
```
```
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, 0) (unhinted: Eq(u0, 0)). (Size-like symbols: u0)
Caused by: logsoft = torch.nn.functional.log_softmax(nz, dim=0) # test/inductor/test_unbacked_symints.py:573 in fn (_decomp/decompositions.py:1212 in _log_softmax)
```
```
GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 0) (unhinted: Ne(u0, 0)). (Size-like symbols: u0)
Caused by: (_refs/__init__.py:2218 in _reduction)
```
### Cannot convert symbols to int
```
File "torch/_inductor/lowering.py", line 7160, in prepare_softmax_online
and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "orch/_inductor/sizevars.py", line 591, in size_hint
return int(out)
^^^^^^^^
File "sympy/core/expr.py", line 342, in __int__
raise TypeError("Cannot convert symbols to int")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162216
Approved by: https://github.com/laithsakka, https://github.com/eellison
This commit is contained in:
parent
1f21f8544c
commit
3c8b90542c
|
|
@ -489,6 +489,22 @@ class TestUnbackedSymints(InductorTestCase):
|
||||||
expected = fn(*example_inputs)
|
expected = fn(*example_inputs)
|
||||||
torch.testing.assert_close(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||||
|
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||||
|
def test_softmax(self, device):
|
||||||
|
def fn(x):
|
||||||
|
nz = x.nonzero().float()
|
||||||
|
soft = torch.softmax(nz, dim=0)
|
||||||
|
logsoft = torch.nn.functional.log_softmax(nz, dim=0)
|
||||||
|
return soft * logsoft
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.randint(low=0, high=2, size=(32,), device=device, dtype=torch.int8),
|
||||||
|
)
|
||||||
|
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||||
|
expected = fn(*example_inputs)
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||||
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
||||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||||
|
|
|
||||||
|
|
@ -1173,6 +1173,8 @@ def native_dropout(input: Tensor, p: float, train: Optional[bool]):
|
||||||
@register_decomposition(aten._softmax)
|
@register_decomposition(aten._softmax)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||||
|
|
||||||
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
|
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
|
||||||
# a contiguous tensor.
|
# a contiguous tensor.
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
|
|
@ -1182,7 +1184,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||||
)
|
)
|
||||||
x = x.to(computation_dtype)
|
x = x.to(computation_dtype)
|
||||||
if x.numel() == 0:
|
if guard_or_false(x.numel() == 0):
|
||||||
unnormalized = torch.exp(x)
|
unnormalized = torch.exp(x)
|
||||||
else:
|
else:
|
||||||
x_max = torch.amax(x, dim, keepdim=True)
|
x_max = torch.amax(x, dim, keepdim=True)
|
||||||
|
|
@ -1196,6 +1198,8 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
@register_decomposition(aten._log_softmax)
|
@register_decomposition(aten._log_softmax)
|
||||||
@out_wrapper(exact_dtype=True)
|
@out_wrapper(exact_dtype=True)
|
||||||
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||||
|
|
||||||
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
||||||
# returns a contiguous tensor.
|
# returns a contiguous tensor.
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
|
|
@ -1205,7 +1209,7 @@ def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||||
)
|
)
|
||||||
x = x.to(computation_dtype)
|
x = x.to(computation_dtype)
|
||||||
if x.numel() == 0:
|
if guard_or_false(x.numel() == 0):
|
||||||
shifted = x
|
shifted = x
|
||||||
else:
|
else:
|
||||||
x_max = torch.amax(x, dim, keepdim=True)
|
x_max = torch.amax(x, dim, keepdim=True)
|
||||||
|
|
|
||||||
|
|
@ -7220,9 +7220,8 @@ def prepare_softmax_online(x, dim):
|
||||||
reduction_numel=rnumel,
|
reduction_numel=rnumel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if num_split == 1 and V.graph.sizevars.statically_known_geq(
|
||||||
num_split == 1
|
rnumel, config.unroll_reductions_threshold
|
||||||
and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
|
|
||||||
):
|
):
|
||||||
max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
|
max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
|
||||||
input_node=x, num_output=2, reduction_hint=hint, **kwargs
|
input_node=x, num_output=2, reduction_hint=hint, **kwargs
|
||||||
|
|
|
||||||
|
|
@ -2259,11 +2259,14 @@ def _reduction(
|
||||||
dims = (dims,) # type: ignore[assignment]
|
dims = (dims,) # type: ignore[assignment]
|
||||||
dims = utils.reduction_dims(a.shape, dims)
|
dims = utils.reduction_dims(a.shape, dims)
|
||||||
if not has_identity:
|
if not has_identity:
|
||||||
valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
|
from torch.fx.experimental.symbolic_shapes import sym_and
|
||||||
if not valid_shape:
|
|
||||||
raise RuntimeError(
|
valid_shape = a.ndim == 0 or sym_and(*(a.shape[i] > 0 for i in dims))
|
||||||
"reducing over zero-size dimension for reduction operation without identity"
|
torch._check(
|
||||||
|
valid_shape,
|
||||||
|
lambda: "reducing over zero-size dimension for reduction operation without identity",
|
||||||
)
|
)
|
||||||
|
|
||||||
computation_dtype, result_dtype = utils.reduction_dtypes(
|
computation_dtype, result_dtype = utils.reduction_dtypes(
|
||||||
a, output_dtype_kind, dtype
|
a, output_dtype_kind, dtype
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user