Revert "[Inductor] FX backend via Wrapper IR (#146942)"

This reverts commit a7691140a0.

Reverted https://github.com/pytorch/pytorch/pull/146942 on behalf of https://github.com/malfet due to Looks like it indeed breaks lint, see a7691140a0/1 ([comment](https://github.com/pytorch/pytorch/pull/146942#issuecomment-2852192778))
This commit is contained in:
PyTorch MergeBot 2025-05-05 20:01:29 +00:00
parent a7691140a0
commit 99dac7005f
9 changed files with 49 additions and 1251 deletions

View File

@ -1,417 +0,0 @@
# Owner(s): ["module: inductor"]
"""
Test the FX IR backend.
"""
import itertools
import operator
import unittest
from typing import Callable, Optional
import sympy
import torch
import torch._inductor.codegen.common as common
import torch.utils._pytree as pytree
from torch._dynamo.exc import BackendCompilerFailed
from torch._dynamo.utils import same
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._inductor import config
from torch._inductor.codegen.common import register_backend_for_device
from torch._inductor.codegen.cpp import CppScheduling
from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
requires_gpu,
TRITON_HAS_CPU,
)
@requires_gpu()
@config.patch(
compile_threads=1,
alignment_asserts=False,
size_asserts=False,
scalar_asserts=False,
nan_asserts=False,
)
class FxirTestCase(InductorTestCase):
device = GPU_TYPE
def _count_ops(self, gm: torch.fx.GraphModule, target: Callable) -> int:
return len(gm.graph.find_nodes(op="call_function", target=target))
def _run_and_capture_graphs(self, opt, args) -> torch.fx.GraphModule:
gms = []
orig_generate = FxConverter.generate
def generate(self) -> torch.fx.GraphModule:
nonlocal gms
gm = orig_generate(self)
gms.append(gm)
return gm
with unittest.mock.patch.object(
torch._inductor.codegen.wrapper_fxir.FxConverter, "generate", generate
):
opt(*args)
return gms
def _compile_and_check(
self,
func,
args,
expected_num_triton_kernels: int = 1,
metadata_only: bool = False,
compile_kwargs: Optional[dict] = None,
):
if compile_kwargs is None:
compile_kwargs = {}
opt = torch.compile(func, **compile_kwargs)
# Get the FX graph from the backend.
gms = self._run_and_capture_graphs(opt, args)
# Check the code for triton kernels.
num_kernels = sum(
self._count_ops(gm, triton_kernel_wrapper_mutation) for gm in gms
)
self.assertEqual(num_kernels, expected_num_triton_kernels)
# Check accuracy.
result = opt(*args)
ref = func(*args)
if metadata_only:
# When we only want to check metadata, fill in zeros for tensor data.
ref, result = tuple(
pytree.tree_map(torch.zeros_like, x) for x in (ref, result)
)
self.assertTrue(same(ref, result))
return gms
@classmethod
def setUpClass(cls):
super().setUpClass()
# Register the FX backend.
register_backend_for_device(cls.device, TritonScheduling, WrapperFxCodegen)
def test_basic(self):
args = [torch.randn(8, device=self.device) for _ in range(2)]
self._compile_and_check(torch.add, args)
def test_multiple_kernels(self):
def foo(x, y):
return x.sum() + y.sum()
args = [torch.randn(length, device=self.device) for length in [517, 1029]]
self._compile_and_check(foo, args, expected_num_triton_kernels=2)
def test_free(self):
"""
Test a program that frees a buffer which is no longer in use.
"""
def foo(x, y, z):
w = x.sum() + y
return z.sum() + w.sum()
args = [torch.randn(length, device=self.device) for length in [517, 1029, 123]]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=3)
# Check the generated code for frees.
num_frees = gm.code.count("= None")
self.assertGreater(num_frees, 0)
def test_extern(self):
"""
Test a program that calls an extern kernel.
"""
def foo(x, y):
return x @ y + y.sum()
args = [
torch.randn(size, device=self.device) for size in [(129, 129), (129, 1)]
]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
# Check for the extern kernel
num_extern = self._count_ops(gm, extern_kernels.addmm)
self.assertEqual(num_extern, 1)
def test_fallback(self):
"""
Test a program that calls an aten fallback.
"""
length = 8
def foo(x):
return x + torch.randn(1, device=self.device)
args = (torch.randn(length, device=self.device),)
# Since the program has a random output, just check metadata.
# Don't check for an exact value.
(gm,) = self._compile_and_check(
foo, args, expected_num_triton_kernels=2, metadata_only=True
)
# Check for the fallback kernel.
num_fallback = self._count_ops(gm, torch.ops.aten.randint.low_out)
self.assertEqual(num_fallback, 1)
def test_cat_inputs(self):
"""
Test concatenation of graph inputs.
"""
def foo(x, y):
return torch.cat((x, y)) + 1
args = [torch.randn(8, device=self.device) for _ in range(2)]
self._compile_and_check(foo, args, expected_num_triton_kernels=1)
def test_cat_to_alloc(self):
"""
Test concatenation that's optimized out to an allocation.
"""
length = 8
def foo(x):
y, z = tuple(
torch.arange(length // 2, device=self.device) for _ in range(2)
)
return x + torch.cat((y, z))
args = [torch.randn(length, device=self.device)]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
# Expect a single allocation, even though eager mode would use 2.
num_allocs = self._count_ops(gm, torch.empty_strided)
self.assertEqual(num_allocs, 1)
def test_cat_reinterpret_view(self):
"""
Test torch.cat using ReinterpretView.
"""
length = 8
def foo(x):
y, z = tuple(torch.randn(length // 2, device=self.device) for _ in range(2))
return x + torch.cat((y, z))
args = [torch.randn(length, device=self.device)]
# Since this test generates random numbers, check metadata only.
(gm,) = self._compile_and_check(
foo, args, expected_num_triton_kernels=3, metadata_only=True
)
# Check for as_strided. We map ReinterpretView to this.
num_as_strided = self._count_ops(gm, torch.as_strided)
self.assertEqual(num_as_strided, 2)
def test_reshape_output(self):
"""
Test reshaping the output, which maps to a ReinterpretView.
"""
def foo(x, y):
return torch.reshape(x + y, (8,))
args = [torch.randn((2, 4), device=self.device) for _ in range(2)]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1)
# Check for as_strided. We map ReinterpretView to this.
num_as_strided = self._count_ops(gm, torch.as_strided)
self.assertEqual(num_as_strided, 1)
def test_extern_multi_output(self):
"""
Test an extern kernel with multiple outputs.
Also test a graph with multiple outputs.
"""
def foo(x):
top, idx = torch.topk(x, 2)
return top + 1, idx * 2
args = [torch.randn(8, device=self.device)]
(gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=2)
# Check for multiple kernel outputs via getitems.
num_getitems = self._count_ops(gm, operator.getitem)
self.assertEqual(num_getitems, 2)
# Check for multiple graph outputs.
output_node = gm.graph.find_nodes(op="output")[0]
self.assertEqual(len(output_node.args[0]), 2)
def test_duplicate_input(self):
"""
Test duplicated inputs. This will collapse into a single input in the GM.
"""
args = [torch.randn(4, device=self.device)] * 2
(gm,) = self._compile_and_check(torch.add, args, expected_num_triton_kernels=1)
num_placeholders = len(gm.graph.find_nodes(op="placeholder"))
self.assertEqual(num_placeholders, 1)
def test_backward(self):
"""
Test a program with a backward pass.
"""
x = torch.ones(5, device=self.device) # input tensor
y = torch.zeros(3, device=self.device) # expected output
w = torch.randn(5, 3, requires_grad=True, device=self.device)
b = torch.randn(3, requires_grad=True, device=self.device)
def foo(x, y):
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()
return w.grad, b.grad
# Expect separate forward and backward graphs.
(forward_gm, backward_gm) = self._compile_and_check(
foo, (x, y), expected_num_triton_kernels=3
)
def test_custom_compiler(self):
"""
Test a derived backend with a custom compiler.
"""
offset = 1
class CustomWrapperCodegen(WrapperFxCodegen):
def compile_graph(self, gm):
def compiled_fn(*args):
# Adds an offset to the program's outputs.
outputs = gm(*args)
return pytree.tree_map(lambda x: x + 1, outputs)
return compiled_fn
args = [torch.randn(8, device=self.device) for _ in range(2)]
custom_backend = common.DeviceCodegen(
TritonScheduling, CustomWrapperCodegen, None
)
with unittest.mock.patch.dict(
common.device_codegens, {self.device: custom_backend}
):
func = torch.add
opt = torch.compile(func)
result = opt(*args)
# Check the output is offset from eager mode.
ref = func(*args)
self.assertFalse(same(result, ref))
self.assertNotEqual(offset, 0)
self.assertTrue(same(result - offset, ref))
def test_dynamic_shapes_and_strides(self):
"""
Test a graph with dynamic shapes and strides.
"""
static_dims = (8, 8)
def get_input():
full_size = (16, 8)
full = torch.randn(full_size, device=self.device)
view = torch.as_strided(full, static_dims, full.stride())
return view
func = torch.add
args = [get_input() for _ in range(2)]
(gm,) = self._compile_and_check(func, args, compile_kwargs={"dynamic": True})
# Check for a symbolic output shape.
(empty_strided,) = gm.graph.find_nodes(
op="call_function", target=torch.empty_strided
)
example_tensor = empty_strided.meta["val"]
symbolic_dims = example_tensor.shape
self.assertEqual(len(symbolic_dims), len(static_dims))
# Check for symbolic output strides.
(stride, one) = example_tensor.stride()
self.assertEqual(one, sympy.S.One)
# Find the size symbols, and check for a corresponding placeholders defining them.
for symbol in itertools.chain(symbolic_dims, [stride]):
self.assertTrue(isinstance(symbol, torch.SymInt))
(placeholder,) = [
node
for node in gm.graph.find_nodes(op="placeholder")
if node.name == str(symbol)
]
self.assertEqual(placeholder.meta["val"], symbol)
@config.patch({"trace.enabled": True})
@unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code")
def test_debug(self, mock_output_code):
# Compile in debug mode.
args = [torch.randn(11, device=self.device) for _ in range(2)]
self._compile_and_check(torch.sub, args)
# Check the output code for a Triton kernel call.
mock_output_code.assert_called_once()
(output_filename,) = mock_output_code.call_args.args
with open(output_filename) as f:
output_code = f.read()
self.assertIn("triton_kernel_wrapper_mutation", output_code)
@torch._inductor.config.patch("graph_partition", True)
def test_subgraph_raises(self):
"""
Test a model with subgraphs. This is not yet supported, so check that we get the
expected exception.
"""
def foo(cond, x):
return torch.cond(cond, torch.cos, torch.sin, [x])
cond = torch.tensor([True], device=self.device)
x = torch.ones([2, 3], device=self.device)
with self.assertRaisesRegex(BackendCompilerFailed, "Subgraph"):
self._compile_and_check(foo, [cond, x])
def test_cpp_raises(self):
"""
Test the C++ CPU backend. C++ kernels are not yet supported, so for now check
that we get the expected exception.
"""
def foo(x, y):
return x + y * 5
device = torch.device("cpu")
args = [torch.randn(5, device=device) for _ in range(2)]
cpp_backend = common.DeviceCodegen(CppScheduling, WrapperFxCodegen, None)
with unittest.mock.patch.dict(
common.device_codegens, {device.type: cpp_backend}
), self.assertRaisesRegex(BackendCompilerFailed, "Triton"):
self._compile_and_check(foo, args)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU or TRITON_HAS_CPU:
run_tests(needs="filelock")

View File

@ -1750,27 +1750,6 @@ class TracingTritonHOPifier(TritonHOPifier):
# normalize to tuple
return tuple(grid)
def store_non_graphable_args(
self,
combined_args: dict[str, Any],
) -> tuple[dict, int]:
"""
Some args cannot be stored in the FX graph.
Put them in the side table.
"""
def is_graphable(val: Any) -> bool:
return isinstance(val, (fx.node.base_types, fx.Node))
non_graphable_args = {
k: v for k, v in combined_args.items() if not is_graphable(v)
}
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
return graphable_args, constant_args_idx
def call_HOP(
self,
variable: "TraceableTritonKernelWrapper",
@ -1781,8 +1760,15 @@ class TracingTritonHOPifier(TritonHOPifier):
assert tx is None
assert isinstance(variable, TraceableTritonKernelWrapper)
graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
def is_graphable(val: Any) -> bool:
return isinstance(val, fx.node.base_types)
non_graphable_args = {
k: v for k, v in combined_args.items() if not is_graphable(v)
}
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
assert isinstance(variable.kernel_idx, int)
return triton_kernel_wrapper_mutation(
kernel_idx=variable.kernel_idx,

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import atexit
import contextlib
import dataclasses
import enum
@ -9,11 +8,8 @@ import itertools
import logging
import math
import operator
import os
import re
import tempfile
import typing
from abc import ABC, abstractmethod
from enum import auto, Enum
from itertools import chain
from typing import (
@ -64,8 +60,6 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from collections.abc import Iterator, MutableMapping, Sequence
from torch.fx import GraphModule
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
@ -89,38 +83,6 @@ def data_type_logger(msg: str) -> None:
schedule_log.debug("Data type propagation: %s", msg)
@dataclasses.dataclass
class FileBackedGraphModule:
"""
Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
map back to a GraphModule instead of Python source.
"""
gm: GraphModule
compiled_fn: Callable[..., Any]
def __post_init__(self) -> None:
# Write the code to a file for compatibility with debugging utilities.
# The file is deleted upon program termination.
self.tempfile = tempfile.NamedTemporaryFile(
mode="w+", suffix=".py", delete=False
)
atexit.register(os.remove, self.tempfile.name)
with self.tempfile as f:
f.write(self.value)
@property
def __file__(self) -> str:
return self.tempfile.name
def call(self, args: list[Any]) -> Any:
return self.compiled_fn(*args)
@property
def value(self) -> str:
return self.gm.code
class WorkspaceZeroMode(enum.Enum):
UNINITIALIZED = 0
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
@ -141,22 +103,8 @@ class WorkspaceZeroMode(enum.Enum):
return WorkspaceZeroMode.UNINITIALIZED
class CodegenSymbol(ABC):
"""
An IR object possibly corresponding to a variable in the wrapper code.
"""
@abstractmethod
def get_name(self) -> str:
pass
@abstractmethod
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
pass
@ir_dataclass(frozen=True)
class WorkspaceArg(CodegenSymbol):
class WorkspaceArg:
"""A temporary buffer used for a single kernel, then discarded.
Not registered as a traditional buffer since there are no users,
@ -219,9 +167,6 @@ class WorkspaceArg(CodegenSymbol):
def get_dtype(self) -> torch.dtype:
return self.dtype
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
return self.get_layout().get_example()
def get_layout(self) -> FixedLayout:
from ..ir import FixedLayout
@ -240,9 +185,6 @@ class WorkspaceArg(CodegenSymbol):
maybe_get_output_spec = get_layout
maybe_get_layout = get_layout
def get_offset(self) -> sympy.Expr:
return sympy.S.Zero
def get_size(self) -> list[sympy.Expr]:
return [self.count]

View File

@ -74,7 +74,6 @@ if TYPE_CHECKING:
import triton
from ..graph import GraphLowering
from .wrapper_fxir import FxConverter
log = logging.getLogger(__name__)
@ -84,7 +83,6 @@ pexpr = PythonPrinter().doprint
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
BufferLike = Union[ir.Buffer, WorkspaceArg]
FxConversionFunc = Callable[["WrapperLine"], None]
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
@ -351,8 +349,7 @@ class MemoryPlanningState:
class WrapperLine:
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
pass
@dataclasses.dataclass
@ -367,9 +364,6 @@ class EnterSubgraphLine(WrapperLine):
self.wrapper.push_codegened_graph(self.graph)
code.do_indent()
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_enter_subgraph
@dataclasses.dataclass
class CommentLine(WrapperLine):
@ -378,10 +372,6 @@ class CommentLine(WrapperLine):
def codegen(self, code: IndentedBuffer) -> None:
code.writeline(self.line)
@staticmethod
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
return converter._generate_comment
@dataclasses.dataclass
class ExitSubgraphLine(WrapperLine):
@ -394,9 +384,6 @@ class ExitSubgraphLine(WrapperLine):
self.wrapper.pop_codegened_graph()
code.do_unindent()
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_exit_subgraph
@dataclasses.dataclass
class EnterDeviceContextManagerLine(WrapperLine):
@ -432,18 +419,12 @@ class EnterDeviceContextManagerLine(WrapperLine):
code.do_indent()
code.writeline(V.graph.device_ops.set_device(self.device_idx))
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_enter_device_context_manager
class ExitDeviceContextManagerLine(WrapperLine):
def codegen(self, code: IndentedBuffer) -> None:
if not V.graph.cpp_wrapper:
code.do_unindent()
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_exit_device_context_manager
@dataclasses.dataclass
class ExternKernelAllocLine(WrapperLine):
@ -455,9 +436,6 @@ class ExternKernelAllocLine(WrapperLine):
args = [*node.codegen_args(), *node.codegen_kwargs()]
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_extern_kernel_alloc
@dataclasses.dataclass
class ExternKernelOutLine(WrapperLine):
@ -488,9 +466,6 @@ class ExternKernelOutLine(WrapperLine):
device,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_extern_kernel_out
@dataclasses.dataclass
class FreeLine(WrapperLine):
@ -501,9 +476,6 @@ class FreeLine(WrapperLine):
assert self.node.get_name() not in V.graph.removed_buffers
code.writeline(self.wrapper.make_buffer_free(self.node))
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_free
@dataclasses.dataclass
class KernelCallLine(WrapperLine):
@ -533,9 +505,6 @@ class KernelCallLine(WrapperLine):
original_fxnode_name=self.original_fxnode_name,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_kernel_call
@dataclasses.dataclass
class KernelDefinitionLine(WrapperLine):
@ -555,9 +524,6 @@ class KernelDefinitionLine(WrapperLine):
cpp_definition=self.cpp_definition,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_kernel_definition
@dataclasses.dataclass
class MemoryPlanningLine(WrapperLine):
@ -614,9 +580,6 @@ class AllocateLine(MemoryPlanningLine):
line = self.wrapper.make_buffer_allocation(self.node)
code.writeline(line)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_allocate
@dataclasses.dataclass
class FreeIfNotReusedLine(MemoryPlanningLine):
@ -640,9 +603,6 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
if not self.is_reused:
code.writeline(self.wrapper.make_buffer_free(self.node))
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_free_if_not_reused
@dataclasses.dataclass
class ReinterpretLine(MemoryPlanningLine):
@ -660,9 +620,6 @@ class ReinterpretLine(MemoryPlanningLine):
self.reused_as.get_name(), self.layout.view
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_reinterpret
@dataclasses.dataclass
class ReuseLine(MemoryPlanningLine):
@ -684,13 +641,9 @@ class ReuseLine(MemoryPlanningLine):
self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_reuse
class NullLine(MemoryPlanningLine):
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_null
pass
@dataclasses.dataclass
@ -764,9 +717,6 @@ class CommBufferAllocateLine(CommBufferLine):
f"Unsupported comm buffer type: {comm_buffer_type}"
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_comm_buffer_allocate
@dataclasses.dataclass
class CommBufferFreeLine(CommBufferLine):
@ -774,9 +724,6 @@ class CommBufferFreeLine(CommBufferLine):
line = self.wrapper.make_buffer_free(self.node)
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_comm_buffer_free
@dataclasses.dataclass
class MultiOutputLine(WrapperLine):
@ -813,22 +760,6 @@ class MultiOutputLine(WrapperLine):
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_multi_output
@dataclasses.dataclass
class SymbolicCallArgLine(WrapperLine):
wrapper: PythonWrapperCodegen
arg: SymbolicCallArg
graph: GraphLowering
def codegen(self, code: IndentedBuffer) -> None:
self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
return converter._generate_symbolic_call_arg
@dataclasses.dataclass
class SymbolicCallArgLine(WrapperLine):

View File

@ -1,596 +0,0 @@
import dataclasses
import operator
import textwrap
from collections import Counter
from typing import Any, Callable, Optional, Union
import sympy
import torch
from torch._higher_order_ops.triton_kernel_wrap import (
TraceableTritonKernelWrapper,
tracing_triton_hopifier_singleton,
triton_kernel_wrapper_mutation,
)
from torch._inductor.codecache import PyCodeCache
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
from torch._inductor.virtualized import V
from torch._library.triton import wrap_triton
from torch.fx import GraphModule
from .. import ir
from ..utils import convert_shape_to_symint, convert_to_symint, LineContext
from .common import (
CodegenSymbol,
FileBackedGraphModule,
WorkspaceArg,
WorkspaceZeroMode,
)
from .wrapper import (
AllocateLine,
BufferLike,
CommBufferAllocateLine,
CommBufferFreeLine,
CommentLine,
EnterDeviceContextManagerLine,
EnterSubgraphLine,
ExitDeviceContextManagerLine,
ExitSubgraphLine,
ExternKernelAllocLine,
ExternKernelOutLine,
FreeIfNotReusedLine,
FreeLine,
KernelCallLine,
KernelDefinitionLine,
Line,
MultiOutputLine,
NullLine,
PythonWrapperCodegen,
ReinterpretLine,
ReuseLine,
SymbolicCallArg,
SymbolicCallArgLine,
WrapperLine,
)
aten = torch.ops.aten
@dataclasses.dataclass
class SymbolBuffer(CodegenSymbol):
"""
Represents a sympy.Symbol graph input.
"""
symbol: sympy.Symbol
def get_name(self) -> str:
return str(self.symbol)
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
return self.symbol
CodegenBuffer = Union[BufferLike, SymbolBuffer]
@dataclasses.dataclass
class TritonKernel:
"""
Stores metadata about Triton kernels for use in FX.
"""
tuner: CachingAutotuner
wrapped: TraceableTritonKernelWrapper
class WrapperFxCodegen(PythonWrapperCodegen):
"""
Backend to generate wrapper code as an FX IR graph.
"""
supports_caching = False
def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]:
self.run_wrapper_ir_passes(is_inference)
prologue = "\n".join(
[
self.imports.getvalue(),
self.header.getvalue(),
]
)
gm = FxConverter(lines=self.lines, prologue=prologue).generate()
compiled_fn = self.compile_graph(gm)
return FileBackedGraphModule(gm, compiled_fn), None
def compile_graph(self, gm: GraphModule) -> Callable[..., Any]:
"""
Converts the graph module into a runnable function. The default implementation
is simply an interpreter calling kernels in eager mode. Derived backends can
override this to do further compilation.
"""
return gm.forward
@classmethod
def create(
cls,
is_subgraph: bool,
subgraph_name: Optional[str],
parent_wrapper: Optional[PythonWrapperCodegen],
partition_signatures: Optional[ir.GraphPartitionSignature] = None,
) -> "WrapperFxCodegen":
if is_subgraph:
raise NotImplementedError(
"Subgraphs are not yet supported by FX conversion"
)
# For derived backends, this could be a subclass.
return cls()
@dataclasses.dataclass
class FxConverter:
"""
Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the
input and output code are stored as attributes.
"""
lines: list[Line]
prologue: str = ""
def __post_init__(self) -> None:
graph = torch.fx.Graph()
self.gm = GraphModule({}, graph) # Wrapper FX IR.
self.buffer_to_node: dict[
Optional[str], torch.fx.Node
] = {} # Symbol table for codegen.
self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels.
self._unique_symbol_ids: Counter[str] = Counter()
def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner:
"""
Imports a kernel from source, possibly autotuning block parameters.
"""
module_code = "\n".join([self.prologue, code])
mod = PyCodeCache.load(module_code)
kernel = getattr(mod, kernel_name)
if not isinstance(kernel, CachingAutotuner):
raise NotImplementedError(
textwrap.dedent(f"""
Unsupported type for kernel {kernel_name}: {type(kernel)}.
FX conversion only supports Triton kernels.
""")
)
return kernel
def _fake_tensor(
self,
size: tuple[Any, ...],
stride: tuple[Any, ...],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
with V.fake_mode:
return torch.empty_strided(
convert_shape_to_symint(size),
convert_shape_to_symint(stride),
dtype=dtype,
device=device,
)
def _create_meta_from_buffer(
self, node: torch.fx.Node, buffer: CodegenBuffer
) -> None:
name = buffer.get_name()
assert name
node.name = name
node.meta["val"] = buffer.get_example()
def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None:
"""
Updates the symbol table to record that an Inductor buffer maps to the result of
an FX node.
"""
assert node not in self.buffer_to_node
self.buffer_to_node[buffer.get_name()] = node
def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None:
"""
Removes the buffer from the symbol table.
"""
name = buffer.get_name()
del self.buffer_to_node[name]
def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
"""
Maps call args back to FX nodes.
"""
return tuple(
self.buffer_to_node[arg]
if isinstance(arg, str)
else arg.inner_expr
if isinstance(arg, SymbolicCallArg)
else arg
for arg in args
)
def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer:
"""
Extract buffer data from an IR node.
"""
if isinstance(node, (ir.Buffer, WorkspaceArg)):
return node
elif isinstance(node, (ir.BaseView, ir.MutableBox)):
return self._get_buffer(node.data)
elif isinstance(node, sympy.Symbol):
return SymbolBuffer(node)
else:
raise NotImplementedError(f"Unable to extract buffer from node: {node}")
def _generate_graph_inputs(self) -> None:
"""
Converts graph inputs to FX placeholders.
"""
for ir_node in V.graph.graph_inputs.values():
buffer = self._get_buffer(ir_node)
node = self.gm.graph.placeholder(buffer.get_name())
self._create_meta_from_buffer(node, buffer)
self._record_allocation(buffer, node)
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
"""
Generates FX IR for transformations on a buffer, such as ReinterpretView.
Does nothing if no such transformations are present.
"""
def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]:
if isinstance(node, (ir.Buffer, WorkspaceArg)):
return node
elif isinstance(node, ir.NoneAsConstantBuffer):
return None
elif isinstance(node, ir.StorageBox):
return generate_to_buffer(node.data)
elif isinstance(node, ir.ReinterpretView):
# We need to introduce a new symbol if the output is a ReinterpretView.
# Use a WorkspaceArg for this.
buffer = self._get_buffer(node.data)
assert isinstance(buffer, (ir.Buffer, WorkspaceArg))
unique_name = self.gm.graph._graph_namespace.create_name(
f"{buffer.get_name()}_view", None
)
device = buffer.get_device()
assert device
reused_as = WorkspaceArg(
count=buffer.get_size(),
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
device=device,
outer_name=unique_name,
dtype=buffer.get_dtype(),
)
# Generate FX IR for the view.
self._generate_reinterpret_helper(buffer, reused_as, node.layout)
return reused_as
else:
raise NotImplementedError(f"Unrecognized buffer/view node: {node}")
buffer = generate_to_buffer(node)
return self.buffer_to_node[buffer.get_name()] if buffer is not None else None
def _generate_output(self) -> None:
"""
Generate FX IR for graph outputs.
"""
output_nodes = [
self._generate_buffer(node)
for idx, node in enumerate(V.graph.graph_outputs)
]
# Single return elements don't use a tuple.
output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes
self.gm.graph.output(output_value)
def generate(self) -> torch.fx.GraphModule:
"""
Main entrypoint for FX codegen.
"""
self._generate_graph_inputs()
# Generate FX IR from Wrapper IR lines.
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen_fx(self)(line)
elif isinstance(line, LineContext):
# Ignore line context in FX IR.
pass
else:
raise NotImplementedError(
textwrap.dedent(
f"""
Found line of unrecognized type '{type(line)}':
'{line}'
FX conversion only supports Wrapper IR lines.
"""
)
)
self._generate_output()
self.gm.recompile()
return self.gm
def _generate_allocate(self, line: WrapperLine) -> None:
assert isinstance(line, AllocateLine)
buffer = line.node
name = buffer.get_name()
assert name not in V.graph.removed_buffers
device = buffer.get_device()
dtype = buffer.get_dtype()
shape = convert_shape_to_symint(buffer.get_size())
stride = convert_shape_to_symint(buffer.get_stride())
node = self.gm.graph.call_function(
torch.empty_strided,
args=(shape, stride),
kwargs={"dtype": dtype, "device": device},
)
assert name
node.name = name
self._create_meta_from_buffer(node, buffer)
self._record_allocation(buffer, node)
def _generate_comment(self, line: WrapperLine) -> None:
assert isinstance(line, CommentLine)
# We ignore comments in FX IR.
def _generate_enter_device_context_manager(self, line: WrapperLine) -> None:
assert isinstance(line, EnterDeviceContextManagerLine)
# We ignore the device context in FX IR.
def _generate_exit_device_context_manager(self, line: WrapperLine) -> None:
assert isinstance(line, ExitDeviceContextManagerLine)
# We ignore the device context in FX IR.
def _generate_enter_subgraph(self, line: WrapperLine) -> None:
assert isinstance(line, EnterSubgraphLine)
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
def _generate_exit_subgraph(self, line: WrapperLine) -> None:
assert isinstance(line, ExitSubgraphLine)
raise NotImplementedError("Subgraphs are not yet supported by FX conversion")
def _generate_free(self, line: WrapperLine) -> None:
assert isinstance(line, FreeLine)
buf = line.node
# No need to free placeholders.
if self.buffer_to_node[buf.get_name()].op == "placeholder":
return
self._free(buf)
def _generate_free_if_not_reused(self, line: WrapperLine) -> None:
assert isinstance(line, FreeIfNotReusedLine)
buf = line.node
assert buf.get_name() not in V.graph.removed_buffers
if not line.is_reused:
self._free(buf)
def _generate_line_context(self, line: WrapperLine) -> None:
assert isinstance(line, LineContext)
# We ignore line context in FX IR.
def _generate_reinterpret(self, line: WrapperLine) -> None:
assert isinstance(line, ReinterpretLine)
self._generate_reinterpret_helper(line.node, line.reused_as, line.layout)
def _generate_reinterpret_helper(
self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout
) -> None:
input_node = self.buffer_to_node[input_buffer.get_name()]
# Look up output metadata.
name = result_buffer.get_name()
assert name
size = tuple(layout.size)
stride = tuple(layout.stride)
offset = input_buffer.get_offset() + layout.offset
# Map ReinterpretView to as_strided.
result_node = self.gm.graph.call_function(
torch.as_strided, args=(input_node, size, stride, offset)
)
result_node.name = name
result_node.meta["val"] = layout.get_example()
self._record_allocation(result_buffer, result_node)
def _generate_reuse(self, line: WrapperLine) -> None:
assert isinstance(line, ReuseLine)
old = line.node
new = line.reused_as
assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new))
assert old.get_dtype() == new.get_dtype()
old_node = self.buffer_to_node[old.get_name()]
result_node = old_node
# Change shape and stride.
size = new.get_size()
stride = new.get_stride()
offset = new.get_offset()
if (
old.get_size() != size
or old.get_stride() != stride
or old.get_offset() != offset
):
result_node = self.gm.graph.call_function(
torch.as_strided, args=(old_node, size, stride, offset)
)
self._create_meta_from_buffer(result_node, new)
self._record_allocation(new, result_node)
# Free the old buffer, if we allocated a new tensor.
if (
old.get_name() not in V.graph.get_output_names()
and line.delete_old
and result_node is not old_node
):
self._free(old)
def _generate_multi_output(self, line: WrapperLine) -> None:
assert isinstance(line, MultiOutputLine)
# Extract the index for tuple access.
inds = line.indices[0][1:]
assert len(inds) == 1, f"Cannot convert {inds} to an index."
idx = inds[0]
arg_node = self.buffer_to_node[line.arg_name]
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
node.meta["val"] = arg_node.meta["val"][idx]
node.name = line.result_name
self.buffer_to_node[line.result_name] = node
def _generate_null(self, line: WrapperLine) -> None:
assert isinstance(line, NullLine)
# Does nothing.
def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None:
assert isinstance(line, CommBufferAllocateLine)
raise NotImplementedError("Comm buffer allocation is not yet supported")
def _generate_comm_buffer_free(self, line: WrapperLine) -> None:
assert isinstance(line, CommBufferFreeLine)
self._free(line.node)
def _generate_triton_call(self, line: WrapperLine) -> None:
assert isinstance(line, KernelCallLine)
# Collect all kwargs, including autotuned block sizes.
call_args = self._lookup_args(line.call_args)
kernel = self.kernels[line.kernel_name]
tuner = kernel.tuner
config = tuner.compile_results[0].config
call_args, grid = tuner._interpret_args_grid(call_args, config)
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
call_kwargs.update(config.kwargs)
# Convert sympy expressions to symints.
for name, val in call_kwargs.items():
if isinstance(val, sympy.Expr):
call_kwargs[name] = convert_to_symint(val)
# Store non-graphable kwargs in the side table.
(
call_kwargs,
constant_args_idx,
) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs)
self.gm.graph.call_function(
triton_kernel_wrapper_mutation,
kwargs={
"kernel_idx": kernel.wrapped.kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": [convert_shape_to_symint(grid)],
"tma_descriptor_metadata": {},
"kwargs": call_kwargs,
},
)
def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None:
assert isinstance(line, ExternKernelAllocLine)
node = line.node
self._generate_extern_kernel_common(node, node)
def _generate_extern_kernel_out(
self,
line: WrapperLine,
) -> None:
assert isinstance(line, ExternKernelOutLine)
node = line.node
out_node = node.output_view if node.output_view else node
self._generate_extern_kernel_common(node, out_node)
def _generate_extern_kernel_common(
self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode
) -> None:
"""
Generates FX IR from either ExternKernelAlloc or ExternKernelOut.
"""
# Get FX nodes corresponding to the call args.
tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs)
args = tensor_nodes + tuple(kernel.constant_args)
# Get the result buffer.
# Some kernels write to a pre-existing output tensor via the "out" kwarg.
kwargs = kernel.kwargs.copy()
result_buffer: Optional[str] = None
if isinstance(kernel, ir.ExternKernelOut):
kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()]
elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)):
result_buffer = kernel.get_name()
elif isinstance(kernel.layout, ir.NoneLayout):
pass
else:
raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}")
# Look up the kernel function from its name.
kernel_name = kernel.get_kernel_name()
module_name, kernel_name = kernel_name.split(".", 1)
op = globals()[module_name] # E.g. extern_kernels, aten, etc.
for subname in kernel_name.split("."):
op = getattr(op, subname) # E.g. extern_kernels.addmm
fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs)
# Assign the result to the given name.
if result_buffer:
assert "out" not in kwargs, (
f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one."
)
fx_node.name = result_buffer
self.buffer_to_node[result_buffer] = fx_node
arg_tensors = [
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in args
]
# Run the operation to propagate metadata.
fx_node.meta["val"] = op(*arg_tensors, **kwargs)
def _generate_kernel_call(self, line: WrapperLine) -> None:
assert isinstance(line, KernelCallLine)
if not line.triton:
raise NotImplementedError("FX conversion only supports Triton kernels.")
self._generate_triton_call(line)
def _generate_kernel_definition(self, line: WrapperLine) -> None:
assert isinstance(line, KernelDefinitionLine)
# Generate code for the kernel.
kernel_code = PythonWrapperCodegen._format_kernel_definition(
line.kernel_name, line.kernel_body, metadata=line.metadata
)
# Import the module and store the JIT kernel.
tuner = self._import_kernel(kernel_code, line.kernel_name)
wrapped = wrap_triton(tuner.fn)
self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped)
def _generate_symbolic_call_arg(self, line: WrapperLine) -> None:
assert isinstance(line, SymbolicCallArgLine)
# No need for an FX node, as we will pass the arg to kernels via a SymInt.

View File

@ -50,7 +50,6 @@ from . import config, ir, metrics
from .codegen.common import (
BackendFeature,
DeviceOpOverrides,
FileBackedGraphModule,
get_backend_features,
get_device_op_overrides,
get_wrapper_codegen_for_device,
@ -116,12 +115,9 @@ if TYPE_CHECKING:
from torch._higher_order_ops.effects import _EffectType
from torch.fx import GraphModule
from torch.fx.graph import Graph
from .codegen.wrapper import PythonWrapperCodegen
from .scheduler import BaseSchedulerNode
CompiledModule = Union[ModuleType, FileBackedGraphModule]
from torch._inductor.codecache import output_code_log
@ -2228,7 +2224,7 @@ class GraphLowering(torch.fx.Interpreter):
# No-op to be patched for unit tests
save_output_code: Optional[Callable[[str], None]] = None
def compile_to_module(self) -> CompiledModule:
def compile_to_module(self) -> ModuleType:
with dynamo_timed(
"GraphLowering.compile_to_module",
phase_name="code_gen",
@ -2237,41 +2233,14 @@ class GraphLowering(torch.fx.Interpreter):
):
return self._compile_to_module()
def _compile_to_module(self) -> CompiledModule:
def _compile_to_module(self) -> ModuleType:
from .codecache import PyCodeCache
# Currently, if we're here, we don't have to worry about the kernel code, which
# is only available in AOTInductor mode.
wrapper_code, _ = (
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
)
if isinstance(wrapper_code, ValueWithLineMap):
mod = self._compile_to_module_lines(wrapper_code)
elif isinstance(wrapper_code, FileBackedGraphModule):
mod = wrapper_code
else:
raise NotImplementedError(
f"Unrecognized wrapper code type: {type(wrapper_code)}"
)
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
assert mod.__file__ is not None
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.info("Output code written to: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def _compile_to_module_lines(
self, wrapper_code: ValueWithLineMap
) -> CompiledModule:
from .codecache import PyCodeCache
if config.triton.autotune_at_compile_time:
tuning_code = (
'"""\n'
@ -2322,7 +2291,17 @@ class GraphLowering(torch.fx.Interpreter):
if config.benchmark_harness and config.profile_bandwidth_output:
# run the inputs code gen to get the bandwidth info
mod.benchmark_compiled_module(times=1, repeat=1)
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
assert mod.__file__ is not None
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.info("Output code written to: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def get_output_names(self) -> list[str]:

View File

@ -65,7 +65,6 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import (
BackendFeature,
CodegenSymbol,
get_scheduling_for_device,
index_prevent_reordering,
)
@ -3424,15 +3423,6 @@ class Layout(OutputSpec):
def get_device(self) -> torch.device:
return self.device
def get_example(self) -> torch.Tensor:
with V.fake_mode:
return torch.empty_strided(
convert_shape_to_symint(self.size),
convert_shape_to_symint(self.stride),
dtype=self.dtype,
device=self.device,
)
def is_contiguous(self) -> bool:
return is_contiguous_strides_for_shape(self.stride, self.size)
@ -3936,7 +3926,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
@ir_dataclass(frozen=False)
class Buffer(IRNode, CodegenSymbol):
class Buffer(IRNode):
# Name is sometimes None; e.g., ForceInPlace, where there isn't
# a meaningful name
name: Optional[str]
@ -3956,11 +3946,6 @@ class Buffer(IRNode, CodegenSymbol):
assert self.name, self
return self.name
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
if isinstance(self.layout, Layout):
return self.layout.get_example()
raise NotImplementedError(type(self.layout).__name__)
def get_device(self) -> Optional[torch.device]:
return self.get_output_spec().get_device()

View File

@ -85,7 +85,7 @@ class NoTritonConfigsError(RuntimeError):
if TYPE_CHECKING:
from collections.abc import Container, Hashable
from collections.abc import Container, Hashable, Sequence
from torch._guards import CompileId
@ -2564,15 +2564,13 @@ class GridExpr:
inductor_meta: dict[str, Any]
mode: Literal["python", "cpp"] = "python"
prefix: list[str] = dataclasses.field(default_factory=list)
prefix: Sequence[str] = ()
x_grid: Union[str, int] = 1
y_grid: Union[str, int] = 1
z_grid: Union[str, int] = 1
def __post_init__(self) -> None:
assert self.mode in ("python", "cpp")
if self.mode == "python":
self.prefix.append("from torch.utils._sympy.functions import FloorDiv")
def generate(self, meta: dict[str, int]) -> None:
raise NotImplementedError
@ -2585,9 +2583,7 @@ class GridExpr:
if isinstance(numel, int) and isinstance(block, int):
return ceildiv(numel, block) # constant fold
if self.mode == "python":
# Use FloorDiv instead of // so we can get better sympy expressions for
# dynamic shapes.
return f"-FloorDiv(({numel}), -({block}))"
return f"-(({numel}) // -({block}))"
# trick above doesn't work in C++ due to rounding differences
return f"(({numel} + ({block} - 1)) / ({block}))"
@ -2670,16 +2666,12 @@ class Grid3D(GridExpr):
class Grid2DWithYZOverflow(GridExpr):
def generate(self, meta: dict[str, int]) -> None:
self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK"))
self.prefix.extend(
[
self.assign_tmp(
"y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))
),
self.assign_tmp(
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
),
]
)
self.prefix = [
self.assign_tmp("y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK"))),
self.assign_tmp(
"y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid())
),
]
self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_")
self.z_grid = "y_grid_div_"

View File

@ -436,23 +436,6 @@ def convert_shape_to_inductor(
return [sympy.sympify(i) for i in lst]
def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
"""
Like convert_shape_to_symint, but operates on a single expression.
"""
from .virtualized import V
return (
i
if isinstance(i, int)
else (
int(i)
if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
)
)
def convert_shape_to_symint(
lst: Iterable[Union[int, sympy.Expr]],
) -> list[Union[int, torch.SymInt]]:
@ -460,7 +443,20 @@ def convert_shape_to_symint(
Takes a list of shapes from Inductor and converts them into symints (or just
ints if all shapes are static).
"""
return [convert_to_symint(i) for i in lst]
from .virtualized import V
return [
(
i
if isinstance(i, int)
else (
int(i)
if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
)
)
for i in lst
]
def is_view(op: torch._ops.OpOverload) -> bool: