From 901b02cf16b61824ef6662af4549b81898a127fd Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 7 Apr 2025 13:50:00 -0700 Subject: [PATCH] [Inductor] fix alignement assumption for fallback (#150777) Inductor right now only works properly for fallback kernels producing aligned output. When Inductor create layout for fallback kernel output, Inductor does not add the tensor offset to the layout [link](https://github.com/pytorch/pytorch/blob/2a1e2b88ed7bf7d7436b741ee0c3a2297d7d7bc2/torch/_inductor/ir.py#L6935-L6941). Thus unaligned output will be treated as aligned. Adding the offset to the layout directly does not work since that change the index expression in the generated kernel and we may 'double' applying the offset. Triton already considers the offset when passing in the data_ptr. To solve this issue, we track the unaligned buffer names instead. This potentially can fix the internal issues we are debugging here: https://fb.workplace.com/groups/1075192433118967/permalink/1618308128807392/ Differential Revision: [D72600784](https://our.internmc.facebook.com/intern/diff/D72600784) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150777 Approved by: https://github.com/eellison, https://github.com/jansel --- test/inductor/test_torchinductor.py | 62 +++++++++++++++++++++++++ torch/_inductor/codegen/common.py | 2 + torch/_inductor/codegen/triton_utils.py | 7 ++- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/config.py | 5 +- torch/_inductor/graph.py | 9 ++-- torch/_inductor/ir.py | 17 ++++++- 7 files changed, 99 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bbec7ab9bbe..8fa9be1966b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12895,6 +12895,68 @@ class CommonTemplate: self.common(fn, (1, x)) self.common(fn, (2, x)) + def test_unaligned_input(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024 + 16, device=self.device)[1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_unaligned_input_2d(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_alignment_without_custom_op(self): + def fn(x): + a = torch.nn.functional.relu(x) + b = (3 * a)[1:-15] + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op(self): + def slice1d(x): + return (3 * x)[1:-15] + + def slice1d_meta(x): + return torch.empty_like(x)[1:-15] + + define_custom_op_for_test("slice1d", slice1d, slice1d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice1d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op_2d(self): + def slice2d(x): + return (3 * x)[..., 1:-15] + + def slice2d_meta(x): + return torch.empty_like(x)[..., 1:-15] + + define_custom_op_for_test("slice2d", slice2d, slice2d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice2d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024, 1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 7fce40e869e..417e215d4f5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1397,6 +1397,8 @@ class KernelArgs: return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) assert output_name not in self.inplace_buffers if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2d5f6a55b4c..ddd4ec51551 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -122,9 +122,14 @@ def signature_to_meta( def is_unaligned_buffer(arg: TensorArg): buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + if buf_name in V.graph.graph_inputs: # See Note: [Input Alignment handling in Inductor] - return buf_name not in V.graph.aligned_inputs + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False if buf_name in V.graph.constants: # all constants are assumed to be aligned diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index d7de4b4f24a..c10831bd827 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -78,12 +78,13 @@ if TYPE_CHECKING: pexpr = PythonPrinter().doprint -ReuseKey = tuple[torch.device, torch.dtype, str] +ReuseKey = tuple[torch.device, torch.dtype, str, bool] BufferLike = Union[ir.Buffer, WorkspaceArg] def buffer_reuse_key(node: BufferLike) -> ReuseKey: storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers return ( node.get_device_or_error(), node.get_dtype(), @@ -91,6 +92,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey: # for s0 for s1, just because they happen to share the same # size hint sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c210af25c16..27b77d199f0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -500,6 +500,9 @@ fallback_random = False # automatically create fallbacks when encountering an unhandled op implicit_fallbacks = True +assume_unaligned_fallback_output = ( + os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" +) # fuse even in cases without common reads aggressive_fusion = False @@ -1129,7 +1132,7 @@ class triton: ) # type: ignore[assignment] # hint to Triton when arguments are divisible by 16 - divisible_by_16 = True + divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" # Minimum R0_BLOCK to be used for a TritonSplitScanKernel # NOTE: This also indirectly controls the size of workspace buffer required diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 38f359c5f25..bfe4d1e960b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -430,7 +430,10 @@ class GraphLowering(torch.fx.Interpreter): self.get_backend_features = functools.lru_cache(None)(get_backend_features) self.effectful_ops: dict[_EffectType, ir.Buffer] = {} - self.aligned_inputs: OrderedSet[str] = OrderedSet() + # Track the buffers that we know is unaligned + # This can either be a graph input or the output of fallback + # kernels. + self.unaligned_buffers: OrderedSet[str] = OrderedSet() self.no_fuse_buffer_names = OrderedSet[str]() self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() @@ -1116,8 +1119,8 @@ class GraphLowering(torch.fx.Interpreter): # expensive and cause recompiles; Instead, we're generating code # based on the alignment of the example input without guarding. with maybe_get_suppress_shape_guards_ctx(): - if should_assume_input_aligned(example): - self.aligned_inputs.add(target) + if not should_assume_input_aligned(example): + self.unaligned_buffers.add(target) return tensor def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 84069bbdf82..a312ea3ca11 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -95,6 +95,7 @@ from .utils import ( sympy_index_symbol_with_prefix, sympy_product, sympy_subs, + tensor_is_aligned, ) from .virtualized import ops, OpsValue, V @@ -6996,11 +6997,16 @@ class FallbackKernel(ExternKernelAlloc): for key, val in output.items() } elif isinstance(output, torch.Tensor): - return MultiOutput( + buf = MultiOutput( cls.tensor_to_layout(output), packed, indices, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + output + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return buf elif isinstance(output, int): return output elif isinstance(output, torch.SymInt): @@ -8051,6 +8057,11 @@ class _CollectiveKernel(FallbackKernel): ) for i, tensor in enumerate(example_output) ] + for buf, tensor in zip(packed.outputs, example_output): + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + tensor + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] return packed.outputs else: packed = cls( @@ -8060,6 +8071,10 @@ class _CollectiveKernel(FallbackKernel): non_tensor_args, unflatten_args, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + example_output + ): + V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type] packed.outputs = [packed] return packed