[Inductor] fix alignement assumption for fallback (#150777)

Inductor right now only works properly for fallback kernels producing aligned output.
When Inductor create layout for fallback kernel output, Inductor does not add the tensor offset to the layout [link](2a1e2b88ed/torch/_inductor/ir.py (L6935-L6941)). Thus unaligned output will be treated as aligned. Adding the offset to the layout directly does not work since that change the index expression in the generated kernel and we may 'double' applying the offset. Triton already considers the offset when passing in the data_ptr.

To solve this issue, we track the unaligned buffer names instead.

This potentially can fix the internal issues we are debugging here: https://fb.workplace.com/groups/1075192433118967/permalink/1618308128807392/

Differential Revision: [D72600784](https://our.internmc.facebook.com/intern/diff/D72600784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150777
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Shunting Zhang 2025-04-07 13:50:00 -07:00 committed by PyTorch MergeBot
parent c36d9b0d8d
commit 901b02cf16
7 changed files with 99 additions and 7 deletions

View File

@ -12895,6 +12895,68 @@ class CommonTemplate:
self.common(fn, (1, x))
self.common(fn, (2, x))
def test_unaligned_input(self):
def fn(x):
return torch.nn.functional.relu(x)
x = torch.randn(1024 + 16, device=self.device)[1:-15]
self.common(fn, (x,), check_lowp=False)
def test_unaligned_input_2d(self):
def fn(x):
return torch.nn.functional.relu(x)
x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15]
self.common(fn, (x,), check_lowp=False)
def test_alignment_without_custom_op(self):
def fn(x):
a = torch.nn.functional.relu(x)
b = (3 * a)[1:-15]
c = torch.cos(b)
return c
x = torch.randn(1024 + 16, device=self.device)
self.common(fn, (x,), check_lowp=False)
@config.patch(implicit_fallbacks=True)
def test_no_align_for_custom_op(self):
def slice1d(x):
return (3 * x)[1:-15]
def slice1d_meta(x):
return torch.empty_like(x)[1:-15]
define_custom_op_for_test("slice1d", slice1d, slice1d_meta)
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.test.slice1d(a)
c = torch.cos(b)
return c
x = torch.randn(1024 + 16, device=self.device)
self.common(fn, (x,), check_lowp=False)
@config.patch(implicit_fallbacks=True)
def test_no_align_for_custom_op_2d(self):
def slice2d(x):
return (3 * x)[..., 1:-15]
def slice2d_meta(x):
return torch.empty_like(x)[..., 1:-15]
define_custom_op_for_test("slice2d", slice2d, slice2d_meta)
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.test.slice2d(a)
c = torch.cos(b)
return c
x = torch.randn(1024, 1024 + 16, device=self.device)
self.common(fn, (x,), check_lowp=False)
@dataclasses.dataclass
class TestFailure:

View File

@ -1397,6 +1397,8 @@ class KernelArgs:
return self._lookup("out_ptr", self.output_buffers, name)
def make_inplace(self, input_name: str, output_name: str) -> None:
if input_name in V.graph.unaligned_buffers:
V.graph.unaligned_buffers.add(output_name)
assert output_name not in self.inplace_buffers
if input_name in self.inplace_buffers:
buf = self.inplace_buffers[input_name]

View File

@ -122,9 +122,14 @@ def signature_to_meta(
def is_unaligned_buffer(arg: TensorArg):
buf_name = arg.buffer
if buf_name in V.graph.unaligned_buffers:
return True
if buf_name in V.graph.graph_inputs:
# See Note: [Input Alignment handling in Inductor]
return buf_name not in V.graph.aligned_inputs
# For graph inputs that is not recorded in V.graph.unaligned_buffers,
# we know for sure the tensor is aligned.
return False
if buf_name in V.graph.constants:
# all constants are assumed to be aligned

View File

@ -78,12 +78,13 @@ if TYPE_CHECKING:
pexpr = PythonPrinter().doprint
ReuseKey = tuple[torch.device, torch.dtype, str]
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
BufferLike = Union[ir.Buffer, WorkspaceArg]
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
storage_size = V.graph.get_allocation_storage_size(node)
alignment = node.get_name() not in V.graph.unaligned_buffers
return (
node.get_device_or_error(),
node.get_dtype(),
@ -91,6 +92,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey:
# for s0 for s1, just because they happen to share the same
# size hint
sympy_str(V.graph.sizevars.simplify(storage_size)),
alignment,
)

View File

@ -500,6 +500,9 @@ fallback_random = False
# automatically create fallbacks when encountering an unhandled op
implicit_fallbacks = True
assume_unaligned_fallback_output = (
os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1"
)
# fuse even in cases without common reads
aggressive_fusion = False
@ -1129,7 +1132,7 @@ class triton:
) # type: ignore[assignment]
# hint to Triton when arguments are divisible by 16
divisible_by_16 = True
divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1"
# Minimum R0_BLOCK to be used for a TritonSplitScanKernel
# NOTE: This also indirectly controls the size of workspace buffer required

View File

@ -430,7 +430,10 @@ class GraphLowering(torch.fx.Interpreter):
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
self.aligned_inputs: OrderedSet[str] = OrderedSet()
# Track the buffers that we know is unaligned
# This can either be a graph input or the output of fallback
# kernels.
self.unaligned_buffers: OrderedSet[str] = OrderedSet()
self.no_fuse_buffer_names = OrderedSet[str]()
self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet()
@ -1116,8 +1119,8 @@ class GraphLowering(torch.fx.Interpreter):
# expensive and cause recompiles; Instead, we're generating code
# based on the alignment of the example input without guarding.
with maybe_get_suppress_shape_guards_ctx():
if should_assume_input_aligned(example):
self.aligned_inputs.add(target)
if not should_assume_input_aligned(example):
self.unaligned_buffers.add(target)
return tensor
def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]

View File

@ -95,6 +95,7 @@ from .utils import (
sympy_index_symbol_with_prefix,
sympy_product,
sympy_subs,
tensor_is_aligned,
)
from .virtualized import ops, OpsValue, V
@ -6996,11 +6997,16 @@ class FallbackKernel(ExternKernelAlloc):
for key, val in output.items()
}
elif isinstance(output, torch.Tensor):
return MultiOutput(
buf = MultiOutput(
cls.tensor_to_layout(output),
packed,
indices,
)
if config.assume_unaligned_fallback_output or not tensor_is_aligned(
output
):
V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
return buf
elif isinstance(output, int):
return output
elif isinstance(output, torch.SymInt):
@ -8051,6 +8057,11 @@ class _CollectiveKernel(FallbackKernel):
)
for i, tensor in enumerate(example_output)
]
for buf, tensor in zip(packed.outputs, example_output):
if config.assume_unaligned_fallback_output or not tensor_is_aligned(
tensor
):
V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
return packed.outputs
else:
packed = cls(
@ -8060,6 +8071,10 @@ class _CollectiveKernel(FallbackKernel):
non_tensor_args,
unflatten_args,
)
if config.assume_unaligned_fallback_output or not tensor_is_aligned(
example_output
):
V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type]
packed.outputs = [packed]
return packed