[inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)

Fixes #151930

This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages.

The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg.

In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging.

Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py).
- Verified both successful and failing assertion cases include the operator name.
- Verified that generated Triton code contains the op name inside the asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152353
Approved by: https://github.com/jansel, https://github.com/shunting314
This commit is contained in:
karthickai 2025-06-03 19:21:12 +00:00 committed by PyTorch MergeBot
parent cc96febb97
commit 10c3e6ec43
6 changed files with 189 additions and 17 deletions

View File

@ -715,6 +715,13 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
_, codes = run_and_get_code(run_with_backward)
for code in codes:
assert_keywords = ["assert_size_stride", "assert_alignment"]
filtered_lines = [
line
for line in code.splitlines()
if not any(assert_key in line for assert_key in assert_keywords)
]
code = "\n".join(filtered_lines)
FileCheck().check_count(
"_c10d_functional.all_to_all_single.default", 1, exactly=True
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(

View File

@ -231,6 +231,14 @@ class TestPatternMatcherBase(TestCase):
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
*clone_inputs,
)
assert_keywords = ["assert_size_stride", "assert_alignment"]
filtered_lines = [
line
for line in source_code.splitlines()
if not any(assert_key in line for assert_key in assert_keywords)
]
source_code = "\n".join(filtered_lines)
for op in include_ops:
self.assertIn(op, source_code)
if num_include_ops is not None:

View File

@ -30,6 +30,7 @@ import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.aoti_eager
import torch.nn as nn
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._dynamo.device_interface import get_interface_for_device
@ -1409,7 +1410,14 @@ class CommonTemplate:
)
_, code = run_and_get_code(fn, x, y)
code = " ".join(code)
self.assertEqual(
assert_keywords = ["assert_size_stride", "assert_alignment"]
filtered_lines = [
line
for line in code.splitlines()
if not any(assert_key in line for assert_key in assert_keywords)
]
code = "\n".join(filtered_lines)
self.assertGreaterEqual(
code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3
)
@ -11923,6 +11931,98 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=False,
)
@requires_gpu()
@skip_if_not_triton
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
@config.patch(implicit_fallbacks=True)
def test_generated_code_has_size_stride_assert(self):
def foo(x):
return 3 * x
def foo_meta(x):
return torch.empty_like(x)
define_custom_op_for_test("foo", foo, foo_meta)
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.test.foo(a)
return b
a = torch.randn((16, 32), device=self.device)
_, code = run_and_get_code(
torch.compile(fn),
a,
)
if not is_dynamic_shape_enabled():
if code and len(code) > 0 and "assert_size_stride(" in code[0]:
try:
FileCheck().check_regex(
r"assert_size_stride\s*\(\s*[^,]+,\s*\([^\)]*\),\s*\([^\)]*\),\s*'[^']+'\s*\)"
).run(code[0])
except Exception as e:
print(f"Failed regex match for assert_size_stride: {e}")
print(code[0])
raise e
else:
print("Skipping: No assert_size_stride found.")
@requires_gpu()
@skip_if_not_triton
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
@config.patch(implicit_fallbacks=True)
def test_generated_code_has_alignment_assert(self):
def foo(x):
return 3 * x
def foo_meta(x):
return torch.empty_like(x)
define_custom_op_for_test("foo", foo, foo_meta)
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.test.foo(a)
return b
a = torch.randn((16, 32), device=self.device)
_, code = run_and_get_code(
torch.compile(fn),
a,
)
if not is_dynamic_shape_enabled():
if code and len(code) > 0 and "assert_alignment(" in code[0]:
try:
FileCheck().check_regex(
r"assert_alignment\s*\(\s*[^,]+,\s*[^,]+,\s*'[^']+'\s*\)"
).run(code[0])
except Exception as e:
print(f"Failed regex match for assert_alignment: {e}")
print(code[0])
raise e
else:
print("Skipping: No assert_alignment found.")
def test_assert_size_stride_op_name_pass(self):
tensor = torch.empty((16, 32))
assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name")
def test_assert_size_stride_op_name_fail(self):
tensor = torch.empty((16, 32))
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name")
def test_assert_alignment_op_name_pass(self):
tensor = torch.empty((16, 32))
assert_alignment(tensor, 16, "torch.ops.dummy.op_name")
def test_assert_alignment_op_name_fail(self):
tensor = torch.empty((16, 32))
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
assert_alignment(tensor, 0, "torch.ops.dummy.op_name")
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
@torch._inductor.config.patch(implicit_fallbacks=True)
def test_custom_op_unbacked_symints(self):
@ -13056,12 +13156,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
code = run_and_get_triton_code(f, x)
if is_dynamic_shape_enabled():
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check(
"assert_size_stride(buf2, (s77, s27), (s27, 1)"
).run(code)
else:
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
"assert_size_stride(buf2, (16, 32), (32, 1))"
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check(
"assert_size_stride(buf2, (16, 32), (32, 1)"
).run(code)
@requires_cuda

