mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)
Fixes #151930 This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages. The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg. In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging. Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py). - Verified both successful and failing assertion cases include the operator name. - Verified that generated Triton code contains the op name inside the asserts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152353 Approved by: https://github.com/jansel, https://github.com/shunting314
This commit is contained in:
parent
cc96febb97
commit
10c3e6ec43
|
|
@ -715,6 +715,13 @@ class TestFunctionalAutograd(MultiThreadedTestCase):
|
|||
|
||||
_, codes = run_and_get_code(run_with_backward)
|
||||
for code in codes:
|
||||
assert_keywords = ["assert_size_stride", "assert_alignment"]
|
||||
filtered_lines = [
|
||||
line
|
||||
for line in code.splitlines()
|
||||
if not any(assert_key in line for assert_key in assert_keywords)
|
||||
]
|
||||
code = "\n".join(filtered_lines)
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_to_all_single.default", 1, exactly=True
|
||||
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(
|
||||
|
|
|
|||
|
|
@ -231,6 +231,14 @@ class TestPatternMatcherBase(TestCase):
|
|||
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
|
||||
*clone_inputs,
|
||||
)
|
||||
assert_keywords = ["assert_size_stride", "assert_alignment"]
|
||||
filtered_lines = [
|
||||
line
|
||||
for line in source_code.splitlines()
|
||||
if not any(assert_key in line for assert_key in assert_keywords)
|
||||
]
|
||||
source_code = "\n".join(filtered_lines)
|
||||
|
||||
for op in include_ops:
|
||||
self.assertIn(op, source_code)
|
||||
if num_include_ops is not None:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import torch
|
|||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.aoti_eager
|
||||
import torch.nn as nn
|
||||
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
|
@ -1409,7 +1410,14 @@ class CommonTemplate:
|
|||
)
|
||||
_, code = run_and_get_code(fn, x, y)
|
||||
code = " ".join(code)
|
||||
self.assertEqual(
|
||||
assert_keywords = ["assert_size_stride", "assert_alignment"]
|
||||
filtered_lines = [
|
||||
line
|
||||
for line in code.splitlines()
|
||||
if not any(assert_key in line for assert_key in assert_keywords)
|
||||
]
|
||||
code = "\n".join(filtered_lines)
|
||||
self.assertGreaterEqual(
|
||||
code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3
|
||||
)
|
||||
|
||||
|
|
@ -11923,6 +11931,98 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=False,
|
||||
)
|
||||
|
||||
@requires_gpu()
|
||||
@skip_if_not_triton
|
||||
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_generated_code_has_size_stride_assert(self):
|
||||
def foo(x):
|
||||
return 3 * x
|
||||
|
||||
def foo_meta(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
define_custom_op_for_test("foo", foo, foo_meta)
|
||||
|
||||
def fn(x):
|
||||
a = torch.nn.functional.relu(x)
|
||||
b = torch.ops.test.foo(a)
|
||||
return b
|
||||
|
||||
a = torch.randn((16, 32), device=self.device)
|
||||
|
||||
_, code = run_and_get_code(
|
||||
torch.compile(fn),
|
||||
a,
|
||||
)
|
||||
if not is_dynamic_shape_enabled():
|
||||
if code and len(code) > 0 and "assert_size_stride(" in code[0]:
|
||||
try:
|
||||
FileCheck().check_regex(
|
||||
r"assert_size_stride\s*\(\s*[^,]+,\s*\([^\)]*\),\s*\([^\)]*\),\s*'[^']+'\s*\)"
|
||||
).run(code[0])
|
||||
except Exception as e:
|
||||
print(f"Failed regex match for assert_size_stride: {e}")
|
||||
print(code[0])
|
||||
raise e
|
||||
else:
|
||||
print("Skipping: No assert_size_stride found.")
|
||||
|
||||
@requires_gpu()
|
||||
@skip_if_not_triton
|
||||
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_generated_code_has_alignment_assert(self):
|
||||
def foo(x):
|
||||
return 3 * x
|
||||
|
||||
def foo_meta(x):
|
||||
return torch.empty_like(x)
|
||||
|
||||
define_custom_op_for_test("foo", foo, foo_meta)
|
||||
|
||||
def fn(x):
|
||||
a = torch.nn.functional.relu(x)
|
||||
b = torch.ops.test.foo(a)
|
||||
return b
|
||||
|
||||
a = torch.randn((16, 32), device=self.device)
|
||||
|
||||
_, code = run_and_get_code(
|
||||
torch.compile(fn),
|
||||
a,
|
||||
)
|
||||
if not is_dynamic_shape_enabled():
|
||||
if code and len(code) > 0 and "assert_alignment(" in code[0]:
|
||||
try:
|
||||
FileCheck().check_regex(
|
||||
r"assert_alignment\s*\(\s*[^,]+,\s*[^,]+,\s*'[^']+'\s*\)"
|
||||
).run(code[0])
|
||||
except Exception as e:
|
||||
print(f"Failed regex match for assert_alignment: {e}")
|
||||
print(code[0])
|
||||
raise e
|
||||
else:
|
||||
print("Skipping: No assert_alignment found.")
|
||||
|
||||
def test_assert_size_stride_op_name_pass(self):
|
||||
tensor = torch.empty((16, 32))
|
||||
assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name")
|
||||
|
||||
def test_assert_size_stride_op_name_fail(self):
|
||||
tensor = torch.empty((16, 32))
|
||||
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
|
||||
assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name")
|
||||
|
||||
def test_assert_alignment_op_name_pass(self):
|
||||
tensor = torch.empty((16, 32))
|
||||
assert_alignment(tensor, 16, "torch.ops.dummy.op_name")
|
||||
|
||||
def test_assert_alignment_op_name_fail(self):
|
||||
tensor = torch.empty((16, 32))
|
||||
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
|
||||
assert_alignment(tensor, 0, "torch.ops.dummy.op_name")
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
@torch._inductor.config.patch(implicit_fallbacks=True)
|
||||
def test_custom_op_unbacked_symints(self):
|
||||
|
|
@ -13056,12 +13156,12 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
code = run_and_get_triton_code(f, x)
|
||||
|
||||
if is_dynamic_shape_enabled():
|
||||
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
|
||||
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
|
||||
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check(
|
||||
"assert_size_stride(buf2, (s77, s27), (s27, 1)"
|
||||
).run(code)
|
||||
else:
|
||||
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
|
||||
"assert_size_stride(buf2, (16, 32), (32, 1))"
|
||||
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check(
|
||||
"assert_size_stride(buf2, (16, 32), (32, 1)"
|
||||
).run(code)
|
||||
|
||||
@requires_cuda
|
||||
|
|
|
|||
|
|
@ -176,6 +176,12 @@ def assert_size_stride(
|
|||
item: torch.Tensor,
|
||||
size: torch.types._size,
|
||||
stride: torch.types._size,
|
||||
op_name: str | None = None,
|
||||
): ...
|
||||
def assert_alignment(
|
||||
item: torch.Tensor,
|
||||
alignment: int,
|
||||
op_name: str | None = None,
|
||||
): ...
|
||||
def check_obj_id(obj: object, expected: int) -> bool: ...
|
||||
def check_type_id(obj: object, expected: int) -> bool: ...
|
||||
|
|
|
|||
|
|
@ -5818,6 +5818,17 @@ class ExternKernel(InputsKernel):
|
|||
]
|
||||
return kwargs
|
||||
|
||||
def get_op_name(self) -> str:
|
||||
if self.fx_node is not None:
|
||||
target = self.fx_node.target
|
||||
op_namespace = getattr(target, "__module__", "unknown_namespace")
|
||||
op_namespace = op_namespace.replace("._ops.", ".ops.")
|
||||
op_namespace = op_namespace.rsplit(".", 1)[0]
|
||||
op_name = f"{op_namespace}.{target}"
|
||||
else:
|
||||
op_name = "unknown_op"
|
||||
return op_name
|
||||
|
||||
def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
if config.size_asserts and not V.graph.cpp_wrapper:
|
||||
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
||||
|
|
@ -5825,19 +5836,24 @@ class ExternKernel(InputsKernel):
|
|||
return
|
||||
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
|
||||
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
|
||||
|
||||
op_name = self.get_op_name()
|
||||
wrapper.writeline(
|
||||
f"assert_size_stride({self.get_name()}, {size}, {stride})"
|
||||
f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
|
||||
)
|
||||
|
||||
def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
if config.alignment_asserts and not V.graph.cpp_wrapper:
|
||||
name = self.get_name()
|
||||
aligned = name not in V.graph.unaligned_buffers
|
||||
op_name = self.get_op_name()
|
||||
if aligned:
|
||||
wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})")
|
||||
wrapper.writeline(
|
||||
f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
|
||||
)
|
||||
else:
|
||||
wrapper.writeline(f"# buffer {name} is assumed to be not aligned")
|
||||
wrapper.writeline(
|
||||
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
|
||||
)
|
||||
|
||||
def get_group_stride(self): # type: ignore[no-untyped-def]
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -844,21 +844,38 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
|
|||
PyObject* item = nullptr;
|
||||
PyObject* size = nullptr;
|
||||
PyObject* stride = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
|
||||
const char* op_name = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
|
||||
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
|
||||
std::stringstream msg;
|
||||
msg << "expected Tensor()";
|
||||
if (op_name) {
|
||||
msg << " for op: " << op_name;
|
||||
}
|
||||
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
|
||||
PyErr_SetString(PyExc_TypeError, "expected tuple()");
|
||||
std::stringstream msg;
|
||||
msg << "expected tuple()";
|
||||
if (op_name) {
|
||||
msg << " for op: " << op_name;
|
||||
}
|
||||
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
at::Tensor tensor = THPVariable_Unpack(item);
|
||||
int64_t ndim = tensor.ndimension();
|
||||
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
|
||||
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
|
||||
std::stringstream msg;
|
||||
msg << "wrong number of dimensions" << ndim;
|
||||
if (op_name) {
|
||||
msg << " for op: " << op_name;
|
||||
}
|
||||
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -887,6 +904,9 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
|
|||
}
|
||||
|
||||
if (num_errors) {
|
||||
if (op_name) {
|
||||
msg << "\nError in op: " << op_name;
|
||||
}
|
||||
msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op.";
|
||||
msg << "\nUse torch.library.opcheck to test your custom op.";
|
||||
msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck";
|
||||
|
|
@ -904,15 +924,27 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
|
|||
*/
|
||||
PyObject* item = nullptr;
|
||||
unsigned long alignment = 0;
|
||||
if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) {
|
||||
const char* op_name = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
|
||||
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
|
||||
std::stringstream msg;
|
||||
msg << "expected Tensor()";
|
||||
if (op_name) {
|
||||
msg << " for op: " << op_name;
|
||||
}
|
||||
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (alignment == 0) {
|
||||
PyErr_SetString(PyExc_AssertionError, "alignment can not be 0");
|
||||
std::stringstream msg;
|
||||
msg << "alignment cannot be 0";
|
||||
if (op_name) {
|
||||
msg << " in op: " << op_name;
|
||||
}
|
||||
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -922,7 +954,10 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
|
|||
size_t itemsize = tensor.itemsize();
|
||||
if (storage_offset * itemsize % alignment != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "Expect the tensor to be " << alignment
|
||||
if (op_name) {
|
||||
msg << "\nError in op: " << op_name;
|
||||
}
|
||||
msg << "\nExpect the tensor to be " << alignment
|
||||
<< " bytes aligned. Fail due to storage_offset=" << storage_offset
|
||||
<< " itemsize=" << itemsize;
|
||||
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user