[cond] support mismatched output in inductor (#147567)

In this PR, we extract `codegen_unbacked_symbol_defs` of FallbackKernel out as a `codegen_unbacked_symbol_defs_for_outputs` method in wrapper. With it,  HOPs can support the case where the subgraph returns a tensor with unbacked symints. This PR only do it for cond, we'll have follow up PRs for others (e.g. while_loop) as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147567
Approved by: https://github.com/jansel
This commit is contained in:
Yidi Wu 2025-02-26 14:28:37 -08:00 committed by PyTorch MergeBot
parent d765077004
commit 2d2f60bdda
5 changed files with 183 additions and 73 deletions

View File

@ -1362,6 +1362,33 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_cond_mismatched_branch_output(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dynamic_shapes = None
if dynamic:
# Note the minimum has to be 4 because the model
# is slicing over the first dim with [2:], if first
# dim is 2 or 3, the slicing will be 0/1 specialized,
# causing a constraint violation eror.
dim0_a = Dim("s0", min=4, max=1024)
dim0_b = Dim("s1", min=4, max=1024)
dynamic_shapes = {
"p": {},
"x": {0: dim0_a, 1: None},
"y": {0: dim0_b, 1: None},
"z": {0: dim0_a, 1: None},
}
self.check_model_with_multiple_inputs(
CondModels.MismatchedOutputSize(),
prepend_predicates(inputs),
dynamic_shapes=dynamic_shapes,
)
def test_cond_symint_input(self):
class M(torch.nn.Module):
def forward(self, x, y, z):

View File

@ -141,6 +141,12 @@ CPU_TEST_FAILURES = {
),
# same issue as https://github.com/pytorch/pytorch/issues/122990
"test_cond_non_tensor_predicates_dynamic_True": fail_stack_allocation(is_skip=True),
"test_cond_mismatched_branch_output_dynamic_True": fail_stack_allocation(
is_skip=True
),
"test_cond_mismatched_branch_output_dynamic_False": fail_stack_allocation(
is_skip=True
),
# https://github.com/pytorch/pytorch/issues/122991
"test_runtime_checks_complex": fail_stack_allocation(is_skip=True),
"test_runtime_checks_fp8": fail_stack_allocation(is_skip=True),

View File

@ -196,6 +196,19 @@ class CondModels:
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
class MismatchedOutputSize(torch.nn.Module):
def forward(self, p, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return (x + a)[2:].sin()
def false_fn(x):
return (x + b * z)[:2].cos()
return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
class CondTests(TestCase):
def _run_test(
@ -643,6 +656,21 @@ class CondTests(TestCase):
self.assertEqual(counters["pre_grad"], 11)
self.assertEqual(counters["post_grad"], 11)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
def test_cond_mismatched_branch_output_size(self, device, dynamic):
self._run_test(
model=CondModels.MismatchedOutputSize(),
inputs={
torch.randn(10, 20),
torch.randn(10, 20),
torch.randn(10, 20),
},
device=device,
dynamic=dynamic,
)
class WhileLoopModels:
class Simple(torch.nn.Module):

View File

@ -20,12 +20,19 @@ from sympy import Expr
import torch
import torch._ops
import torch.utils._pytree as pytree
from torch import dtype as torch_dtype
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codegen.debug_utils import DebugPrinterManager
from torch._inductor.codegen.multi_kernel import MultiKernelState
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
from torch.fx.experimental.symbolic_shapes import (
CallMethodKey,
ConvertIntKey,
DivideByKey,
resolve_unbacked_bindings,
SymTypes,
)
from torch.fx.node import _get_qualified_name
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.singleton_int import SingletonInt
@ -2487,6 +2494,86 @@ class PythonWrapperCodegen(CodeGen):
self.unbacked_symbol_decls.add(name)
return self.declare + name
def codegen_unbacked_symbol_defs_for_outputs(
self,
output_name: str,
outputs: Any,
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
) -> None:
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, unbacked_bindings
)
if not unbacked_bindings:
return
# This code is designed to generate code expressions from symbolic paths (keypaths)
# associated with certain symbols (unbacked bindings). These keypaths describe how
# to access the unbacked symbol in a structured way.
# For example, we might want to generate "u0 = outs[0].stride(1)"", where s = u0, and the keypath
# describes the structure of "outs[0].stride(1)", like [SequenceKey(0), CallMethodKey("stride"), SequenceKey[1]].
for s, keypath in unbacked_bindings.items():
# `go` recursively constructs a code expression by processing each element of
# the keypath and construct the expression incrementally.
# For example, given output name outs and keypath [SequenceKey(0), CallMethodKey("stride", 1)],
# it generates "outs[0]" based on SequenceKey(0), then recursively go("outs[0]", [CallMethodKey("stride"), ...])
def go(expr: str, keypath: pytree.KeyPath):
if keypath == ():
return expr
if (
len(keypath) >= 2
and isinstance(keypath[0], CallMethodKey)
and isinstance(keypath[1], pytree.SequenceKey)
):
return go(
f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:]
)
elif isinstance(keypath[0], CallMethodKey):
return go(f"{expr}.{keypath[0].name}()", keypath[1:])
elif isinstance(keypath[0], pytree.SequenceKey):
return (
go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:])
if V.graph.cpp_wrapper
else go(f"{expr}[{keypath[0].idx}]", keypath[1:])
)
elif isinstance(keypath[0], DivideByKey):
# TODO: need to assert divisibility
# TODO: this is invalid C++ codegen
return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:])
else:
raise AssertionError(f"unrecognized keypath {keypath}")
# `go_outer` manages the top-level logic for generating the final expression.
# It handles special cases for C++ code generation and adjusts
# the keypath based on the context (e.g., single vs. multiple outputs).
def go_outer(): # type: ignore[no-untyped-def]
if V.graph.cpp_wrapper:
# Special handling for the top level buffer access,
# because self.get_name() is actually never bound; the
# individual output arguments are bound by
# generate_c_shim_fallback_kernel
if len(outputs) == 1:
out = outputs[0]
# When fallback kernel returns a list consisting of a single tensor,
# the output is represented as a MultiOutput with non empty indices.
# In this case, we strip the first key path away.
return go(
outputs[0].get_name(),
keypath[1:]
if isinstance(out, ir.MultiOutput) and len(out.indices) != 0
else keypath,
)
else:
assert isinstance(keypath[0], pytree.SequenceKey)
return go(outputs[keypath[0].idx].get_name(), keypath[1:])
else:
return go(output_name, keypath)
self.writeline(
f"{self.codegen_unbacked_symbol_decl(s)} = {go_outer()}{self.ending}"
)
def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs):
# TODO (desertfire) - This function is the old way of supporting
# subgraph codegen by inlining subgraphs in the output code. For python

