mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cond] support mismatched output in inductor (#147567)
In this PR, we extract `codegen_unbacked_symbol_defs` of FallbackKernel out as a `codegen_unbacked_symbol_defs_for_outputs` method in wrapper. With it, HOPs can support the case where the subgraph returns a tensor with unbacked symints. This PR only do it for cond, we'll have follow up PRs for others (e.g. while_loop) as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147567 Approved by: https://github.com/jansel
This commit is contained in:
parent
d765077004
commit
2d2f60bdda
|
|
@ -1362,6 +1362,33 @@ class AOTInductorTestsTemplate:
|
|||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_cond_mismatched_branch_output(self, dynamic):
|
||||
inputs = (
|
||||
torch.randn(10, 20, device=self.device),
|
||||
torch.randn(10, 20, device=self.device),
|
||||
torch.randn(10, 20, device=self.device),
|
||||
)
|
||||
dynamic_shapes = None
|
||||
if dynamic:
|
||||
# Note the minimum has to be 4 because the model
|
||||
# is slicing over the first dim with [2:], if first
|
||||
# dim is 2 or 3, the slicing will be 0/1 specialized,
|
||||
# causing a constraint violation eror.
|
||||
dim0_a = Dim("s0", min=4, max=1024)
|
||||
dim0_b = Dim("s1", min=4, max=1024)
|
||||
dynamic_shapes = {
|
||||
"p": {},
|
||||
"x": {0: dim0_a, 1: None},
|
||||
"y": {0: dim0_b, 1: None},
|
||||
"z": {0: dim0_a, 1: None},
|
||||
}
|
||||
self.check_model_with_multiple_inputs(
|
||||
CondModels.MismatchedOutputSize(),
|
||||
prepend_predicates(inputs),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_cond_symint_input(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
|
|
|
|||
|
|
@ -141,6 +141,12 @@ CPU_TEST_FAILURES = {
|
|||
),
|
||||
# same issue as https://github.com/pytorch/pytorch/issues/122990
|
||||
"test_cond_non_tensor_predicates_dynamic_True": fail_stack_allocation(is_skip=True),
|
||||
"test_cond_mismatched_branch_output_dynamic_True": fail_stack_allocation(
|
||||
is_skip=True
|
||||
),
|
||||
"test_cond_mismatched_branch_output_dynamic_False": fail_stack_allocation(
|
||||
is_skip=True
|
||||
),
|
||||
# https://github.com/pytorch/pytorch/issues/122991
|
||||
"test_runtime_checks_complex": fail_stack_allocation(is_skip=True),
|
||||
"test_runtime_checks_fp8": fail_stack_allocation(is_skip=True),
|
||||
|
|
|
|||
|
|
@ -196,6 +196,19 @@ class CondModels:
|
|||
|
||||
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
||||
|
||||
class MismatchedOutputSize(torch.nn.Module):
|
||||
def forward(self, p, x, y, z):
|
||||
a = y.shape[0]
|
||||
b = z.shape[0]
|
||||
|
||||
def true_fn(x):
|
||||
return (x + a)[2:].sin()
|
||||
|
||||
def false_fn(x):
|
||||
return (x + b * z)[:2].cos()
|
||||
|
||||
return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
|
||||
|
||||
|
||||
class CondTests(TestCase):
|
||||
def _run_test(
|
||||
|
|
@ -643,6 +656,21 @@ class CondTests(TestCase):
|
|||
self.assertEqual(counters["pre_grad"], 11)
|
||||
self.assertEqual(counters["post_grad"], 11)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
def test_cond_mismatched_branch_output_size(self, device, dynamic):
|
||||
self._run_test(
|
||||
model=CondModels.MismatchedOutputSize(),
|
||||
inputs={
|
||||
torch.randn(10, 20),
|
||||
torch.randn(10, 20),
|
||||
torch.randn(10, 20),
|
||||
},
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
|
||||
class WhileLoopModels:
|
||||
class Simple(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -20,12 +20,19 @@ from sympy import Expr
|
|||
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import dtype as torch_dtype
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.codegen.debug_utils import DebugPrinterManager
|
||||
from torch._inductor.codegen.multi_kernel import MultiKernelState
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
CallMethodKey,
|
||||
ConvertIntKey,
|
||||
DivideByKey,
|
||||
resolve_unbacked_bindings,
|
||||
SymTypes,
|
||||
)
|
||||
from torch.fx.node import _get_qualified_name
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
|
|
@ -2487,6 +2494,86 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.unbacked_symbol_decls.add(name)
|
||||
return self.declare + name
|
||||
|
||||
def codegen_unbacked_symbol_defs_for_outputs(
|
||||
self,
|
||||
output_name: str,
|
||||
outputs: Any,
|
||||
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
|
||||
) -> None:
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, unbacked_bindings
|
||||
)
|
||||
|
||||
if not unbacked_bindings:
|
||||
return
|
||||
|
||||
# This code is designed to generate code expressions from symbolic paths (keypaths)
|
||||
# associated with certain symbols (unbacked bindings). These keypaths describe how
|
||||
# to access the unbacked symbol in a structured way.
|
||||
# For example, we might want to generate "u0 = outs[0].stride(1)"", where s = u0, and the keypath
|
||||
# describes the structure of "outs[0].stride(1)", like [SequenceKey(0), CallMethodKey("stride"), SequenceKey[1]].
|
||||
for s, keypath in unbacked_bindings.items():
|
||||
# `go` recursively constructs a code expression by processing each element of
|
||||
# the keypath and construct the expression incrementally.
|
||||
# For example, given output name outs and keypath [SequenceKey(0), CallMethodKey("stride", 1)],
|
||||
# it generates "outs[0]" based on SequenceKey(0), then recursively go("outs[0]", [CallMethodKey("stride"), ...])
|
||||
def go(expr: str, keypath: pytree.KeyPath):
|
||||
if keypath == ():
|
||||
return expr
|
||||
|
||||
if (
|
||||
len(keypath) >= 2
|
||||
and isinstance(keypath[0], CallMethodKey)
|
||||
and isinstance(keypath[1], pytree.SequenceKey)
|
||||
):
|
||||
return go(
|
||||
f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:]
|
||||
)
|
||||
elif isinstance(keypath[0], CallMethodKey):
|
||||
return go(f"{expr}.{keypath[0].name}()", keypath[1:])
|
||||
elif isinstance(keypath[0], pytree.SequenceKey):
|
||||
return (
|
||||
go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:])
|
||||
if V.graph.cpp_wrapper
|
||||
else go(f"{expr}[{keypath[0].idx}]", keypath[1:])
|
||||
)
|
||||
elif isinstance(keypath[0], DivideByKey):
|
||||
# TODO: need to assert divisibility
|
||||
# TODO: this is invalid C++ codegen
|
||||
return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:])
|
||||
else:
|
||||
raise AssertionError(f"unrecognized keypath {keypath}")
|
||||
|
||||
# `go_outer` manages the top-level logic for generating the final expression.
|
||||
# It handles special cases for C++ code generation and adjusts
|
||||
# the keypath based on the context (e.g., single vs. multiple outputs).
|
||||
def go_outer(): # type: ignore[no-untyped-def]
|
||||
if V.graph.cpp_wrapper:
|
||||
# Special handling for the top level buffer access,
|
||||
# because self.get_name() is actually never bound; the
|
||||
# individual output arguments are bound by
|
||||
# generate_c_shim_fallback_kernel
|
||||
if len(outputs) == 1:
|
||||
out = outputs[0]
|
||||
# When fallback kernel returns a list consisting of a single tensor,
|
||||
# the output is represented as a MultiOutput with non empty indices.
|
||||
# In this case, we strip the first key path away.
|
||||
return go(
|
||||
outputs[0].get_name(),
|
||||
keypath[1:]
|
||||
if isinstance(out, ir.MultiOutput) and len(out.indices) != 0
|
||||
else keypath,
|
||||
)
|
||||
else:
|
||||
assert isinstance(keypath[0], pytree.SequenceKey)
|
||||
return go(outputs[keypath[0].idx].get_name(), keypath[1:])
|
||||
else:
|
||||
return go(output_name, keypath)
|
||||
|
||||
self.writeline(
|
||||
f"{self.codegen_unbacked_symbol_decl(s)} = {go_outer()}{self.ending}"
|
||||
)
|
||||
|
||||
def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs):
|
||||
# TODO (desertfire) - This function is the old way of supporting
|
||||
# subgraph codegen by inlining subgraphs in the output code. For python
|
||||
|
|
|
|||
|
|
@ -47,9 +47,7 @@ from torch._prims_common import (
|
|||
)
|
||||
from torch._subclasses.fake_tensor import get_schema_info
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
CallMethodKey,
|
||||
compute_unbacked_bindings,
|
||||
DivideByKey,
|
||||
free_unbacked_symbols,
|
||||
rebind_unbacked,
|
||||
resolve_unbacked_bindings,
|
||||
|
|
@ -6488,72 +6486,10 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
handle_aliasing_and_mutation(info, arg)
|
||||
|
||||
def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
if not hasattr(self, "unbacked_bindings"):
|
||||
return
|
||||
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, self.unbacked_bindings
|
||||
return wrapper.codegen_unbacked_symbol_defs_for_outputs(
|
||||
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)
|
||||
)
|
||||
|
||||
if not unbacked_bindings:
|
||||
return
|
||||
|
||||
for s, keypath in unbacked_bindings.items():
|
||||
|
||||
def go(expr, keypath): # type: ignore[no-untyped-def]
|
||||
if keypath == ():
|
||||
return expr
|
||||
|
||||
if (
|
||||
len(keypath) >= 2
|
||||
and isinstance(keypath[0], CallMethodKey)
|
||||
and isinstance(keypath[1], pytree.SequenceKey)
|
||||
):
|
||||
return go(
|
||||
f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:]
|
||||
)
|
||||
elif isinstance(keypath[0], CallMethodKey):
|
||||
return go(f"{expr}.{keypath[0].name}()", keypath[1:])
|
||||
elif isinstance(keypath[0], pytree.SequenceKey):
|
||||
return (
|
||||
go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:])
|
||||
if V.graph.cpp_wrapper
|
||||
else go(f"{expr}[{keypath[0].idx}]", keypath[1:])
|
||||
)
|
||||
elif isinstance(keypath[0], DivideByKey):
|
||||
# TODO: need to assert divisibility
|
||||
# TODO: this is invalid C++ codegen
|
||||
return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:])
|
||||
else:
|
||||
raise AssertionError(f"unrecognized keypath {keypath}")
|
||||
|
||||
def go_outer(): # type: ignore[no-untyped-def]
|
||||
if V.graph.cpp_wrapper:
|
||||
# Special handling for the top level buffer access,
|
||||
# because self.get_name() is actually never bound; the
|
||||
# individual output arguments are bound by
|
||||
# generate_c_shim_fallback_kernel
|
||||
if len(self.outputs) == 1:
|
||||
out = self.outputs[0]
|
||||
# When fallback kernel returns a list consisting of a single tensor,
|
||||
# the output is represented as a MultiOutput with non empty indices.
|
||||
# In this case, we strip the first key path away.
|
||||
return go(
|
||||
self.outputs[0].get_name(),
|
||||
keypath[1:]
|
||||
if isinstance(out, MultiOutput) and len(out.indices) != 0
|
||||
else keypath,
|
||||
)
|
||||
else:
|
||||
assert isinstance(keypath[0], pytree.SequenceKey)
|
||||
return go(self.outputs[keypath[0].idx].get_name(), keypath[1:])
|
||||
else:
|
||||
return go(self.get_name(), keypath)
|
||||
|
||||
wrapper.writeline(
|
||||
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}"
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
||||
resolved = resolve_unbacked_bindings(
|
||||
|
|
@ -7319,6 +7255,7 @@ class Conditional(ExternKernel):
|
|||
true_subgraph: Subgraph,
|
||||
false_subgraph: Subgraph,
|
||||
layout: MultiOutputLayout,
|
||||
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
|
||||
) -> None:
|
||||
self.predicate = predicate
|
||||
self.operands = operands
|
||||
|
|
@ -7333,6 +7270,8 @@ class Conditional(ExternKernel):
|
|||
inputs=tensor_args,
|
||||
constant_args=sym_args,
|
||||
)
|
||||
if unbacked_bindings is not None:
|
||||
self.unbacked_bindings = unbacked_bindings
|
||||
|
||||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
|
|
@ -7374,8 +7313,6 @@ class Conditional(ExternKernel):
|
|||
# make sure true and false outputs are structurally equivalent
|
||||
assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
|
||||
for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)):
|
||||
assert to.get_size() == fo.get_size(), (i, to, fo)
|
||||
assert to.get_stride() == fo.get_stride(), (i, to, fo)
|
||||
assert to.get_device() == fo.get_device(), (i, to, fo)
|
||||
assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
|
||||
assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
|
||||
|
|
@ -7385,6 +7322,10 @@ class Conditional(ExternKernel):
|
|||
for o in [predicate] + operands
|
||||
if not isinstance(o, ShapeAsConstantBuffer)
|
||||
)
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env,
|
||||
V.graph.current_node.meta.get("unbacked_bindings", None),
|
||||
)
|
||||
assert device is not None, "cannot determine device"
|
||||
conditional = Conditional(
|
||||
predicate=predicate,
|
||||
|
|
@ -7392,15 +7333,21 @@ class Conditional(ExternKernel):
|
|||
true_subgraph=true_fn,
|
||||
false_subgraph=false_fn,
|
||||
layout=MultiOutputLayout(device=device),
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
|
||||
def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]:
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
return s.node.expr
|
||||
|
||||
outputs = [
|
||||
MultiOutput(
|
||||
FixedLayout(
|
||||
device=output.get_device(),
|
||||
dtype=output.get_dtype(),
|
||||
size=output.get_size(),
|
||||
stride=output.get_stride(),
|
||||
size=[_maybe_expr(sz) for sz in merged_output.size()],
|
||||
stride=[_maybe_expr(sz) for sz in merged_output.stride()],
|
||||
offset=output.get_layout().offset,
|
||||
),
|
||||
conditional,
|
||||
|
|
@ -7408,14 +7355,29 @@ class Conditional(ExternKernel):
|
|||
)
|
||||
# as the true and false outputs are equivalent,
|
||||
# we can use either of them here as a "template"
|
||||
for i, output in enumerate(true_outputs)
|
||||
for i, (output, merged_output) in enumerate(
|
||||
zip(true_outputs, V.graph.current_node.meta["val"])
|
||||
)
|
||||
]
|
||||
|
||||
conditional.outputs = outputs
|
||||
conditional.outputs = outputs # type: ignore[assignment]
|
||||
return outputs
|
||||
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
wrapper.codegen_conditional(self)
|
||||
wrapper.codegen_unbacked_symbol_defs_for_outputs(
|
||||
self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
||||
resolved = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, unbacked_bindings
|
||||
)
|
||||
assert resolved is not None
|
||||
return resolved.keys() # type: ignore[return-value]
|
||||
else:
|
||||
return OrderedSet()
|
||||
|
||||
|
||||
def _split_by_sym_type(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user