mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor] fix alignement assumption for fallback (#150777)
Inductor right now only works properly for fallback kernels producing aligned output.
When Inductor create layout for fallback kernel output, Inductor does not add the tensor offset to the layout [link](2a1e2b88ed/torch/_inductor/ir.py (L6935-L6941)). Thus unaligned output will be treated as aligned. Adding the offset to the layout directly does not work since that change the index expression in the generated kernel and we may 'double' applying the offset. Triton already considers the offset when passing in the data_ptr.
To solve this issue, we track the unaligned buffer names instead.
This potentially can fix the internal issues we are debugging here: https://fb.workplace.com/groups/1075192433118967/permalink/1618308128807392/
Differential Revision: [D72600784](https://our.internmc.facebook.com/intern/diff/D72600784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150777
Approved by: https://github.com/eellison, https://github.com/jansel
This commit is contained in:
parent
c36d9b0d8d
commit
901b02cf16
|
|
@ -12895,6 +12895,68 @@ class CommonTemplate:
|
|||
self.common(fn, (1, x))
|
||||
self.common(fn, (2, x))
|
||||
|
||||
def test_unaligned_input(self):
|
||||
def fn(x):
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
x = torch.randn(1024 + 16, device=self.device)[1:-15]
|
||||
self.common(fn, (x,), check_lowp=False)
|
||||
|
||||
def test_unaligned_input_2d(self):
|
||||
def fn(x):
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15]
|
||||
self.common(fn, (x,), check_lowp=False)
|
||||
|
||||
def test_alignment_without_custom_op(self):
|
||||
def fn(x):
|
||||
a = torch.nn.functional.relu(x)
|
||||
b = (3 * a)[1:-15]
|
||||
c = torch.cos(b)
|
||||
return c
|
||||
|
||||
x = torch.randn(1024 + 16, device=self.device)
|
||||
self.common(fn, (x,), check_lowp=False)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_no_align_for_custom_op(self):
|
||||
def slice1d(x):
|
||||
return (3 * x)[1:-15]
|
||||
|
||||
def slice1d_meta(x):
|
||||
return torch.empty_like(x)[1:-15]
|
||||
|
||||
define_custom_op_for_test("slice1d", slice1d, slice1d_meta)
|
||||
|
||||
def fn(x):
|
||||
a = torch.nn.functional.relu(x)
|
||||
b = torch.ops.test.slice1d(a)
|
||||
c = torch.cos(b)
|
||||
return c
|
||||
|
||||
x = torch.randn(1024 + 16, device=self.device)
|
||||
self.common(fn, (x,), check_lowp=False)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_no_align_for_custom_op_2d(self):
|
||||
def slice2d(x):
|
||||
return (3 * x)[..., 1:-15]
|
||||
|
||||
def slice2d_meta(x):
|
||||
return torch.empty_like(x)[..., 1:-15]
|
||||
|
||||
define_custom_op_for_test("slice2d", slice2d, slice2d_meta)
|
||||
|
||||
def fn(x):
|
||||
a = torch.nn.functional.relu(x)
|
||||
b = torch.ops.test.slice2d(a)
|
||||
c = torch.cos(b)
|
||||
return c
|
||||
|
||||
x = torch.randn(1024, 1024 + 16, device=self.device)
|
||||
self.common(fn, (x,), check_lowp=False)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestFailure:
|
||||
|
|
|
|||
|
|
@ -1397,6 +1397,8 @@ class KernelArgs:
|
|||
return self._lookup("out_ptr", self.output_buffers, name)
|
||||
|
||||
def make_inplace(self, input_name: str, output_name: str) -> None:
|
||||
if input_name in V.graph.unaligned_buffers:
|
||||
V.graph.unaligned_buffers.add(output_name)
|
||||
assert output_name not in self.inplace_buffers
|
||||
if input_name in self.inplace_buffers:
|
||||
buf = self.inplace_buffers[input_name]
|
||||
|
|
|
|||
|
|
@ -122,9 +122,14 @@ def signature_to_meta(
|
|||
|
||||
def is_unaligned_buffer(arg: TensorArg):
|
||||
buf_name = arg.buffer
|
||||
if buf_name in V.graph.unaligned_buffers:
|
||||
return True
|
||||
|
||||
if buf_name in V.graph.graph_inputs:
|
||||
# See Note: [Input Alignment handling in Inductor]
|
||||
return buf_name not in V.graph.aligned_inputs
|
||||
# For graph inputs that is not recorded in V.graph.unaligned_buffers,
|
||||
# we know for sure the tensor is aligned.
|
||||
return False
|
||||
|
||||
if buf_name in V.graph.constants:
|
||||
# all constants are assumed to be aligned
|
||||
|
|
|
|||
|
|
@ -78,12 +78,13 @@ if TYPE_CHECKING:
|
|||
pexpr = PythonPrinter().doprint
|
||||
|
||||
|
||||
ReuseKey = tuple[torch.device, torch.dtype, str]
|
||||
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
|
||||
BufferLike = Union[ir.Buffer, WorkspaceArg]
|
||||
|
||||
|
||||
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
|
||||
storage_size = V.graph.get_allocation_storage_size(node)
|
||||
alignment = node.get_name() not in V.graph.unaligned_buffers
|
||||
return (
|
||||
node.get_device_or_error(),
|
||||
node.get_dtype(),
|
||||
|
|
@ -91,6 +92,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey:
|
|||
# for s0 for s1, just because they happen to share the same
|
||||
# size hint
|
||||
sympy_str(V.graph.sizevars.simplify(storage_size)),
|
||||
alignment,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -500,6 +500,9 @@ fallback_random = False
|
|||
|
||||
# automatically create fallbacks when encountering an unhandled op
|
||||
implicit_fallbacks = True
|
||||
assume_unaligned_fallback_output = (
|
||||
os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1"
|
||||
)
|
||||
|
||||
# fuse even in cases without common reads
|
||||
aggressive_fusion = False
|
||||
|
|
@ -1129,7 +1132,7 @@ class triton:
|
|||
) # type: ignore[assignment]
|
||||
|
||||
# hint to Triton when arguments are divisible by 16
|
||||
divisible_by_16 = True
|
||||
divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1"
|
||||
|
||||
# Minimum R0_BLOCK to be used for a TritonSplitScanKernel
|
||||
# NOTE: This also indirectly controls the size of workspace buffer required
|
||||
|
|
|
|||
|
|
@ -430,7 +430,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
|
||||
|
||||
self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
|
||||
self.aligned_inputs: OrderedSet[str] = OrderedSet()
|
||||
# Track the buffers that we know is unaligned
|
||||
# This can either be a graph input or the output of fallback
|
||||
# kernels.
|
||||
self.unaligned_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.no_fuse_buffer_names = OrderedSet[str]()
|
||||
|
||||
self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet()
|
||||
|
|
@ -1116,8 +1119,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
# expensive and cause recompiles; Instead, we're generating code
|
||||
# based on the alignment of the example input without guarding.
|
||||
with maybe_get_suppress_shape_guards_ctx():
|
||||
if should_assume_input_aligned(example):
|
||||
self.aligned_inputs.add(target)
|
||||
if not should_assume_input_aligned(example):
|
||||
self.unaligned_buffers.add(target)
|
||||
return tensor
|
||||
|
||||
def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ from .utils import (
|
|||
sympy_index_symbol_with_prefix,
|
||||
sympy_product,
|
||||
sympy_subs,
|
||||
tensor_is_aligned,
|
||||
)
|
||||
from .virtualized import ops, OpsValue, V
|
||||
|
||||
|
|
@ -6996,11 +6997,16 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
for key, val in output.items()
|
||||
}
|
||||
elif isinstance(output, torch.Tensor):
|
||||
return MultiOutput(
|
||||
buf = MultiOutput(
|
||||
cls.tensor_to_layout(output),
|
||||
packed,
|
||||
indices,
|
||||
)
|
||||
if config.assume_unaligned_fallback_output or not tensor_is_aligned(
|
||||
output
|
||||
):
|
||||
V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
|
||||
return buf
|
||||
elif isinstance(output, int):
|
||||
return output
|
||||
elif isinstance(output, torch.SymInt):
|
||||
|
|
@ -8051,6 +8057,11 @@ class _CollectiveKernel(FallbackKernel):
|
|||
)
|
||||
for i, tensor in enumerate(example_output)
|
||||
]
|
||||
for buf, tensor in zip(packed.outputs, example_output):
|
||||
if config.assume_unaligned_fallback_output or not tensor_is_aligned(
|
||||
tensor
|
||||
):
|
||||
V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
|
||||
return packed.outputs
|
||||
else:
|
||||
packed = cls(
|
||||
|
|
@ -8060,6 +8071,10 @@ class _CollectiveKernel(FallbackKernel):
|
|||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
if config.assume_unaligned_fallback_output or not tensor_is_aligned(
|
||||
example_output
|
||||
):
|
||||
V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type]
|
||||
packed.outputs = [packed]
|
||||
return packed
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user