mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Subgraph check output strides (#153755)
Make sure outputs strides of subgraph consistent with original gm. Without checking strides, it was possible for subgraph to produce nans with a reinterpret tensor on the output of the subgraph output, in which itself was not contiguous. Differential Revision: [D74691119](https://our.internmc.facebook.com/intern/diff/D74691119/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/153755 Approved by: https://github.com/eellison ghstack dependencies: #153754
This commit is contained in:
parent
63e5d46478
commit
a7c01d7f13
|
|
@ -8,6 +8,7 @@ import tempfile
|
|||
import unittest
|
||||
from typing import Callable, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch import multiprocessing as mp, nn
|
||||
|
|
@ -1042,6 +1043,49 @@ class TestMaxAutotune(TestCase):
|
|||
rtol=1e-2,
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
|
||||
)
|
||||
@config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
autotune_fallback_to_aten=False,
|
||||
)
|
||||
def test_max_autotune_decompose_k_output_stride(self):
|
||||
def f(a, b):
|
||||
a = a.transpose(0, 1)
|
||||
return a @ b
|
||||
|
||||
a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
b = b[:, :1096]
|
||||
|
||||
# Force only decomposeK choice
|
||||
with mock.patch(
|
||||
"torch._inductor.kernel.mm.V.choices.get_base_mm_configs"
|
||||
) as base_mm_mock, mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
) as decompose_mock:
|
||||
mm_configs_mock = MagicMock()
|
||||
mm_configs_mock.return_value = []
|
||||
base_mm_mock.return_value = mm_configs_mock
|
||||
decompose_mock.return_value = True
|
||||
compiled_f = torch.compile(f)
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
|
||||
# Output stride equal to original gm output stride
|
||||
# If output stride is not correctly checked, this will be (1152, 1) which can cause nans
|
||||
self.assertEqual(out.stride(), (1096, 1))
|
||||
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check(
|
||||
"decompose_k"
|
||||
).check(" empty_strided_cuda((256, 1096), (1096, 1), torch.bfloat16)").run(
|
||||
code[0]
|
||||
)
|
||||
|
||||
|
||||
class TestMaxAutotunePrecompile(TestCase):
|
||||
def test_precompilation_threads(self):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch._inductor.ir import (
|
|||
add_symbolic_shapes_for_inputs_to_subgraph,
|
||||
Buffer,
|
||||
get_free_symbols,
|
||||
gm_original_output_strides,
|
||||
ir_node_to_tensor,
|
||||
Layout,
|
||||
)
|
||||
|
|
@ -57,6 +58,7 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
|||
import torch._inductor.config as inductor_config
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
gm_original_output_strides(self.gm)
|
||||
bm_graph_lowering = GraphLowering(
|
||||
gm=self.gm,
|
||||
example_inputs=self.example_inputs,
|
||||
|
|
@ -77,6 +79,17 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
|||
int(V.graph.sizevars.shape_env.size_hint(sym_var)) for sym_var in sym_inputs
|
||||
]
|
||||
|
||||
if len(sym_inputs) == 0:
|
||||
# Sanity check that args are same layout as example inputs
|
||||
# Only do it if there are no symbolic inputs, otherwise
|
||||
# the dynamic dim will be realized to the same size as args
|
||||
for ar, example_inp in zip(args, self.example_inputs):
|
||||
# Sanity check that args are same layout as example inputs
|
||||
if isinstance(ar, torch.Tensor):
|
||||
assert isinstance(example_inp, torch.Tensor)
|
||||
assert ar.shape == example_inp.shape
|
||||
assert ar.stride() == example_inp.stride()
|
||||
|
||||
if len(sym_inputs) == 0:
|
||||
# Sanity check that args are same layout as example inputs
|
||||
# Only do it if there are no symbolic inputs, otherwise
|
||||
|
|
|
|||
|
|
@ -219,7 +219,13 @@ def get_static_input_idxs(num_fixed: int) -> list[int]:
|
|||
def record_original_output_strides(gm: GraphModule) -> None:
|
||||
output_node = gm.graph.find_nodes(op="output")[0]
|
||||
output_strides = []
|
||||
for output in output_node.args[0]:
|
||||
|
||||
if not isinstance(output_node.args[0], torch.fx.Node):
|
||||
output_node_args = output_node.args[0]
|
||||
else:
|
||||
output_node_args = output_node.args
|
||||
|
||||
for output in output_node_args:
|
||||
if (
|
||||
isinstance(output, torch.fx.Node)
|
||||
and (val := output.meta.get("val")) is not None
|
||||
|
|
|
|||
|
|
@ -187,7 +187,12 @@ def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
|
|||
if "user_visible_output_idxs" not in output_node.meta:
|
||||
return ret
|
||||
|
||||
for idx, node in enumerate(output_node.args[0]):
|
||||
if not isinstance(output_node.args[0], torch.fx.Node):
|
||||
output_node_args = output_node.args[0]
|
||||
else:
|
||||
output_node_args = output_node.args
|
||||
|
||||
for idx, node in enumerate(output_node_args):
|
||||
if idx in output_node.meta["user_visible_output_idxs"]:
|
||||
ret[node] = output_node.meta["original_output_strides"][idx]
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -490,6 +490,16 @@ def try_match_insignificant_strides(
|
|||
return TensorBox(ReinterpretView(data=storage, layout=new_layout))
|
||||
|
||||
|
||||
def gm_original_output_strides(gm: torch.fx.GraphModule) -> None:
|
||||
output_node = gm.graph.find_nodes(op="output")[0]
|
||||
output_node.meta["user_visible_output_idxs"] = [
|
||||
idx for idx, _ in enumerate(output_node.args)
|
||||
]
|
||||
from torch._inductor.compile_fx import record_original_output_strides
|
||||
|
||||
record_original_output_strides(gm)
|
||||
|
||||
|
||||
def add_symbolic_shapes_for_inputs_to_subgraph(
|
||||
inputs: list[Buffer], subgraph: GraphLowering
|
||||
) -> list[Expr]:
|
||||
|
|
@ -6095,9 +6105,8 @@ class SubgraphBuffer(ExternKernel):
|
|||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
|
||||
self.subgraph = V.graph.make_subgraph(
|
||||
self.gm, self.example_inputs, subgraph_name
|
||||
)
|
||||
gm_original_output_strides(self.gm)
|
||||
self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name)
|
||||
|
||||
sym_inputs = add_symbolic_shapes_for_inputs_to_subgraph(
|
||||
self.inputs, self.subgraph
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user