[Inductor] Subgraph check output strides (#153755)

Make sure outputs strides of subgraph consistent with original gm. Without checking strides, it was possible for subgraph to produce nans with a reinterpret tensor on the output of the subgraph output, in which itself was not contiguous.

Differential Revision: [D74691119](https://our.internmc.facebook.com/intern/diff/D74691119/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153755
Approved by: https://github.com/eellison
ghstack dependencies: #153754
This commit is contained in:
PaulZhang12 2025-05-19 22:58:34 -07:00 committed by PyTorch MergeBot
parent 63e5d46478
commit a7c01d7f13
5 changed files with 82 additions and 5 deletions

View File

@ -8,6 +8,7 @@ import tempfile
import unittest
from typing import Callable, Optional
from unittest import mock
from unittest.mock import MagicMock
import torch
from torch import multiprocessing as mp, nn
@ -1042,6 +1043,49 @@ class TestMaxAutotune(TestCase):
rtol=1e-2,
)
@skipIfXpu
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)
@config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
autotune_fallback_to_aten=False,
)
def test_max_autotune_decompose_k_output_stride(self):
def f(a, b):
a = a.transpose(0, 1)
return a @ b
a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16)
b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16)
b = b[:, :1096]
# Force only decomposeK choice
with mock.patch(
"torch._inductor.kernel.mm.V.choices.get_base_mm_configs"
) as base_mm_mock, mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decompose_mock:
mm_configs_mock = MagicMock()
mm_configs_mock.return_value = []
base_mm_mock.return_value = mm_configs_mock
decompose_mock.return_value = True
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)
# Output stride equal to original gm output stride
# If output stride is not correctly checked, this will be (1152, 1) which can cause nans
self.assertEqual(out.stride(), (1096, 1))
FileCheck().check_not("extern_kernels.bmm_dtype").check(
"decompose_k"
).check(" empty_strided_cuda((256, 1096), (1096, 1), torch.bfloat16)").run(
code[0]
)
class TestMaxAutotunePrecompile(TestCase):
def test_precompilation_threads(self):

View File

@ -9,6 +9,7 @@ from torch._inductor.ir import (
add_symbolic_shapes_for_inputs_to_subgraph,
Buffer,
get_free_symbols,
gm_original_output_strides,
ir_node_to_tensor,
Layout,
)
@ -57,6 +58,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
import torch._inductor.config as inductor_config
from torch._inductor.graph import GraphLowering
gm_original_output_strides(self.gm)
bm_graph_lowering = GraphLowering(
gm=self.gm,
example_inputs=self.example_inputs,
@ -77,6 +79,17 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
int(V.graph.sizevars.shape_env.size_hint(sym_var)) for sym_var in sym_inputs
]
if len(sym_inputs) == 0:
# Sanity check that args are same layout as example inputs
# Only do it if there are no symbolic inputs, otherwise
# the dynamic dim will be realized to the same size as args
for ar, example_inp in zip(args, self.example_inputs):
# Sanity check that args are same layout as example inputs
if isinstance(ar, torch.Tensor):
assert isinstance(example_inp, torch.Tensor)
assert ar.shape == example_inp.shape
assert ar.stride() == example_inp.stride()
if len(sym_inputs) == 0:
# Sanity check that args are same layout as example inputs
# Only do it if there are no symbolic inputs, otherwise

View File

@ -219,7 +219,13 @@ def get_static_input_idxs(num_fixed: int) -> list[int]:
def record_original_output_strides(gm: GraphModule) -> None:
output_node = gm.graph.find_nodes(op="output")[0]
output_strides = []
for output in output_node.args[0]:
if not isinstance(output_node.args[0], torch.fx.Node):
output_node_args = output_node.args[0]
else:
output_node_args = output_node.args
for output in output_node_args:
if (
isinstance(output, torch.fx.Node)
and (val := output.meta.get("val")) is not None

View File

@ -187,7 +187,12 @@ def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
if "user_visible_output_idxs" not in output_node.meta:
return ret
for idx, node in enumerate(output_node.args[0]):
if not isinstance(output_node.args[0], torch.fx.Node):
output_node_args = output_node.args[0]
else:
output_node_args = output_node.args
for idx, node in enumerate(output_node_args):
if idx in output_node.meta["user_visible_output_idxs"]:
ret[node] = output_node.meta["original_output_strides"][idx]
return ret

View File

@ -490,6 +490,16 @@ def try_match_insignificant_strides(
return TensorBox(ReinterpretView(data=storage, layout=new_layout))
def gm_original_output_strides(gm: torch.fx.GraphModule) -> None:
output_node = gm.graph.find_nodes(op="output")[0]
output_node.meta["user_visible_output_idxs"] = [
idx for idx, _ in enumerate(output_node.args)
]
from torch._inductor.compile_fx import record_original_output_strides
record_original_output_strides(gm)
def add_symbolic_shapes_for_inputs_to_subgraph(
inputs: list[Buffer], subgraph: GraphLowering
) -> list[Expr]:
@ -6095,9 +6105,8 @@ class SubgraphBuffer(ExternKernel):
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
self.subgraph = V.graph.make_subgraph(
self.gm, self.example_inputs, subgraph_name
)
gm_original_output_strides(self.gm)
self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name)
sym_inputs = add_symbolic_shapes_for_inputs_to_subgraph(
self.inputs, self.subgraph