View File

@ -176,6 +176,12 @@ def assert_size_stride(
item: torch.Tensor,
size: torch.types._size,
stride: torch.types._size,
op_name: str | None = None,
): ...
def assert_alignment(
item: torch.Tensor,
alignment: int,
op_name: str | None = None,
): ...
def check_obj_id(obj: object, expected: int) -> bool: ...
def check_type_id(obj: object, expected: int) -> bool: ...

View File

@ -5818,6 +5818,17 @@ class ExternKernel(InputsKernel):
]
return kwargs
def get_op_name(self) -> str:
if self.fx_node is not None:
target = self.fx_node.target
op_namespace = getattr(target, "__module__", "unknown_namespace")
op_namespace = op_namespace.replace("._ops.", ".ops.")
op_namespace = op_namespace.rsplit(".", 1)[0]
op_name = f"{op_namespace}.{target}"
else:
op_name = "unknown_op"
return op_name
def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
if config.size_asserts and not V.graph.cpp_wrapper:
# comparing strides for 0 size tensor is tricky. Ignore them for now.
@ -5825,19 +5836,24 @@ class ExternKernel(InputsKernel):
return
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
op_name = self.get_op_name()
wrapper.writeline(
f"assert_size_stride({self.get_name()}, {size}, {stride})"
f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
)
def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
if config.alignment_asserts and not V.graph.cpp_wrapper:
name = self.get_name()
aligned = name not in V.graph.unaligned_buffers
op_name = self.get_op_name()
if aligned:
wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})")
wrapper.writeline(
f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
)
else:
wrapper.writeline(f"# buffer {name} is assumed to be not aligned")
wrapper.writeline(
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
)
def get_group_stride(self): # type: ignore[no-untyped-def]
"""

View File

@ -844,21 +844,38 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
PyObject* item = nullptr;
PyObject* size = nullptr;
PyObject* stride = nullptr;
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
const char* op_name = nullptr;
if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) {
return nullptr;
}
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
std::stringstream msg;
msg << "expected Tensor()";
if (op_name) {
msg << " for op: " << op_name;
}
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
return nullptr;
}
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
std::stringstream msg;
msg << "expected tuple()";
if (op_name) {
msg << " for op: " << op_name;
}
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
return nullptr;
}
at::Tensor tensor = THPVariable_Unpack(item);
int64_t ndim = tensor.ndimension();
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
std::stringstream msg;
msg << "wrong number of dimensions" << ndim;
if (op_name) {
msg << " for op: " << op_name;
}
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
return nullptr;
}
@ -887,6 +904,9 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
}
if (num_errors) {
if (op_name) {
msg << "\nError in op: " << op_name;
}
msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op.";
msg << "\nUse torch.library.opcheck to test your custom op.";
msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck";
@ -904,15 +924,27 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
*/
PyObject* item = nullptr;
unsigned long alignment = 0;
if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) {
const char* op_name = nullptr;
if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) {
return nullptr;
}
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
std::stringstream msg;
msg << "expected Tensor()";
if (op_name) {
msg << " for op: " << op_name;
}
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
return nullptr;
}
if (alignment == 0) {
PyErr_SetString(PyExc_AssertionError, "alignment can not be 0");
std::stringstream msg;
msg << "alignment cannot be 0";
if (op_name) {
msg << " in op: " << op_name;
}
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
return nullptr;
}
@ -922,7 +954,10 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
size_t itemsize = tensor.itemsize();
if (storage_offset * itemsize % alignment != 0) {
std::stringstream msg;
msg << "Expect the tensor to be " << alignment
if (op_name) {
msg << "\nError in op: " << op_name;
}
msg << "\nExpect the tensor to be " << alignment
<< " bytes aligned. Fail due to storage_offset=" << storage_offset
<< " itemsize=" << itemsize;
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());