View File

@ -47,9 +47,7 @@ from torch._prims_common import (
)
from torch._subclasses.fake_tensor import get_schema_info
from torch.fx.experimental.symbolic_shapes import (
CallMethodKey,
compute_unbacked_bindings,
DivideByKey,
free_unbacked_symbols,
rebind_unbacked,
resolve_unbacked_bindings,
@ -6488,72 +6486,10 @@ class FallbackKernel(ExternKernelAlloc):
handle_aliasing_and_mutation(info, arg)
def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def]
if not hasattr(self, "unbacked_bindings"):
return
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, self.unbacked_bindings
return wrapper.codegen_unbacked_symbol_defs_for_outputs(
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)
)
if not unbacked_bindings:
return
for s, keypath in unbacked_bindings.items():
def go(expr, keypath): # type: ignore[no-untyped-def]
if keypath == ():
return expr
if (
len(keypath) >= 2
and isinstance(keypath[0], CallMethodKey)
and isinstance(keypath[1], pytree.SequenceKey)
):
return go(
f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:]
)
elif isinstance(keypath[0], CallMethodKey):
return go(f"{expr}.{keypath[0].name}()", keypath[1:])
elif isinstance(keypath[0], pytree.SequenceKey):
return (
go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:])
if V.graph.cpp_wrapper
else go(f"{expr}[{keypath[0].idx}]", keypath[1:])
)
elif isinstance(keypath[0], DivideByKey):
# TODO: need to assert divisibility
# TODO: this is invalid C++ codegen
return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:])
else:
raise AssertionError(f"unrecognized keypath {keypath}")
def go_outer(): # type: ignore[no-untyped-def]
if V.graph.cpp_wrapper:
# Special handling for the top level buffer access,
# because self.get_name() is actually never bound; the
# individual output arguments are bound by
# generate_c_shim_fallback_kernel
if len(self.outputs) == 1:
out = self.outputs[0]
# When fallback kernel returns a list consisting of a single tensor,
# the output is represented as a MultiOutput with non empty indices.
# In this case, we strip the first key path away.
return go(
self.outputs[0].get_name(),
keypath[1:]
if isinstance(out, MultiOutput) and len(out.indices) != 0
else keypath,
)
else:
assert isinstance(keypath[0], pytree.SequenceKey)
return go(self.outputs[keypath[0].idx].get_name(), keypath[1:])
else:
return go(self.get_name(), keypath)
wrapper.writeline(
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}"
)
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
resolved = resolve_unbacked_bindings(
@ -7319,6 +7255,7 @@ class Conditional(ExternKernel):
true_subgraph: Subgraph,
false_subgraph: Subgraph,
layout: MultiOutputLayout,
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
) -> None:
self.predicate = predicate
self.operands = operands
@ -7333,6 +7270,8 @@ class Conditional(ExternKernel):
inputs=tensor_args,
constant_args=sym_args,
)
if unbacked_bindings is not None:
self.unbacked_bindings = unbacked_bindings
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
@ -7374,8 +7313,6 @@ class Conditional(ExternKernel):
# make sure true and false outputs are structurally equivalent
assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)):
assert to.get_size() == fo.get_size(), (i, to, fo)
assert to.get_stride() == fo.get_stride(), (i, to, fo)
assert to.get_device() == fo.get_device(), (i, to, fo)
assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
@ -7385,6 +7322,10 @@ class Conditional(ExternKernel):
for o in [predicate] + operands
if not isinstance(o, ShapeAsConstantBuffer)
)
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env,
V.graph.current_node.meta.get("unbacked_bindings", None),
)
assert device is not None, "cannot determine device"
conditional = Conditional(
predicate=predicate,
@ -7392,15 +7333,21 @@ class Conditional(ExternKernel):
true_subgraph=true_fn,
false_subgraph=false_fn,
layout=MultiOutputLayout(device=device),
unbacked_bindings=unbacked_bindings,
)
def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]:
if isinstance(s, int):
return s
return s.node.expr
outputs = [
MultiOutput(
FixedLayout(
device=output.get_device(),
dtype=output.get_dtype(),
size=output.get_size(),
stride=output.get_stride(),
size=[_maybe_expr(sz) for sz in merged_output.size()],
stride=[_maybe_expr(sz) for sz in merged_output.stride()],
offset=output.get_layout().offset,
),
conditional,
@ -7408,14 +7355,29 @@ class Conditional(ExternKernel):
)
# as the true and false outputs are equivalent,
# we can use either of them here as a "template"
for i, output in enumerate(true_outputs)
for i, (output, merged_output) in enumerate(
zip(true_outputs, V.graph.current_node.meta["val"])
)
]
conditional.outputs = outputs
conditional.outputs = outputs # type: ignore[assignment]
return outputs
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
wrapper.codegen_conditional(self)
wrapper.codegen_unbacked_symbol_defs_for_outputs(
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
)
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
resolved = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, unbacked_bindings
)
assert resolved is not None
return resolved.keys() # type: ignore[return-value]
else:
return OrderedSet()
def _split_by_sym_type(