mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Inductor] FX backend via Wrapper IR (#146942)"
This reverts commita7691140a0. Reverted https://github.com/pytorch/pytorch/pull/146942 on behalf of https://github.com/malfet due to Looks like it indeed breaks lint, seea7691140a0/1([comment](https://github.com/pytorch/pytorch/pull/146942#issuecomment-2852192778))
This commit is contained in:
parent
a7691140a0
commit
99dac7005f
|
|
@ -1,417 +0,0 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Test the FX IR backend.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import unittest
|
||||
from typing import Callable, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor.codegen.common as common
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.exc import BackendCompilerFailed
|
||||
from torch._dynamo.utils import same
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.common import register_backend_for_device
|
||||
from torch._inductor.codegen.cpp import CppScheduling
|
||||
from torch._inductor.codegen.triton import TritonScheduling
|
||||
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_GPU,
|
||||
requires_gpu,
|
||||
TRITON_HAS_CPU,
|
||||
)
|
||||
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch(
|
||||
compile_threads=1,
|
||||
alignment_asserts=False,
|
||||
size_asserts=False,
|
||||
scalar_asserts=False,
|
||||
nan_asserts=False,
|
||||
)
|
||||
class FxirTestCase(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int:
|
||||
return len(gm.graph.find_nodes(op="call_function", target=target))
|
||||
|
||||
def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule:
|
||||
gms = []
|
||||
|
||||
orig_generate = FxConverter.generate
|
||||
|
||||
def generate(self) -> torch.fx.GraphModule:
|
||||
nonlocal gms
|
||||
gm = orig_generate(self)
|
||||
gms.append(gm)
|
||||
return gm
|
||||
|
||||
with unittest.mock.patch.object(
|
||||
torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate
|
||||
):
|
||||
opt(*args)
|
||||
|
||||
return gms
|
||||
|
||||
def _compile_and_check(
|
||||
self,
|
||||
func,
|
||||
args,
|
||||
expected_num_triton_kernels: int = 1,
|
||||
metadata_only: bool = False,
|
||||
compile_kwargs: Optional[dict] = None,
|
||||
):
|
||||
if compile_kwargs is None:
|
||||
compile_kwargs = {}
|
||||
|
||||
opt = torch.compile(func, **compile_kwargs)
|
||||
|
||||
# Get the FX graph from the backend.
|
||||
gms = self._run_and_capture_graphs(opt, args)
|
||||
|
||||
# Check the code for triton kernels.
|
||||
num_kernels = sum(
|
||||
self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms
|
||||
)
|
||||
self.assertEqual(num_kernels, expected_num_triton_kernels)
|
||||
|
||||
# Check accuracy.
|
||||
result = opt(*args)
|
||||
ref = func(*args)
|
||||
if metadata_only:
|
||||
# When we only want to check metadata, fill in zeros for tensor data.
|
||||
ref, result = tuple(
|
||||
pytree.tree_map(torch.zeros_like, x) for x in (ref, result)
|
||||
)
|
||||
|
||||
self.assertTrue(same(ref, result))
|
||||
|
||||
return gms
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# Register the FX backend.
|
||||
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
|
||||
|
||||
def test_basic(self):
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(torch.add, args)
|
||||
|
||||
def test_multiple_kernels(self):
|
||||
def foo(x, y):
|
||||
return x.sum() + y.sum()
|
||||
|
||||
args = [torch.randn(length, device=self.device) for length in [517, 1029]]
|
||||
self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
||||
|
||||
def test_free(self):
|
||||
"""
|
||||
Test a program that frees a buffer which is no longer in use.
|
||||
"""
|
||||
|
||||
def foo(x, y, z):
|
||||
w = x.sum() + y
|
||||
return z.sum() + w.sum()
|
||||
|
||||
args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3)
|
||||
|
||||
# Check the generated code for frees.
|
||||
num_frees = gm.code.count("= None")
|
||||
self.assertGreater(num_frees, 0)
|
||||
|
||||
def test_extern(self):
|
||||
"""
|
||||
Test a program that calls an extern kernel.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return x @ y + y.sum()
|
||||
|
||||
args = [
|
||||
torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)]
|
||||
]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
# Check for the extern kernel
|
||||
num_extern = self._count_ops(gm, extern_kernels.addmm)
|
||||
self.assertEqual(num_extern, 1)
|
||||
|
||||
def test_fallback(self):
|
||||
"""
|
||||
Test a program that calls an aten fallback.
|
||||
"""
|
||||
|
||||
length = 8
|
||||
|
||||
def foo(x):
|
||||
return x + torch.randn(1, device=self.device)
|
||||
|
||||
args = (torch.randn(length, device=self.device),)
|
||||
|
||||
# Since the program has a random output, just check metadata.
|
||||
# Don't check for an exact value.
|
||||
(gm,) = self._compile_and_check(
|
||||
foo, args, expected_num_triton_kernels=2, metadata_only=True
|
||||
)
|
||||
|
||||
# Check for the fallback kernel.
|
||||
num_fallback = self._count_ops(gm, torch.ops.aten.randint.low_out)
|
||||
self.assertEqual(num_fallback, 1)
|
||||
|
||||
def test_cat_inputs(self):
|
||||
"""
|
||||
Test concatenation of graph inputs.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return torch.cat((x, y)) + 1
|
||||
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
def test_cat_to_alloc(self):
|
||||
"""
|
||||
Test concatenation that's optimized out to an allocation.
|
||||
"""
|
||||
length = 8
|
||||
|
||||
def foo(x):
|
||||
y, z = tuple(
|
||||
torch.arange(length // 2, device=self.device) for _ in range(2)
|
||||
)
|
||||
return x + torch.cat((y, z))
|
||||
|
||||
args = [torch.randn(length, device=self.device)]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
# Expect a single allocation, even though eager mode would use 2.
|
||||
num_allocs = self._count_ops(gm, torch.empty_strided)
|
||||
self.assertEqual(num_allocs, 1)
|
||||
|
||||
def test_cat_reinterpret_view(self):
|
||||
"""
|
||||
Test torch.cat using ReinterpretView.
|
||||
"""
|
||||
length = 8
|
||||
|
||||
def foo(x):
|
||||
y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2))
|
||||
return x + torch.cat((y, z))
|
||||
|
||||
args = [torch.randn(length, device=self.device)]
|
||||
|
||||
# Since this test generates random numbers, check metadata only.
|
||||
(gm,) = self._compile_and_check(
|
||||
foo, args, expected_num_triton_kernels=3, metadata_only=True
|
||||
)
|
||||
|
||||
# Check for as_strided. We map ReinterpretView to this.
|
||||
num_as_strided = self._count_ops(gm, torch.as_strided)
|
||||
self.assertEqual(num_as_strided, 2)
|
||||
|
||||
def test_reshape_output(self):
|
||||
"""
|
||||
Test reshaping the output, which maps to a ReinterpretView.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return torch.reshape(x + y, (8,))
|
||||
|
||||
args = [torch.randn((2, 4), device=self.device) for _ in range(2)]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
|
||||
|
||||
# Check for as_strided. We map ReinterpretView to this.
|
||||
num_as_strided = self._count_ops(gm, torch.as_strided)
|
||||
self.assertEqual(num_as_strided, 1)
|
||||
|
||||
def test_extern_multi_output(self):
|
||||
"""
|
||||
Test an extern kernel with multiple outputs.
|
||||
Also test a graph with multiple outputs.
|
||||
"""
|
||||
|
||||
def foo(x):
|
||||
top, idx = torch.topk(x, 2)
|
||||
return top + 1, idx * 2
|
||||
|
||||
args = [torch.randn(8, device=self.device)]
|
||||
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
|
||||
|
||||
# Check for multiple kernel outputs via getitems.
|
||||
num_getitems = self._count_ops(gm, operator.getitem)
|
||||
self.assertEqual(num_getitems, 2)
|
||||
|
||||
# Check for multiple graph outputs.
|
||||
output_node = gm.graph.find_nodes(op="output")[0]
|
||||
self.assertEqual(len(output_node.args[0]), 2)
|
||||
|
||||
def test_duplicate_input(self):
|
||||
"""
|
||||
Test duplicated inputs. This will collapse into a single input in the GM.
|
||||
"""
|
||||
|
||||
args = [torch.randn(4, device=self.device)] * 2
|
||||
(gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1)
|
||||
|
||||
num_placeholders = len(gm.graph.find_nodes(op="placeholder"))
|
||||
self.assertEqual(num_placeholders, 1)
|
||||
|
||||
def test_backward(self):
|
||||
"""
|
||||
Test a program with a backward pass.
|
||||
"""
|
||||
|
||||
x = torch.ones(5, device=self.device) # input tensor
|
||||
y = torch.zeros(3, device=self.device) # expected output
|
||||
w = torch.randn(5, 3, requires_grad=True, device=self.device)
|
||||
b = torch.randn(3, requires_grad=True, device=self.device)
|
||||
|
||||
def foo(x, y):
|
||||
z = torch.matmul(x, w) + b
|
||||
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
|
||||
loss.backward()
|
||||
return w.grad, b.grad
|
||||
|
||||
# Expect separate forward and backward graphs.
|
||||
(forward_gm, backward_gm) = self._compile_and_check(
|
||||
foo, (x, y), expected_num_triton_kernels=3
|
||||
)
|
||||
|
||||
def test_custom_compiler(self):
|
||||
"""
|
||||
Test a derived backend with a custom compiler.
|
||||
"""
|
||||
offset = 1
|
||||
|
||||
class CustomWrapperCodegen(WrapperFxCodegen):
|
||||
def compile_graph(self, gm):
|
||||
def compiled_fn(*args):
|
||||
# Adds an offset to the program's outputs.
|
||||
outputs = gm(*args)
|
||||
return pytree.tree_map(lambda x: x + 1, outputs)
|
||||
|
||||
return compiled_fn
|
||||
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
custom_backend = common.DeviceCodegen(
|
||||
TritonScheduling, CustomWrapperCodegen, None
|
||||
)
|
||||
with unittest.mock.patch.dict(
|
||||
common.device_codegens, {self.device: custom_backend}
|
||||
):
|
||||
func = torch.add
|
||||
opt = torch.compile(func)
|
||||
result = opt(*args)
|
||||
|
||||
# Check the output is offset from eager mode.
|
||||
ref = func(*args)
|
||||
self.assertFalse(same(result, ref))
|
||||
self.assertNotEqual(offset, 0)
|
||||
self.assertTrue(same(result - offset, ref))
|
||||
|
||||
def test_dynamic_shapes_and_strides(self):
|
||||
"""
|
||||
Test a graph with dynamic shapes and strides.
|
||||
"""
|
||||
|
||||
static_dims = (8, 8)
|
||||
|
||||
def get_input():
|
||||
full_size = (16, 8)
|
||||
full = torch.randn(full_size, device=self.device)
|
||||
view = torch.as_strided(full, static_dims, full.stride())
|
||||
return view
|
||||
|
||||
func = torch.add
|
||||
args = [get_input() for _ in range(2)]
|
||||
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
|
||||
|
||||
# Check for a symbolic output shape.
|
||||
(empty_strided,) = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.empty_strided
|
||||
)
|
||||
example_tensor = empty_strided.meta["val"]
|
||||
symbolic_dims = example_tensor.shape
|
||||
self.assertEqual(len(symbolic_dims), len(static_dims))
|
||||
|
||||
# Check for symbolic output strides.
|
||||
(stride, one) = example_tensor.stride()
|
||||
self.assertEqual(one, sympy.S.One)
|
||||
|
||||
# Find the size symbols, and check for a corresponding placeholders defining them.
|
||||
for symbol in itertools.chain(symbolic_dims, [stride]):
|
||||
self.assertTrue(isinstance(symbol, torch.SymInt))
|
||||
(placeholder,) = [
|
||||
node
|
||||
for node in gm.graph.find_nodes(op="placeholder")
|
||||
if node.name == str(symbol)
|
||||
]
|
||||
self.assertEqual(placeholder.meta["val"], symbol)
|
||||
|
||||
@config.patch({"trace.enabled": True})
|
||||
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
|
||||
def test_debug(self, mock_output_code):
|
||||
# Compile in debug mode.
|
||||
args = [torch.randn(11, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(torch.sub, args)
|
||||
|
||||
# Check the output code for a Triton kernel call.
|
||||
mock_output_code.assert_called_once()
|
||||
(output_filename,) = mock_output_code.call_args.args
|
||||
with open(output_filename) as f:
|
||||
output_code = f.read()
|
||||
self.assertIn("triton_kernel_wrapper_mutation", output_code)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_subgraph_raises(self):
|
||||
"""
|
||||
Test a model with subgraphs. This is not yet supported, so check that we get the
|
||||
expected exception.
|
||||
"""
|
||||
|
||||
def foo(cond, x):
|
||||
return torch.cond(cond, torch.cos, torch.sin, [x])
|
||||
|
||||
cond = torch.tensor([True], device=self.device)
|
||||
x = torch.ones([2, 3], device=self.device)
|
||||
|
||||
with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"):
|
||||
self._compile_and_check(foo, [cond, x])
|
||||
|
||||
def test_cpp_raises(self):
|
||||
"""
|
||||
Test the C++ CPU backend. C++ kernels are not yet supported, so for now check
|
||||
that we get the expected exception.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return x + y * 5
|
||||
|
||||
device = torch.device("cpu")
|
||||
args = [torch.randn(5, device=device) for _ in range(2)]
|
||||
|
||||
cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None)
|
||||
with unittest.mock.patch.dict(
|
||||
common.device_codegens, {device.type: cpp_backend}
|
||||
), self.assertRaisesRegex(BackendCompilerFailed, "Triton"):
|
||||
self._compile_and_check(foo, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU or TRITON_HAS_CPU:
|
||||
run_tests(needs="filelock")
|
||||
|
|
@ -1750,27 +1750,6 @@ class TracingTritonHOPifier(TritonHOPifier):
|
|||
# normalize to tuple
|
||||
return tuple(grid)
|
||||
|
||||
def store_non_graphable_args(
|
||||
self,
|
||||
combined_args: dict[str, Any],
|
||||
) -> tuple[dict, int]:
|
||||
"""
|
||||
Some args cannot be stored in the FX graph.
|
||||
Put them in the side table.
|
||||
"""
|
||||
|
||||
def is_graphable(val: Any) -> bool:
|
||||
return isinstance(val, (fx.node.base_types, fx.Node))
|
||||
|
||||
non_graphable_args = {
|
||||
k: v for k, v in combined_args.items() if not is_graphable(v)
|
||||
}
|
||||
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
|
||||
|
||||
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
|
||||
|
||||
return graphable_args, constant_args_idx
|
||||
|
||||
def call_HOP(
|
||||
self,
|
||||
variable: "TraceableTritonKernelWrapper",
|
||||
|
|
@ -1781,8 +1760,15 @@ class TracingTritonHOPifier(TritonHOPifier):
|
|||
assert tx is None
|
||||
assert isinstance(variable, TraceableTritonKernelWrapper)
|
||||
|
||||
graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
|
||||
def is_graphable(val: Any) -> bool:
|
||||
return isinstance(val, fx.node.base_types)
|
||||
|
||||
non_graphable_args = {
|
||||
k: v for k, v in combined_args.items() if not is_graphable(v)
|
||||
}
|
||||
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
|
||||
|
||||
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
|
||||
assert isinstance(variable.kernel_idx, int)
|
||||
return triton_kernel_wrapper_mutation(
|
||||
kernel_idx=variable.kernel_idx,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
|
|
@ -9,11 +8,8 @@ import itertools
|
|||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto, Enum
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
|
|
@ -64,8 +60,6 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, MutableMapping, Sequence
|
||||
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
||||
|
|
@ -89,38 +83,6 @@ def data_type_logger(msg: str) -> None:
|
|||
schedule_log.debug("Data type propagation: %s", msg)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileBackedGraphModule:
|
||||
"""
|
||||
Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
|
||||
map back to a GraphModule instead of Python source.
|
||||
"""
|
||||
|
||||
gm: GraphModule
|
||||
compiled_fn: Callable[..., Any]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Write the code to a file for compatibility with debugging utilities.
|
||||
# The file is deleted upon program termination.
|
||||
self.tempfile = tempfile.NamedTemporaryFile(
|
||||
mode="w+", suffix=".py", delete=False
|
||||
)
|
||||
atexit.register(os.remove, self.tempfile.name)
|
||||
with self.tempfile as f:
|
||||
f.write(self.value)
|
||||
|
||||
@property
|
||||
def __file__(self) -> str:
|
||||
return self.tempfile.name
|
||||
|
||||
def call(self, args: list[Any]) -> Any:
|
||||
return self.compiled_fn(*args)
|
||||
|
||||
@property
|
||||
def value(self) -> str:
|
||||
return self.gm.code
|
||||
|
||||
|
||||
class WorkspaceZeroMode(enum.Enum):
|
||||
UNINITIALIZED = 0
|
||||
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
|
||||
|
|
@ -141,22 +103,8 @@ class WorkspaceZeroMode(enum.Enum):
|
|||
return WorkspaceZeroMode.UNINITIALIZED
|
||||
|
||||
|
||||
class CodegenSymbol(ABC):
|
||||
"""
|
||||
An IR object possibly corresponding to a variable in the wrapper code.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
pass
|
||||
|
||||
|
||||
@ir_dataclass(frozen=True)
|
||||
class WorkspaceArg(CodegenSymbol):
|
||||
class WorkspaceArg:
|
||||
"""A temporary buffer used for a single kernel, then discarded.
|
||||
|
||||
Not registered as a traditional buffer since there are no users,
|
||||
|
|
@ -219,9 +167,6 @@ class WorkspaceArg(CodegenSymbol):
|
|||
def get_dtype(self) -> torch.dtype:
|
||||
return self.dtype
|
||||
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
return self.get_layout().get_example()
|
||||
|
||||
def get_layout(self) -> FixedLayout:
|
||||
from ..ir import FixedLayout
|
||||
|
||||
|
|
@ -240,9 +185,6 @@ class WorkspaceArg(CodegenSymbol):
|
|||
maybe_get_output_spec = get_layout
|
||||
maybe_get_layout = get_layout
|
||||
|
||||
def get_offset(self) -> sympy.Expr:
|
||||
return sympy.S.Zero
|
||||
|
||||
def get_size(self) -> list[sympy.Expr]:
|
||||
return [self.count]
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ if TYPE_CHECKING:
|
|||
import triton
|
||||
|
||||
from ..graph import GraphLowering
|
||||
from .wrapper_fxir import FxConverter
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -84,7 +83,6 @@ pexpr = PythonPrinter().doprint
|
|||
|
||||
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
|
||||
BufferLike = Union[ir.Buffer, WorkspaceArg]
|
||||
FxConversionFunc = Callable[["WrapperLine"], None]
|
||||
|
||||
|
||||
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
|
||||
|
|
@ -351,8 +349,7 @@ class MemoryPlanningState:
|
|||
|
||||
|
||||
class WrapperLine:
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -367,9 +364,6 @@ class EnterSubgraphLine(WrapperLine):
|
|||
self.wrapper.push_codegened_graph(self.graph)
|
||||
code.do_indent()
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_enter_subgraph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommentLine(WrapperLine):
|
||||
|
|
@ -378,10 +372,6 @@ class CommentLine(WrapperLine):
|
|||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
code.writeline(self.line)
|
||||
|
||||
@staticmethod
|
||||
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_comment
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExitSubgraphLine(WrapperLine):
|
||||
|
|
@ -394,9 +384,6 @@ class ExitSubgraphLine(WrapperLine):
|
|||
self.wrapper.pop_codegened_graph()
|
||||
code.do_unindent()
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_exit_subgraph
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EnterDeviceContextManagerLine(WrapperLine):
|
||||
|
|
@ -432,18 +419,12 @@ class EnterDeviceContextManagerLine(WrapperLine):
|
|||
code.do_indent()
|
||||
code.writeline(V.graph.device_ops.set_device(self.device_idx))
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_enter_device_context_manager
|
||||
|
||||
|
||||
class ExitDeviceContextManagerLine(WrapperLine):
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
if not V.graph.cpp_wrapper:
|
||||
code.do_unindent()
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_exit_device_context_manager
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelAllocLine(WrapperLine):
|
||||
|
|
@ -455,9 +436,6 @@ class ExternKernelAllocLine(WrapperLine):
|
|||
args = [*node.codegen_args(), *node.codegen_kwargs()]
|
||||
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_extern_kernel_alloc
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExternKernelOutLine(WrapperLine):
|
||||
|
|
@ -488,9 +466,6 @@ class ExternKernelOutLine(WrapperLine):
|
|||
device,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_extern_kernel_out
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FreeLine(WrapperLine):
|
||||
|
|
@ -501,9 +476,6 @@ class FreeLine(WrapperLine):
|
|||
assert self.node.get_name() not in V.graph.removed_buffers
|
||||
code.writeline(self.wrapper.make_buffer_free(self.node))
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_free
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KernelCallLine(WrapperLine):
|
||||
|
|
@ -533,9 +505,6 @@ class KernelCallLine(WrapperLine):
|
|||
original_fxnode_name=self.original_fxnode_name,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_kernel_call
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KernelDefinitionLine(WrapperLine):
|
||||
|
|
@ -555,9 +524,6 @@ class KernelDefinitionLine(WrapperLine):
|
|||
cpp_definition=self.cpp_definition,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_kernel_definition
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryPlanningLine(WrapperLine):
|
||||
|
|
@ -614,9 +580,6 @@ class AllocateLine(MemoryPlanningLine):
|
|||
line = self.wrapper.make_buffer_allocation(self.node)
|
||||
code.writeline(line)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_allocate
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FreeIfNotReusedLine(MemoryPlanningLine):
|
||||
|
|
@ -640,9 +603,6 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
|
|||
if not self.is_reused:
|
||||
code.writeline(self.wrapper.make_buffer_free(self.node))
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_free_if_not_reused
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReinterpretLine(MemoryPlanningLine):
|
||||
|
|
@ -660,9 +620,6 @@ class ReinterpretLine(MemoryPlanningLine):
|
|||
self.reused_as.get_name(), self.layout.view
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_reinterpret
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReuseLine(MemoryPlanningLine):
|
||||
|
|
@ -684,13 +641,9 @@ class ReuseLine(MemoryPlanningLine):
|
|||
self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_reuse
|
||||
|
||||
|
||||
class NullLine(MemoryPlanningLine):
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_null
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -764,9 +717,6 @@ class CommBufferAllocateLine(CommBufferLine):
|
|||
f"Unsupported comm buffer type: {comm_buffer_type}"
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_comm_buffer_allocate
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommBufferFreeLine(CommBufferLine):
|
||||
|
|
@ -774,9 +724,6 @@ class CommBufferFreeLine(CommBufferLine):
|
|||
line = self.wrapper.make_buffer_free(self.node)
|
||||
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_comm_buffer_free
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultiOutputLine(WrapperLine):
|
||||
|
|
@ -813,22 +760,6 @@ class MultiOutputLine(WrapperLine):
|
|||
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_multi_output
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolicCallArgLine(WrapperLine):
|
||||
wrapper: PythonWrapperCodegen
|
||||
arg: SymbolicCallArg
|
||||
graph: GraphLowering
|
||||
|
||||
def codegen(self, code: IndentedBuffer) -> None:
|
||||
self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
return converter._generate_symbolic_call_arg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolicCallArgLine(WrapperLine):
|
||||
|
|
|
|||
|
|
@ -1,596 +0,0 @@
|
|||
import dataclasses
|
||||
import operator
|
||||
import textwrap
|
||||
from collections import Counter
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
TraceableTritonKernelWrapper,
|
||||
tracing_triton_hopifier_singleton,
|
||||
triton_kernel_wrapper_mutation,
|
||||
)
|
||||
from torch._inductor.codecache import PyCodeCache
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._library.triton import wrap_triton
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from .. import ir
|
||||
from ..utils import convert_shape_to_symint, convert_to_symint, LineContext
|
||||
from .common import (
|
||||
CodegenSymbol,
|
||||
FileBackedGraphModule,
|
||||
WorkspaceArg,
|
||||
WorkspaceZeroMode,
|
||||
)
|
||||
from .wrapper import (
|
||||
AllocateLine,
|
||||
BufferLike,
|
||||
CommBufferAllocateLine,
|
||||
CommBufferFreeLine,
|
||||
CommentLine,
|
||||
EnterDeviceContextManagerLine,
|
||||
EnterSubgraphLine,
|
||||
ExitDeviceContextManagerLine,
|
||||
ExitSubgraphLine,
|
||||
ExternKernelAllocLine,
|
||||
ExternKernelOutLine,
|
||||
FreeIfNotReusedLine,
|
||||
FreeLine,
|
||||
KernelCallLine,
|
||||
KernelDefinitionLine,
|
||||
Line,
|
||||
MultiOutputLine,
|
||||
NullLine,
|
||||
PythonWrapperCodegen,
|
||||
ReinterpretLine,
|
||||
ReuseLine,
|
||||
SymbolicCallArg,
|
||||
SymbolicCallArgLine,
|
||||
WrapperLine,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SymbolBuffer(CodegenSymbol):
|
||||
"""
|
||||
Represents a sympy.Symbol graph input.
|
||||
"""
|
||||
|
||||
symbol: sympy.Symbol
|
||||
|
||||
def get_name(self) -> str:
|
||||
return str(self.symbol)
|
||||
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
return self.symbol
|
||||
|
||||
|
||||
CodegenBuffer = Union[BufferLike, SymbolBuffer]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TritonKernel:
|
||||
"""
|
||||
Stores metadata about Triton kernels for use in FX.
|
||||
"""
|
||||
|
||||
tuner: CachingAutotuner
|
||||
wrapped: TraceableTritonKernelWrapper
|
||||
|
||||
|
||||
class WrapperFxCodegen(PythonWrapperCodegen):
|
||||
"""
|
||||
Backend to generate wrapper code as an FX IR graph.
|
||||
"""
|
||||
|
||||
supports_caching = False
|
||||
|
||||
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
|
||||
self.run_wrapper_ir_passes(is_inference)
|
||||
|
||||
prologue = "\n".join(
|
||||
[
|
||||
self.imports.getvalue(),
|
||||
self.header.getvalue(),
|
||||
]
|
||||
)
|
||||
gm = FxConverter(lines=self.lines, prologue=prologue).generate()
|
||||
compiled_fn = self.compile_graph(gm)
|
||||
|
||||
return FileBackedGraphModule(gm, compiled_fn), None
|
||||
|
||||
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
|
||||
"""
|
||||
Converts the graph module into a runnable function. The default implementation
|
||||
is simply an interpreter calling kernels in eager mode. Derived backends can
|
||||
override this to do further compilation.
|
||||
"""
|
||||
return gm.forward
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
is_subgraph: bool,
|
||||
subgraph_name: Optional[str],
|
||||
parent_wrapper: Optional[PythonWrapperCodegen],
|
||||
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
|
||||
) -> "WrapperFxCodegen":
|
||||
if is_subgraph:
|
||||
raise NotImplementedError(
|
||||
"Subgraphs are not yet supported by FX conversion"
|
||||
)
|
||||
|
||||
# For derived backends, this could be a subclass.
|
||||
return cls()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FxConverter:
|
||||
"""
|
||||
Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the
|
||||
input and output code are stored as attributes.
|
||||
"""
|
||||
|
||||
lines: list[Line]
|
||||
prologue: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
graph = torch.fx.Graph()
|
||||
self.gm = GraphModule({}, graph) # Wrapper FX IR.
|
||||
self.buffer_to_node: dict[
|
||||
Optional[str], torch.fx.Node
|
||||
] = {} # Symbol table for codegen.
|
||||
self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels.
|
||||
self._unique_symbol_ids: Counter[str] = Counter()
|
||||
|
||||
def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
|
||||
"""
|
||||
Imports a kernel from source, possibly autotuning block parameters.
|
||||
"""
|
||||
module_code = "\n".join([self.prologue, code])
|
||||
mod = PyCodeCache.load(module_code)
|
||||
kernel = getattr(mod, kernel_name)
|
||||
|
||||
if not isinstance(kernel, CachingAutotuner):
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(f"""
|
||||
Unsupported type for kernel {kernel_name}: {type(kernel)}.
|
||||
FX conversion only supports Triton kernels.
|
||||
""")
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
||||
def _fake_tensor(
|
||||
self,
|
||||
size: tuple[Any, ...],
|
||||
stride: tuple[Any, ...],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
with V.fake_mode:
|
||||
return torch.empty_strided(
|
||||
convert_shape_to_symint(size),
|
||||
convert_shape_to_symint(stride),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _create_meta_from_buffer(
|
||||
self, node: torch.fx.Node, buffer: CodegenBuffer
|
||||
) -> None:
|
||||
name = buffer.get_name()
|
||||
assert name
|
||||
node.name = name
|
||||
node.meta["val"] = buffer.get_example()
|
||||
|
||||
def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None:
|
||||
"""
|
||||
Updates the symbol table to record that an Inductor buffer maps to the result of
|
||||
an FX node.
|
||||
"""
|
||||
assert node not in self.buffer_to_node
|
||||
self.buffer_to_node[buffer.get_name()] = node
|
||||
|
||||
def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None:
|
||||
"""
|
||||
Removes the buffer from the symbol table.
|
||||
"""
|
||||
name = buffer.get_name()
|
||||
del self.buffer_to_node[name]
|
||||
|
||||
def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
"""
|
||||
Maps call args back to FX nodes.
|
||||
"""
|
||||
return tuple(
|
||||
self.buffer_to_node[arg]
|
||||
if isinstance(arg, str)
|
||||
else arg.inner_expr
|
||||
if isinstance(arg, SymbolicCallArg)
|
||||
else arg
|
||||
for arg in args
|
||||
)
|
||||
|
||||
def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
|
||||
"""
|
||||
Extract buffer data from an IR node.
|
||||
"""
|
||||
if isinstance(node, (ir.Buffer, WorkspaceArg)):
|
||||
return node
|
||||
elif isinstance(node, (ir.BaseView, ir.MutableBox)):
|
||||
return self._get_buffer(node.data)
|
||||
elif isinstance(node, sympy.Symbol):
|
||||
return SymbolBuffer(node)
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to extract buffer from node: {node}")
|
||||
|
||||
def _generate_graph_inputs(self) -> None:
|
||||
"""
|
||||
Converts graph inputs to FX placeholders.
|
||||
"""
|
||||
for ir_node in V.graph.graph_inputs.values():
|
||||
buffer = self._get_buffer(ir_node)
|
||||
node = self.gm.graph.placeholder(buffer.get_name())
|
||||
self._create_meta_from_buffer(node, buffer)
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
|
||||
"""
|
||||
Generates FX IR for transformations on a buffer, such as ReinterpretView.
|
||||
Does nothing if no such transformations are present.
|
||||
"""
|
||||
|
||||
def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
|
||||
if isinstance(node, (ir.Buffer, WorkspaceArg)):
|
||||
return node
|
||||
elif isinstance(node, ir.NoneAsConstantBuffer):
|
||||
return None
|
||||
elif isinstance(node, ir.StorageBox):
|
||||
return generate_to_buffer(node.data)
|
||||
elif isinstance(node, ir.ReinterpretView):
|
||||
# We need to introduce a new symbol if the output is a ReinterpretView.
|
||||
# Use a WorkspaceArg for this.
|
||||
buffer = self._get_buffer(node.data)
|
||||
assert isinstance(buffer, (ir.Buffer, WorkspaceArg))
|
||||
unique_name = self.gm.graph._graph_namespace.create_name(
|
||||
f"{buffer.get_name()}_view", None
|
||||
)
|
||||
device = buffer.get_device()
|
||||
assert device
|
||||
reused_as = WorkspaceArg(
|
||||
count=buffer.get_size(),
|
||||
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
|
||||
device=device,
|
||||
outer_name=unique_name,
|
||||
dtype=buffer.get_dtype(),
|
||||
)
|
||||
|
||||
# Generate FX IR for the view.
|
||||
self._generate_reinterpret_helper(buffer, reused_as, node.layout)
|
||||
|
||||
return reused_as
|
||||
else:
|
||||
raise NotImplementedError(f"Unrecognized buffer/view node: {node}")
|
||||
|
||||
buffer = generate_to_buffer(node)
|
||||
return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
|
||||
|
||||
def _generate_output(self) -> None:
|
||||
"""
|
||||
Generate FX IR for graph outputs.
|
||||
"""
|
||||
output_nodes = [
|
||||
self._generate_buffer(node)
|
||||
for idx, node in enumerate(V.graph.graph_outputs)
|
||||
]
|
||||
|
||||
# Single return elements don't use a tuple.
|
||||
output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes
|
||||
|
||||
self.gm.graph.output(output_value)
|
||||
|
||||
def generate(self) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Main entrypoint for FX codegen.
|
||||
"""
|
||||
self._generate_graph_inputs()
|
||||
|
||||
# Generate FX IR from Wrapper IR lines.
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
line.codegen_fx(self)(line)
|
||||
elif isinstance(line, LineContext):
|
||||
# Ignore line context in FX IR.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Found line of unrecognized type '{type(line)}':
|
||||
'{line}'
|
||||
|
||||
FX conversion only supports Wrapper IR lines.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
self._generate_output()
|
||||
self.gm.recompile()
|
||||
return self.gm
|
||||
|
||||
def _generate_allocate(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, AllocateLine)
|
||||
buffer = line.node
|
||||
name = buffer.get_name()
|
||||
assert name not in V.graph.removed_buffers
|
||||
|
||||
device = buffer.get_device()
|
||||
dtype = buffer.get_dtype()
|
||||
shape = convert_shape_to_symint(buffer.get_size())
|
||||
stride = convert_shape_to_symint(buffer.get_stride())
|
||||
|
||||
node = self.gm.graph.call_function(
|
||||
torch.empty_strided,
|
||||
args=(shape, stride),
|
||||
kwargs={"dtype": dtype, "device": device},
|
||||
)
|
||||
assert name
|
||||
node.name = name
|
||||
self._create_meta_from_buffer(node, buffer)
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
def _generate_comment(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommentLine)
|
||||
# We ignore comments in FX IR.
|
||||
|
||||
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, EnterDeviceContextManagerLine)
|
||||
# We ignore the device context in FX IR.
|
||||
|
||||
def _generate_exit_device_context_manager(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExitDeviceContextManagerLine)
|
||||
# We ignore the device context in FX IR.
|
||||
|
||||
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, EnterSubgraphLine)
|
||||
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
|
||||
|
||||
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExitSubgraphLine)
|
||||
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
|
||||
|
||||
def _generate_free(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, FreeLine)
|
||||
|
||||
buf = line.node
|
||||
|
||||
# No need to free placeholders.
|
||||
if self.buffer_to_node[buf.get_name()].op == "placeholder":
|
||||
return
|
||||
|
||||
self._free(buf)
|
||||
|
||||
def _generate_free_if_not_reused(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, FreeIfNotReusedLine)
|
||||
buf = line.node
|
||||
assert buf.get_name() not in V.graph.removed_buffers
|
||||
if not line.is_reused:
|
||||
self._free(buf)
|
||||
|
||||
def _generate_line_context(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, LineContext)
|
||||
# We ignore line context in FX IR.
|
||||
|
||||
def _generate_reinterpret(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ReinterpretLine)
|
||||
self._generate_reinterpret_helper(line.node, line.reused_as, line.layout)
|
||||
|
||||
def _generate_reinterpret_helper(
|
||||
self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout
|
||||
) -> None:
|
||||
input_node = self.buffer_to_node[input_buffer.get_name()]
|
||||
|
||||
# Look up output metadata.
|
||||
name = result_buffer.get_name()
|
||||
assert name
|
||||
size = tuple(layout.size)
|
||||
stride = tuple(layout.stride)
|
||||
offset = input_buffer.get_offset() + layout.offset
|
||||
|
||||
# Map ReinterpretView to as_strided.
|
||||
result_node = self.gm.graph.call_function(
|
||||
torch.as_strided, args=(input_node, size, stride, offset)
|
||||
)
|
||||
result_node.name = name
|
||||
result_node.meta["val"] = layout.get_example()
|
||||
self._record_allocation(result_buffer, result_node)
|
||||
|
||||
def _generate_reuse(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ReuseLine)
|
||||
old = line.node
|
||||
new = line.reused_as
|
||||
assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new))
|
||||
assert old.get_dtype() == new.get_dtype()
|
||||
|
||||
old_node = self.buffer_to_node[old.get_name()]
|
||||
result_node = old_node
|
||||
|
||||
# Change shape and stride.
|
||||
size = new.get_size()
|
||||
stride = new.get_stride()
|
||||
offset = new.get_offset()
|
||||
if (
|
||||
old.get_size() != size
|
||||
or old.get_stride() != stride
|
||||
or old.get_offset() != offset
|
||||
):
|
||||
result_node = self.gm.graph.call_function(
|
||||
torch.as_strided, args=(old_node, size, stride, offset)
|
||||
)
|
||||
self._create_meta_from_buffer(result_node, new)
|
||||
|
||||
self._record_allocation(new, result_node)
|
||||
|
||||
# Free the old buffer, if we allocated a new tensor.
|
||||
if (
|
||||
old.get_name() not in V.graph.get_output_names()
|
||||
and line.delete_old
|
||||
and result_node is not old_node
|
||||
):
|
||||
self._free(old)
|
||||
|
||||
def _generate_multi_output(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, MultiOutputLine)
|
||||
|
||||
# Extract the index for tuple access.
|
||||
inds = line.indices[0][1:]
|
||||
assert len(inds) == 1, f"Cannot convert {inds} to an index."
|
||||
idx = inds[0]
|
||||
|
||||
arg_node = self.buffer_to_node[line.arg_name]
|
||||
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
|
||||
node.meta["val"] = arg_node.meta["val"][idx]
|
||||
node.name = line.result_name
|
||||
self.buffer_to_node[line.result_name] = node
|
||||
|
||||
def _generate_null(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, NullLine)
|
||||
# Does nothing.
|
||||
|
||||
def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommBufferAllocateLine)
|
||||
raise NotImplementedError("Comm buffer allocation is not yet supported")
|
||||
|
||||
def _generate_comm_buffer_free(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, CommBufferFreeLine)
|
||||
self._free(line.node)
|
||||
|
||||
def _generate_triton_call(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelCallLine)
|
||||
|
||||
# Collect all kwargs, including autotuned block sizes.
|
||||
call_args = self._lookup_args(line.call_args)
|
||||
kernel = self.kernels[line.kernel_name]
|
||||
tuner = kernel.tuner
|
||||
config = tuner.compile_results[0].config
|
||||
call_args, grid = tuner._interpret_args_grid(call_args, config)
|
||||
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
|
||||
call_kwargs.update(config.kwargs)
|
||||
|
||||
# Convert sympy expressions to symints.
|
||||
for name, val in call_kwargs.items():
|
||||
if isinstance(val, sympy.Expr):
|
||||
call_kwargs[name] = convert_to_symint(val)
|
||||
|
||||
# Store non-graphable kwargs in the side table.
|
||||
(
|
||||
call_kwargs,
|
||||
constant_args_idx,
|
||||
) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs)
|
||||
|
||||
self.gm.graph.call_function(
|
||||
triton_kernel_wrapper_mutation,
|
||||
kwargs={
|
||||
"kernel_idx": kernel.wrapped.kernel_idx,
|
||||
"constant_args_idx": constant_args_idx,
|
||||
"grid": [convert_shape_to_symint(grid)],
|
||||
"tma_descriptor_metadata": {},
|
||||
"kwargs": call_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, ExternKernelAllocLine)
|
||||
node = line.node
|
||||
self._generate_extern_kernel_common(node, node)
|
||||
|
||||
def _generate_extern_kernel_out(
|
||||
self,
|
||||
line: WrapperLine,
|
||||
) -> None:
|
||||
assert isinstance(line, ExternKernelOutLine)
|
||||
node = line.node
|
||||
out_node = node.output_view if node.output_view else node
|
||||
self._generate_extern_kernel_common(node, out_node)
|
||||
|
||||
def _generate_extern_kernel_common(
|
||||
self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode
|
||||
) -> None:
|
||||
"""
|
||||
Generates FX IR from either ExternKernelAlloc or ExternKernelOut.
|
||||
"""
|
||||
|
||||
# Get FX nodes corresponding to the call args.
|
||||
tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs)
|
||||
args = tensor_nodes + tuple(kernel.constant_args)
|
||||
|
||||
# Get the result buffer.
|
||||
# Some kernels write to a pre-existing output tensor via the "out" kwarg.
|
||||
kwargs = kernel.kwargs.copy()
|
||||
result_buffer: Optional[str] = None
|
||||
if isinstance(kernel, ir.ExternKernelOut):
|
||||
kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()]
|
||||
elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)):
|
||||
result_buffer = kernel.get_name()
|
||||
elif isinstance(kernel.layout, ir.NoneLayout):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}")
|
||||
|
||||
# Look up the kernel function from its name.
|
||||
kernel_name = kernel.get_kernel_name()
|
||||
module_name, kernel_name = kernel_name.split(".", 1)
|
||||
op = globals()[module_name] # E.g. extern_kernels, aten, etc.
|
||||
for subname in kernel_name.split("."):
|
||||
op = getattr(op, subname) # E.g. extern_kernels.addmm
|
||||
|
||||
fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs)
|
||||
|
||||
# Assign the result to the given name.
|
||||
if result_buffer:
|
||||
assert "out" not in kwargs, (
|
||||
f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one."
|
||||
)
|
||||
fx_node.name = result_buffer
|
||||
self.buffer_to_node[result_buffer] = fx_node
|
||||
|
||||
arg_tensors = [
|
||||
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in args
|
||||
]
|
||||
|
||||
# Run the operation to propagate metadata.
|
||||
fx_node.meta["val"] = op(*arg_tensors, **kwargs)
|
||||
|
||||
def _generate_kernel_call(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelCallLine)
|
||||
if not line.triton:
|
||||
raise NotImplementedError("FX conversion only supports Triton kernels.")
|
||||
|
||||
self._generate_triton_call(line)
|
||||
|
||||
def _generate_kernel_definition(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelDefinitionLine)
|
||||
|
||||
# Generate code for the kernel.
|
||||
kernel_code = PythonWrapperCodegen._format_kernel_definition(
|
||||
line.kernel_name, line.kernel_body, metadata=line.metadata
|
||||
)
|
||||
|
||||
# Import the module and store the JIT kernel.
|
||||
tuner = self._import_kernel(kernel_code, line.kernel_name)
|
||||
wrapped = wrap_triton(tuner.fn)
|
||||
self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped)
|
||||
|
||||
def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, SymbolicCallArgLine)
|
||||
# No need for an FX node, as we will pass the arg to kernels via a SymInt.
|
||||
|
|
@ -50,7 +50,6 @@ from . import config, ir, metrics
|
|||
from .codegen.common import (
|
||||
BackendFeature,
|
||||
DeviceOpOverrides,
|
||||
FileBackedGraphModule,
|
||||
get_backend_features,
|
||||
get_device_op_overrides,
|
||||
get_wrapper_codegen_for_device,
|
||||
|
|
@ -116,12 +115,9 @@ if TYPE_CHECKING:
|
|||
from torch._higher_order_ops.effects import _EffectType
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from .codegen.wrapper import PythonWrapperCodegen
|
||||
from .scheduler import BaseSchedulerNode
|
||||
|
||||
CompiledModule = Union[ModuleType, FileBackedGraphModule]
|
||||
|
||||
from torch._inductor.codecache import output_code_log
|
||||
|
||||
|
||||
|
|
@ -2228,7 +2224,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
# No-op to be patched for unit tests
|
||||
save_output_code: Optional[Callable[[str], None]] = None
|
||||
|
||||
def compile_to_module(self) -> CompiledModule:
|
||||
def compile_to_module(self) -> ModuleType:
|
||||
with dynamo_timed(
|
||||
"GraphLowering.compile_to_module",
|
||||
phase_name="code_gen",
|
||||
|
|
@ -2237,41 +2233,14 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
):
|
||||
return self._compile_to_module()
|
||||
|
||||
def _compile_to_module(self) -> CompiledModule:
|
||||
def _compile_to_module(self) -> ModuleType:
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
# Currently, if we're here, we don't have to worry about the kernel code, which
|
||||
# is only available in AOTInductor mode.
|
||||
wrapper_code, _ = (
|
||||
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
|
||||
)
|
||||
|
||||
if isinstance(wrapper_code, ValueWithLineMap):
|
||||
mod = self._compile_to_module_lines(wrapper_code)
|
||||
elif isinstance(wrapper_code, FileBackedGraphModule):
|
||||
mod = wrapper_code
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unrecognized wrapper code type: {type(wrapper_code)}"
|
||||
)
|
||||
|
||||
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
||||
# TODO. Revisit this once the logging API is more mature
|
||||
assert mod.__file__ is not None
|
||||
|
||||
log_module_code(mod.__file__)
|
||||
log.debug("Output code written to: %s", mod.__file__)
|
||||
output_code_log.info("Output code written to: %s", mod.__file__)
|
||||
if config.benchmark_kernel:
|
||||
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
||||
V.debug.output_code(mod.__file__)
|
||||
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
||||
|
||||
return mod
|
||||
|
||||
def _compile_to_module_lines(
|
||||
self, wrapper_code: ValueWithLineMap
|
||||
) -> CompiledModule:
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
tuning_code = (
|
||||
'"""\n'
|
||||
|
|
@ -2322,7 +2291,17 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
if config.benchmark_harness and config.profile_bandwidth_output:
|
||||
# run the inputs code gen to get the bandwidth info
|
||||
mod.benchmark_compiled_module(times=1, repeat=1)
|
||||
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
|
||||
# TODO. Revisit this once the logging API is more mature
|
||||
assert mod.__file__ is not None
|
||||
|
||||
log_module_code(mod.__file__)
|
||||
log.debug("Output code written to: %s", mod.__file__)
|
||||
output_code_log.info("Output code written to: %s", mod.__file__)
|
||||
if config.benchmark_kernel:
|
||||
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
|
||||
V.debug.output_code(mod.__file__)
|
||||
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
||||
return mod
|
||||
|
||||
def get_output_names(self) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ from torch.utils._sympy.symbol import SymT
|
|||
from . import config, dependencies
|
||||
from .codegen.common import (
|
||||
BackendFeature,
|
||||
CodegenSymbol,
|
||||
get_scheduling_for_device,
|
||||
index_prevent_reordering,
|
||||
)
|
||||
|
|
@ -3424,15 +3423,6 @@ class Layout(OutputSpec):
|
|||
def get_device(self) -> torch.device:
|
||||
return self.device
|
||||
|
||||
def get_example(self) -> torch.Tensor:
|
||||
with V.fake_mode:
|
||||
return torch.empty_strided(
|
||||
convert_shape_to_symint(self.size),
|
||||
convert_shape_to_symint(self.stride),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def is_contiguous(self) -> bool:
|
||||
return is_contiguous_strides_for_shape(self.stride, self.size)
|
||||
|
||||
|
|
@ -3936,7 +3926,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
|||
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class Buffer(IRNode, CodegenSymbol):
|
||||
class Buffer(IRNode):
|
||||
# Name is sometimes None; e.g., ForceInPlace, where there isn't
|
||||
# a meaningful name
|
||||
name: Optional[str]
|
||||
|
|
@ -3956,11 +3946,6 @@ class Buffer(IRNode, CodegenSymbol):
|
|||
assert self.name, self
|
||||
return self.name
|
||||
|
||||
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
|
||||
if isinstance(self.layout, Layout):
|
||||
return self.layout.get_example()
|
||||
raise NotImplementedError(type(self.layout).__name__)
|
||||
|
||||
def get_device(self) -> Optional[torch.device]:
|
||||
return self.get_output_spec().get_device()
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class NoTritonConfigsError(RuntimeError):
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Hashable
|
||||
from collections.abc import Container, Hashable, Sequence
|
||||
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
|
@ -2564,15 +2564,13 @@ class GridExpr:
|
|||
|
||||
inductor_meta: dict[str, Any]
|
||||
mode: Literal["python", "cpp"] = "python"
|
||||
prefix: list[str] = dataclasses.field(default_factory=list)
|
||||
prefix: Sequence[str] = ()
|
||||
x_grid: Union[str, int] = 1
|
||||
y_grid: Union[str, int] = 1
|
||||
z_grid: Union[str, int] = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.mode in ("python", "cpp")
|
||||
if self.mode == "python":
|
||||
self.prefix.append("from torch.utils._sympy.functions import FloorDiv")
|
||||
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
@ -2585,9 +2583,7 @@ class GridExpr:
|
|||
if isinstance(numel, int) and isinstance(block, int):
|
||||
return ceildiv(numel, block) # constant fold
|
||||
if self.mode == "python":
|
||||
# Use FloorDiv instead of // so we can get better sympy expressions for
|
||||
# dynamic shapes.
|
||||
return f"-FloorDiv(({numel}), -({block}))"
|
||||
return f"-(({numel}) // -({block}))"
|
||||
# trick above doesn't work in C++ due to rounding differences
|
||||
return f"(({numel} + ({block} - 1)) / ({block}))"
|
||||
|
||||
|
|
@ -2670,16 +2666,12 @@ class Grid3D(GridExpr):
|
|||
class Grid2DWithYZOverflow(GridExpr):
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
|
||||
self.prefix.extend(
|
||||
[
|
||||
self.assign_tmp(
|
||||
"y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
|
||||
),
|
||||
self.assign_tmp(
|
||||
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
||||
),
|
||||
]
|
||||
)
|
||||
self.prefix = [
|
||||
self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))),
|
||||
self.assign_tmp(
|
||||
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
|
||||
),
|
||||
]
|
||||
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
|
||||
self.z_grid = "y_grid_div_"
|
||||
|
||||
|
|
|
|||
|
|
@ -436,23 +436,6 @@ def convert_shape_to_inductor(
|
|||
return [sympy.sympify(i) for i in lst]
|
||||
|
||||
|
||||
def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
|
||||
"""
|
||||
Like convert_shape_to_symint, but operates on a single expression.
|
||||
"""
|
||||
from .virtualized import V
|
||||
|
||||
return (
|
||||
i
|
||||
if isinstance(i, int)
|
||||
else (
|
||||
int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert_shape_to_symint(
|
||||
lst: Iterable[Union[int, sympy.Expr]],
|
||||
) -> list[Union[int, torch.SymInt]]:
|
||||
|
|
@ -460,7 +443,20 @@ def convert_shape_to_symint(
|
|||
Takes a list of shapes from Inductor and converts them into symints (or just
|
||||
ints if all shapes are static).
|
||||
"""
|
||||
return [convert_to_symint(i) for i in lst]
|
||||
from .virtualized import V
|
||||
|
||||
return [
|
||||
(
|
||||
i
|
||||
if isinstance(i, int)
|
||||
else (
|
||||
int(i)
|
||||
if isinstance(i, sympy.Integer)
|
||||
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
||||
)
|
||||
)
|
||||
for i in lst
|
||||
]
|
||||
|
||||
|
||||
def is_view(op: torch._ops.OpOverload) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user