diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bbec7ab9bbe..8fa9be1966b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12895,6 +12895,68 @@ class CommonTemplate: self.common(fn, (1, x)) self.common(fn, (2, x)) + def test_unaligned_input(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024 + 16, device=self.device)[1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_unaligned_input_2d(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_alignment_without_custom_op(self): + def fn(x): + a = torch.nn.functional.relu(x) + b = (3 * a)[1:-15] + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op(self): + def slice1d(x): + return (3 * x)[1:-15] + + def slice1d_meta(x): + return torch.empty_like(x)[1:-15] + + define_custom_op_for_test("slice1d", slice1d, slice1d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice1d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op_2d(self): + def slice2d(x): + return (3 * x)[..., 1:-15] + + def slice2d_meta(x): + return torch.empty_like(x)[..., 1:-15] + + define_custom_op_for_test("slice2d", slice2d, slice2d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice2d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024, 1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 7fce40e869e..417e215d4f5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1397,6 +1397,8 @@ class KernelArgs: return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) assert output_name not in self.inplace_buffers if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2d5f6a55b4c..ddd4ec51551 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -122,9 +122,14 @@ def signature_to_meta( def is_unaligned_buffer(arg: TensorArg): buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + if buf_name in V.graph.graph_inputs: # See Note: [Input Alignment handling in Inductor] - return buf_name not in V.graph.aligned_inputs + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False if buf_name in V.graph.constants: # all constants are assumed to be aligned diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index d7de4b4f24a..c10831bd827 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -78,12 +78,13 @@ if TYPE_CHECKING: pexpr = PythonPrinter().doprint -ReuseKey = tuple[torch.device, torch.dtype, str] +ReuseKey = tuple[torch.device, torch.dtype, str, bool] BufferLike = Union[ir.Buffer, WorkspaceArg] def buffer_reuse_key(node: BufferLike) -> ReuseKey: storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers return ( node.get_device_or_error(), node.get_dtype(), @@ -91,6 +92,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey: # for s0 for s1, just because they happen to share the same # size hint sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c210af25c16..27b77d199f0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -500,6 +500,9 @@ fallback_random = False # automatically create fallbacks when encountering an unhandled op implicit_fallbacks = True +assume_unaligned_fallback_output = ( + os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" +) # fuse even in cases without common reads aggressive_fusion = False @@ -1129,7 +1132,7 @@ class triton: ) # type: ignore[assignment] # hint to Triton when arguments are divisible by 16 - divisible_by_16 = True + divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" # Minimum R0_BLOCK to be used for a TritonSplitScanKernel # NOTE: This also indirectly controls the size of workspace buffer required diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 38f359c5f25..bfe4d1e960b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -430,7 +430,10 @@ class GraphLowering(torch.fx.Interpreter): self.get_backend_features = functools.lru_cache(None)(get_backend_features) self.effectful_ops: dict[_EffectType, ir.Buffer] = {} - self.aligned_inputs: OrderedSet[str] = OrderedSet() + # Track the buffers that we know is unaligned + # This can either be a graph input or the output of fallback + # kernels. + self.unaligned_buffers: OrderedSet[str] = OrderedSet() self.no_fuse_buffer_names = OrderedSet[str]() self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() @@ -1116,8 +1119,8 @@ class GraphLowering(torch.fx.Interpreter): # expensive and cause recompiles; Instead, we're generating code # based on the alignment of the example input without guarding. with maybe_get_suppress_shape_guards_ctx(): - if should_assume_input_aligned(example): - self.aligned_inputs.add(target) + if not should_assume_input_aligned(example): + self.unaligned_buffers.add(target) return tensor def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 84069bbdf82..a312ea3ca11 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -95,6 +95,7 @@ from .utils import ( sympy_index_symbol_with_prefix, sympy_product, sympy_subs, + tensor_is_aligned, ) from .virtualized import ops, OpsValue, V @@ -6996,11 +6997,16 @@ class FallbackKernel(ExternKernelAlloc): for key, val in output.items() } elif isinstance(output, torch.Tensor): - return MultiOutput( + buf = MultiOutput( cls.tensor_to_layout(output), packed, indices, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + output + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return buf elif isinstance(output, int): return output elif isinstance(output, torch.SymInt): @@ -8051,6 +8057,11 @@ class _CollectiveKernel(FallbackKernel): ) for i, tensor in enumerate(example_output) ] + for buf, tensor in zip(packed.outputs, example_output): + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + tensor + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] return packed.outputs else: packed = cls( @@ -8060,6 +8071,10 @@ class _CollectiveKernel(FallbackKernel): non_tensor_args, unflatten_args, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + example_output + ): + V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type] packed.outputs = [packed] return packed