From 3c8b90542ce67be89f235ef5af62da373aa0f09b Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Wed, 17 Sep 2025 17:48:45 -0700 Subject: [PATCH] support unbacked softmax / logsoftmax (#162216) ### DDE ``` GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*u0, 0) (unhinted: Eq(3*u0, 0)). (Size-like symbols: u0) Caused by: (_decomp/decompositions.py:1185 in _softmax) ``` ``` torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, 0) (unhinted: Eq(u0, 0)). (Size-like symbols: u0) Caused by: logsoft = torch.nn.functional.log_softmax(nz, dim=0) # test/inductor/test_unbacked_symints.py:573 in fn (_decomp/decompositions.py:1212 in _log_softmax) ``` ``` GuardOnDataDependentSymNode: Could not guard on data-dependent expression Ne(u0, 0) (unhinted: Ne(u0, 0)). (Size-like symbols: u0) Caused by: (_refs/__init__.py:2218 in _reduction) ``` ### Cannot convert symbols to int ``` File "torch/_inductor/lowering.py", line 7160, in prepare_softmax_online and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "orch/_inductor/sizevars.py", line 591, in size_hint return int(out) ^^^^^^^^ File "sympy/core/expr.py", line 342, in __int__ raise TypeError("Cannot convert symbols to int") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162216 Approved by: https://github.com/laithsakka, https://github.com/eellison --- test/inductor/test_unbacked_symints.py | 16 ++++++++++++++++ torch/_decomp/decompositions.py | 8 ++++++-- torch/_inductor/lowering.py | 5 ++--- torch/_refs/__init__.py | 13 ++++++++----- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index cc9c1251523..b41886a03dd 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -489,6 +489,22 @@ class TestUnbackedSymints(InductorTestCase): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") + @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) + def test_softmax(self, device): + def fn(x): + nz = x.nonzero().float() + soft = torch.softmax(nz, dim=0) + logsoft = torch.nn.functional.log_softmax(nz, dim=0) + return soft * logsoft + + example_inputs = ( + torch.randint(low=0, high=2, size=(32,), device=device, dtype=torch.int8), + ) + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 2a00c57419d..637db6192b1 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1173,6 +1173,8 @@ def native_dropout(input: Tensor, p: float, train: Optional[bool]): @register_decomposition(aten._softmax) @out_wrapper() def _softmax(x: Tensor, dim: int, half_to_float: bool): + from torch.fx.experimental.symbolic_shapes import guard_or_false + # eager softmax returns a contiguous tensor. Ensure that decomp also returns # a contiguous tensor. x = x.contiguous() @@ -1182,7 +1184,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool): x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) x = x.to(computation_dtype) - if x.numel() == 0: + if guard_or_false(x.numel() == 0): unnormalized = torch.exp(x) else: x_max = torch.amax(x, dim, keepdim=True) @@ -1196,6 +1198,8 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool): @register_decomposition(aten._log_softmax) @out_wrapper(exact_dtype=True) def _log_softmax(x: Tensor, dim: int, half_to_float: bool): + from torch.fx.experimental.symbolic_shapes import guard_or_false + # eager log_softmax returns a contiguous tensor. Ensure that decomp also # returns a contiguous tensor. x = x.contiguous() @@ -1205,7 +1209,7 @@ def _log_softmax(x: Tensor, dim: int, half_to_float: bool): x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) x = x.to(computation_dtype) - if x.numel() == 0: + if guard_or_false(x.numel() == 0): shifted = x else: x_max = torch.amax(x, dim, keepdim=True) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 6c5e8ad1ca8..d3d24bfef77 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7220,9 +7220,8 @@ def prepare_softmax_online(x, dim): reduction_numel=rnumel, ) - if ( - num_split == 1 - and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold + if num_split == 1 and V.graph.sizevars.statically_known_geq( + rnumel, config.unroll_reductions_threshold ): max_tensor, sum_tensor = OnlineSoftmaxReduction.create( input_node=x, num_output=2, reduction_hint=hint, **kwargs diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 18455b51941..0155c689249 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2259,11 +2259,14 @@ def _reduction( dims = (dims,) # type: ignore[assignment] dims = utils.reduction_dims(a.shape, dims) if not has_identity: - valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims) - if not valid_shape: - raise RuntimeError( - "reducing over zero-size dimension for reduction operation without identity" - ) + from torch.fx.experimental.symbolic_shapes import sym_and + + valid_shape = a.ndim == 0 or sym_and(*(a.shape[i] > 0 for i in dims)) + torch._check( + valid_shape, + lambda: "reducing over zero-size dimension for reduction operation without identity", + ) + computation_dtype, result_dtype = utils.reduction_dtypes( a, output_dtype_kind, dtype )