mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano
1761 lines
55 KiB
Python
1761 lines
55 KiB
Python
from __future__ import annotations
|
|
|
|
import collections
|
|
import contextlib
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import inspect
|
|
import io
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
import platform
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import textwrap
|
|
import time
|
|
import unittest
|
|
from datetime import datetime
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Generic,
|
|
Iterable,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Protocol,
|
|
Set,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
ValuesView,
|
|
)
|
|
from typing_extensions import Concatenate, ParamSpec
|
|
from unittest import mock
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch._export
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
from torch._dynamo.utils import detect_fake_mode
|
|
from torch.autograd import DeviceType
|
|
from torch.autograd.profiler_util import EventList
|
|
from torch.fx.passes.shape_prop import ShapeProp
|
|
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
|
|
from torch.utils._sympy.symbol import make_symbol, SymT
|
|
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
|
from . import config
|
|
from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_T = TypeVar("_T")
|
|
VarRanges = Dict[sympy.Expr, sympy.Expr]
|
|
|
|
ALIGNMENT = 16
|
|
|
|
|
|
def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
|
|
"""
|
|
Returns benchmark results by examining torch profiler events.
|
|
This could be more accurate as it doesn't count CPU side overhead.
|
|
However, this also requires manually excluding irrelevant event, e.g.
|
|
vectorized_elementwise_kernel which is used to fill L2 cache,
|
|
various CUDA events, etc, so could also be fragile.
|
|
"""
|
|
|
|
fn()
|
|
torch.cuda.synchronize()
|
|
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
|
|
|
|
# Estimate the runtime of the function
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
start_event.record()
|
|
for _ in range(5):
|
|
cache.zero_()
|
|
fn()
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
estimate_ms = start_event.elapsed_time(end_event) / 5
|
|
|
|
# compute number of warmup and repeat
|
|
n_warmup = max(1, int(warmup / estimate_ms))
|
|
n_repeat = max(1, int(rep / estimate_ms))
|
|
|
|
# Warm-up
|
|
for _ in range(n_warmup):
|
|
fn()
|
|
|
|
with torch.profiler.profile(
|
|
activities=[
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
]
|
|
) as p:
|
|
# Benchmark
|
|
for i in range(n_repeat):
|
|
# we clear the L2 cache before each run
|
|
cache.zero_()
|
|
# record time of `fn`
|
|
fn()
|
|
# Record clocks
|
|
torch.cuda.synchronize()
|
|
|
|
log.debug("raw events")
|
|
log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
|
|
|
filtered_events = EventList(
|
|
[
|
|
event
|
|
for event in p.events()
|
|
if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
|
|
]
|
|
)
|
|
if len(filtered_events) % n_repeat != 0:
|
|
raise RuntimeError(
|
|
"Failed to divide all profiling events into #repeat groups. "
|
|
"#CUDA events: %d, #repeats: %s",
|
|
len(filtered_events),
|
|
n_repeat,
|
|
)
|
|
num_event_per_group = len(filtered_events) / n_repeat
|
|
actual_events = EventList(
|
|
[
|
|
event
|
|
for i, event in enumerate(filtered_events)
|
|
if i % num_event_per_group != 0
|
|
]
|
|
)
|
|
actual_events._build_tree()
|
|
actual_events = actual_events.key_averages()
|
|
|
|
log.debug("profiling time breakdown")
|
|
log.debug(actual_events.table(row_limit=-1))
|
|
|
|
res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
|
|
log.debug("profiling results: %s ms", res)
|
|
return res
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_torchvision_roi_align() -> bool:
|
|
try:
|
|
from torchvision.ops import roi_align # noqa: F401
|
|
|
|
torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
|
|
return roi_align is not None and hasattr(
|
|
getattr(torch.ops, "torchvision", None), "roi_align"
|
|
)
|
|
except ImportError:
|
|
return False
|
|
except RuntimeError as e:
|
|
assert "torchvision::nms does not exist" in str(e)
|
|
return False
|
|
|
|
|
|
def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
|
|
if device is None:
|
|
return torch.tensor(0.0).device # default device
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
if device.type not in ("cpu", "meta") and device.index is None:
|
|
device_interface = get_interface_for_device(device.type)
|
|
return torch.device(device.type, index=device_interface.Worker.current_device())
|
|
return device
|
|
|
|
|
|
def sympy_product(it):
|
|
return functools.reduce(operator.mul, it, sympy.Integer(1))
|
|
|
|
|
|
def sympy_dot(seq1, seq2):
|
|
assert len(seq1) == len(seq2)
|
|
return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
|
|
|
|
|
|
def unique(it: Iterable[_T]) -> ValuesView[_T]:
|
|
return {id(x): x for x in it}.values()
|
|
|
|
|
|
def ceildiv(
|
|
numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
|
|
) -> Union[int, sympy.Expr]:
|
|
if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
|
|
return CeilDiv(sympy.sympify(numer), sympy.sympify(denom))
|
|
# TODO: There is a bug in a call to this function, to repro:
|
|
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
|
|
# --amp --only YituTechConvBert --dynamic-shapes
|
|
assert isinstance(numer, int) and isinstance(
|
|
denom, int
|
|
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
|
return runtime_ceildiv(numer, denom)
|
|
|
|
|
|
def _type_of(key):
|
|
# Use the function here to get rid of dependencies on the Triton during the codegen.
|
|
# Refer to Triton implementation here:
|
|
# https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
|
|
# `None` is nullptr. Implicitly convert to *i8.
|
|
if key is None:
|
|
return "*i8"
|
|
dtype_str = str(key).split(".")[-1]
|
|
tys = {
|
|
"bool": "i1",
|
|
"float8e4nv": "fp8e4nv",
|
|
"float8e5": "fp8e5",
|
|
"float8e4b15": "fp8e4b15",
|
|
"float8e4b15x4": "fp8e4b15x4",
|
|
"float8_e4m3fn": "fp8e4nv",
|
|
"float8_e5m2": "fp8e5",
|
|
"float16": "fp16",
|
|
"bfloat16": "bf16",
|
|
"float32": "fp32",
|
|
"float64": "fp64",
|
|
"int8": "i8",
|
|
"int16": "i16",
|
|
"int32": "i32",
|
|
"int64": "i64",
|
|
"uint8": "u8",
|
|
"uint16": "u16",
|
|
"uint32": "u32",
|
|
"uint64": "u64",
|
|
}
|
|
# reinterpret can create triton type
|
|
for v in list(tys.values()):
|
|
tys[v] = v
|
|
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
|
|
|
|
|
|
def convert_shape_to_inductor(
|
|
lst: Iterable[Union[int, torch.SymInt]]
|
|
) -> List[sympy.Expr]:
|
|
"""
|
|
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
|
|
trivial. But for symbolic tensors, we need to map from SymIntNode into
|
|
sympy.Expr.
|
|
"""
|
|
return [
|
|
i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
|
|
]
|
|
|
|
|
|
def convert_shape_to_symint(
|
|
lst: Iterable[Union[int, sympy.Expr]]
|
|
) -> List[Union[int, torch.SymInt]]:
|
|
"""
|
|
Takes a list of shapes from Inductor and converts them into symints (or just
|
|
ints if all shapes are static).
|
|
"""
|
|
from .virtualized import V
|
|
|
|
return [
|
|
i
|
|
if isinstance(i, int)
|
|
else int(i)
|
|
if isinstance(i, sympy.Integer)
|
|
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
|
|
for i in lst
|
|
]
|
|
|
|
|
|
def is_view(op: torch._ops.OpOverload):
|
|
"""
|
|
Does this op overload have aliasing
|
|
"""
|
|
assert isinstance(op, torch._ops.OpOverload)
|
|
return any(a.alias_info is not None for a in op._schema.arguments)
|
|
|
|
|
|
def is_pointwise_use(use):
|
|
if not use.op == "call_function":
|
|
return False
|
|
|
|
if not (
|
|
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
|
|
):
|
|
return False
|
|
|
|
if use.target is operator.getitem or is_view(use.target):
|
|
return all(is_pointwise_use(u) for u in use.users)
|
|
|
|
return torch.Tag.pointwise in use.target.tags
|
|
|
|
|
|
def gen_gm_and_inputs(target, args, kwargs):
|
|
g = torch.fx.Graph()
|
|
g_args = []
|
|
a_args = []
|
|
for n, arg in enumerate(args):
|
|
if isinstance(arg, torch.Tensor):
|
|
g_args.append(g.placeholder(f"arg{n}"))
|
|
a_args.append(arg)
|
|
else:
|
|
g_args.append(arg)
|
|
assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
|
|
node = g.call_function(target, tuple(g_args), kwargs)
|
|
if (
|
|
len(target._schema.returns) == 1
|
|
and str(target._schema.returns[0].type) == "Tensor"
|
|
):
|
|
node = (node,)
|
|
g.output(node)
|
|
|
|
gm = torch.fx.GraphModule({}, g)
|
|
return gm, a_args
|
|
|
|
|
|
def synchronize(device: str = "cuda"):
|
|
if device == "cpu":
|
|
return
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available():
|
|
device_interface.synchronize()
|
|
|
|
|
|
def timed(
|
|
model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
|
|
) -> float:
|
|
synchronize(device)
|
|
torch.manual_seed(1337)
|
|
t0 = time.perf_counter()
|
|
for _ in range(times):
|
|
result = model(*example_inputs)
|
|
synchronize(device)
|
|
t1 = time.perf_counter()
|
|
# GC the result after timing
|
|
assert result is not None # type: ignore[possibly-undefined]
|
|
return t1 - t0
|
|
|
|
|
|
def print_performance(
|
|
fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
|
|
):
|
|
timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
|
|
took = torch.median(timings) / times
|
|
print(f"{took / baseline:.6f}")
|
|
return took
|
|
|
|
|
|
def precompute_method(obj: Any, method: str):
|
|
"""Replace obj.method() with a new method that returns a precomputed constant."""
|
|
result = getattr(obj, method)()
|
|
setattr(obj, method, lambda: result)
|
|
|
|
|
|
def precompute_methods(obj: Any, methods: List[str]):
|
|
"""Replace methods with new methods that returns a precomputed constants."""
|
|
for method in methods:
|
|
precompute_method(obj, method)
|
|
|
|
|
|
def cmp(a, b) -> int:
|
|
return int(a > b) - int(a < b)
|
|
|
|
|
|
def pad_listlike(x, size):
|
|
if len(x) == 1:
|
|
return type(x)([x[0]]) * size
|
|
else:
|
|
return x
|
|
|
|
|
|
# Used to ensure that iterating over a set is deterministic
|
|
def tuple_sorted(x):
|
|
if len(x) == 0:
|
|
return []
|
|
|
|
def sort_func(elem):
|
|
if isinstance(elem, str):
|
|
return elem
|
|
else:
|
|
# We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
|
|
# but we are not able to do isinstance assert because of circular dependency
|
|
return elem.get_name()
|
|
|
|
return sorted(x, key=sort_func)
|
|
|
|
|
|
P = ParamSpec("P")
|
|
RV = TypeVar("RV", covariant=True)
|
|
|
|
|
|
class CachedMethod(Protocol, Generic[P, RV]):
|
|
@staticmethod
|
|
def clear_cache(self) -> None:
|
|
...
|
|
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
|
|
...
|
|
|
|
|
|
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
|
|
def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
|
|
key = f"__{fn.__name__}_cache"
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(self):
|
|
if not hasattr(self, key):
|
|
setattr(self, key, fn(self))
|
|
return getattr(self, key)
|
|
|
|
def clear_cache(self):
|
|
if hasattr(self, key):
|
|
delattr(self, key)
|
|
|
|
wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
|
|
return wrapper # type: ignore[return-value]
|
|
|
|
|
|
def aggregate_origins(node_schedule):
|
|
from . import ir
|
|
|
|
if isinstance(node_schedule, list):
|
|
return functools.reduce(
|
|
operator.or_,
|
|
[
|
|
node.node.origins
|
|
for node in node_schedule
|
|
if hasattr(node, "node") and node.node
|
|
],
|
|
set(),
|
|
)
|
|
elif isinstance(node_schedule, ir.ExternKernel):
|
|
return node_schedule.origins
|
|
else:
|
|
return set()
|
|
|
|
|
|
def get_fused_kernel_name(node_schedule, descriptive_names):
|
|
all_origins = aggregate_origins(node_schedule)
|
|
if descriptive_names == "original_aten":
|
|
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
|
|
sources = [
|
|
origin.meta["original_aten"]._overloadpacket.__name__
|
|
for origin in all_origins
|
|
if origin.op == "call_function"
|
|
and "original_aten" in origin.meta
|
|
and origin.meta["original_aten"] is not None
|
|
]
|
|
sources = sorted(set(sources))
|
|
elif descriptive_names == "torch":
|
|
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
|
|
sources = []
|
|
for origin in all_origins:
|
|
if origin.op == "call_function" and "source_fn_stack" in origin.meta:
|
|
source_fn = origin.meta["source_fn_stack"][-1]
|
|
if isinstance(source_fn[1], str):
|
|
sources.append(source_fn[1])
|
|
else:
|
|
sources.append(source_fn[1].__name__)
|
|
sources = sorted(set(sources))
|
|
elif descriptive_names == "inductor_node":
|
|
sources = [
|
|
origin.name for origin in all_origins if origin.op == "call_function"
|
|
]
|
|
else:
|
|
raise NotImplementedError
|
|
sources = sources
|
|
return "_".join(["fused"] + sources)
|
|
|
|
|
|
def get_kernel_metadata(node_schedule, wrapper):
|
|
all_origins = aggregate_origins(node_schedule)
|
|
inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
|
|
|
|
from_node_dict = collections.defaultdict(list)
|
|
original_aten_dict = collections.defaultdict(list)
|
|
for node in inductor_nodes:
|
|
if "original_aten" in node.meta and node.meta["original_aten"] is not None:
|
|
key = str(node.meta["original_aten"]._overloadpacket)
|
|
original_aten_dict[key].append(node.name)
|
|
if "from_node" in node.meta:
|
|
key = node.meta["from_node"][0][0]
|
|
from_node_dict[key].append(node.name)
|
|
metadata = (
|
|
f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], "
|
|
f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]"
|
|
)
|
|
# trace back to original node here
|
|
detailed_metadata = []
|
|
for original_node, nodes in sorted(from_node_dict.items()):
|
|
detailed_metadata.append(
|
|
f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
|
|
)
|
|
return metadata, "\n".join(detailed_metadata)
|
|
|
|
|
|
def dominated_nodes(
|
|
initial_queue: Iterable[torch.fx.Node], skip_filter=None
|
|
) -> Set[torch.fx.Node]:
|
|
"""Returns the set of nodes whose values depend on those within initial_queue"""
|
|
initial_queue = list(initial_queue)
|
|
dominated_set = set(initial_queue)
|
|
|
|
while initial_queue:
|
|
node = initial_queue.pop()
|
|
for user in node.users:
|
|
if skip_filter and skip_filter(user):
|
|
continue
|
|
if user not in dominated_set:
|
|
dominated_set.add(user)
|
|
initial_queue.append(user)
|
|
|
|
return dominated_set
|
|
|
|
|
|
def gather_origins(args, kwargs):
|
|
import itertools
|
|
|
|
from . import ir
|
|
|
|
def is_unrealized_node(n):
|
|
if isinstance(n, ir.TensorBox):
|
|
return is_unrealized_node(n.data)
|
|
if isinstance(n, ir.StorageBox):
|
|
return is_unrealized_node(n.data)
|
|
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
|
|
|
|
kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
|
|
arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
|
|
return set(itertools.chain(*arg_origins, *kwarg_origins))
|
|
|
|
|
|
def sympy_str(expr: sympy.Expr) -> str:
|
|
"""
|
|
Normal sympy str is very slow, this is a lot faster. The result are
|
|
somewhat worse, as it doesn't do as much simplification. So don't
|
|
use this for final codegen.
|
|
"""
|
|
if isinstance(expr, sympy.Symbol):
|
|
return expr.name
|
|
if isinstance(expr, sympy.Add):
|
|
return " + ".join(map(sympy_str, expr.args))
|
|
if isinstance(expr, sympy.Mul):
|
|
return " * ".join(map(sympy_str, expr.args))
|
|
|
|
if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
|
|
return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
|
|
return str(expr)
|
|
|
|
|
|
def get_bounds_index_expr(index):
|
|
from .virtualized import V
|
|
|
|
# If this expression does not come from an FX node, we compute its bounds
|
|
if (
|
|
config.compute_all_bounds
|
|
and (fx_node := getattr(V.interpreter, "current_node", None))
|
|
and fx_node.target != "index_expr"
|
|
):
|
|
return bound_sympy(index)
|
|
else:
|
|
return ValueRanges.unknown()
|
|
|
|
|
|
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
|
|
"""
|
|
Used to generate an integer-nonnegative symbol.
|
|
"""
|
|
# This should never be used for creating shape/stride symbols, as those
|
|
# should all be allocated before Inductor.
|
|
assert prefix != SymT.SIZE
|
|
# NOTE: shape symbols are positive (> 0), but index variables are only
|
|
# non-negative (>= 0).
|
|
return make_symbol(prefix, idx, integer=True, nonnegative=True)
|
|
|
|
|
|
def generate_assert(check):
|
|
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
|
|
|
|
|
|
def sympy_index_symbol(name: str) -> sympy.Symbol:
|
|
"""
|
|
Used to generate an integer-nonnegative symbol.
|
|
"""
|
|
# This should never be used for creating shape/stride symbols, as those
|
|
# should all be allocated before Inductor.
|
|
assert name[0] != "s"
|
|
# NOTE: shape symbols are positive (> 0), but index variables are only
|
|
# non-negative (>= 0).
|
|
return sympy.Symbol(name, integer=True, nonnegative=True)
|
|
|
|
|
|
def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
|
|
"""
|
|
When the passed replacement symbol v is a string, it is converted to a symbol with name v that
|
|
have the same replaced expression integer and nonnegative properties.
|
|
"""
|
|
|
|
def to_symbol(replaced, replacement):
|
|
assert isinstance(replaced, sympy.Expr)
|
|
if isinstance(replacement, str):
|
|
return sympy.Symbol(
|
|
replacement,
|
|
integer=replaced.is_integer, # type: ignore[attr-defined]
|
|
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
|
|
)
|
|
else:
|
|
return replacement
|
|
|
|
# xreplace is faster than subs, but is way more picky
|
|
return sympy.sympify(expr).xreplace(
|
|
{k: to_symbol(k, v) for k, v in replacements.items()}
|
|
)
|
|
|
|
|
|
def is_symbolic(a: Any) -> bool:
|
|
return isinstance(a, torch.SymInt) or (
|
|
isinstance(a, torch.Tensor)
|
|
and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
|
|
)
|
|
|
|
|
|
def any_is_symbolic(*args: Any) -> bool:
|
|
return any(is_symbolic(a) for a in args)
|
|
|
|
|
|
def get_first_incompatible_cudagraph_node(gm):
|
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
|
|
|
forbidden_set = {
|
|
"aten._fused_moving_avg_obs_fq_helper.default",
|
|
"aten._fused_moving_avg_obs_fq_helper_functional.default",
|
|
"aten.multinomial.default",
|
|
"fbgemm.dense_to_jagged.default",
|
|
"fbgemm.jagged_to_padded_dense.default",
|
|
"run_and_save_rng_state",
|
|
"run_with_rng_state",
|
|
"aten._local_scalar_dense",
|
|
# Technically, it's not necessary to ban this, because an
|
|
# assert_scalar with constant arguments can be validly run
|
|
# with CUDA graphs, but the operator is also pointless with
|
|
# constant arguments, so might as well ban
|
|
"aten._assert_scalar",
|
|
}
|
|
if torch.are_deterministic_algorithms_enabled():
|
|
forbidden_set.update(
|
|
{
|
|
"aten._unsafe_index_put.default",
|
|
"aten.index_put.default",
|
|
"aten.index_put_.default",
|
|
"aten.scatter.src",
|
|
"aten.scatter.reduce",
|
|
"aten.scatter.value_reduce",
|
|
"aten.scatter_add_",
|
|
"aten.scatter_add.default",
|
|
"aten.scatter_reduce.two",
|
|
"aten.scatter_reduce_.two",
|
|
"aten.scatter_reduce.two_out",
|
|
}
|
|
)
|
|
for node in gm.graph.nodes:
|
|
if str(node.target) in forbidden_set:
|
|
return node
|
|
if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
|
|
return node
|
|
return None
|
|
|
|
|
|
def has_incompatible_cudagraph_ops(gm):
|
|
return get_first_incompatible_cudagraph_node(gm) is not None
|
|
|
|
|
|
def output_node(gm: torch.fx.GraphModule):
|
|
"""Get the output node from an FX graph"""
|
|
last_node = next(iter(reversed(gm.graph.nodes)))
|
|
assert last_node.op == "output"
|
|
return last_node
|
|
|
|
|
|
_registered_caches: List[Any] = []
|
|
|
|
|
|
def clear_on_fresh_inductor_cache(obj: Any):
|
|
"""
|
|
Use this decorator to register any caches that should be cache_clear'd
|
|
with fresh_inductor_cache().
|
|
"""
|
|
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
|
raise AttributeError(f"{obj} does not have a cache_clear method")
|
|
|
|
_registered_caches.append(obj)
|
|
return obj
|
|
|
|
|
|
def clear_inductor_caches():
|
|
"""
|
|
Clear all registered caches.
|
|
"""
|
|
for obj in _registered_caches:
|
|
obj.cache_clear()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def fresh_inductor_cache(cache_entries=None):
|
|
"""
|
|
Contextmanager that provides a clean tmp cachedir for inductor.
|
|
|
|
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
|
|
generated with this cache instance.
|
|
"""
|
|
clear_inductor_caches()
|
|
|
|
inductor_cache_dir = tempfile.mkdtemp()
|
|
try:
|
|
with mock.patch.dict(
|
|
os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
|
|
):
|
|
triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
|
|
with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
|
|
yield
|
|
if isinstance(cache_entries, dict):
|
|
assert len(cache_entries) == 0, "expected empty cache_entries dict"
|
|
if os.path.exists(triton_cache_dir):
|
|
files = os.listdir(triton_cache_dir)
|
|
cache_entries.update(
|
|
{
|
|
f: os.path.getsize(os.path.join(triton_cache_dir, f))
|
|
for f in files
|
|
if ".lock" not in f
|
|
}
|
|
)
|
|
shutil.rmtree(inductor_cache_dir)
|
|
except Exception:
|
|
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
|
|
raise
|
|
finally:
|
|
clear_inductor_caches()
|
|
|
|
|
|
def argsort(seq) -> List[int]:
|
|
# preserve original order for equal strides
|
|
getter = seq.__getitem__
|
|
a_r = range(len(seq))
|
|
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
|
|
|
|
|
|
@functools.lru_cache(8)
|
|
def get_dtype_size(dtype):
|
|
return torch.empty((), dtype=dtype).element_size()
|
|
|
|
|
|
class LineContext(NamedTuple):
|
|
context: Any
|
|
|
|
|
|
class IndentedBuffer:
|
|
tabwidth = 4
|
|
|
|
def __init__(self, initial_indent=0):
|
|
self._lines = []
|
|
self._indent = initial_indent
|
|
|
|
def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
|
|
buf = StringIO()
|
|
p = 1
|
|
linemap = []
|
|
for line in self._lines:
|
|
if isinstance(line, DeferredLineBase):
|
|
line = line()
|
|
if line is None:
|
|
continue
|
|
elif isinstance(line, LineContext):
|
|
linemap.append((p, line.context))
|
|
continue
|
|
assert isinstance(line, str)
|
|
buf.write(line)
|
|
buf.write("\n")
|
|
p += 1 + line.count("\n")
|
|
return buf.getvalue(), linemap
|
|
|
|
def getvalue(self) -> str:
|
|
v, _ = self.getvaluewithlinemap()
|
|
return v
|
|
|
|
def getrawvalue(self) -> str:
|
|
buf = StringIO()
|
|
for line in self._lines:
|
|
if isinstance(line, DeferredLineBase):
|
|
line = line()
|
|
if line is None:
|
|
continue
|
|
elif isinstance(line, LineContext):
|
|
continue
|
|
assert isinstance(line, str)
|
|
# backslash implies line continuation
|
|
if line.endswith("\\"):
|
|
buf.write(line[:-1])
|
|
else:
|
|
buf.write(line)
|
|
buf.write("\n")
|
|
return buf.getvalue()
|
|
|
|
def clear(self):
|
|
self._lines.clear()
|
|
|
|
def __bool__(self):
|
|
return bool(self._lines)
|
|
|
|
def prefix(self):
|
|
return " " * (self._indent * self.tabwidth)
|
|
|
|
def newline(self):
|
|
self.writeline("\n")
|
|
|
|
def writeline(self, line):
|
|
if isinstance(line, LineContext):
|
|
self._lines.append(line)
|
|
elif isinstance(line, DeferredLineBase):
|
|
self._lines.append(line.with_prefix(self.prefix()))
|
|
elif line.strip():
|
|
self._lines.append(f"{self.prefix()}{line}")
|
|
else:
|
|
self._lines.append("")
|
|
|
|
def writelines(self, lines):
|
|
for line in lines:
|
|
self.writeline(line)
|
|
|
|
def indent(self, offset=1):
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
self._indent += offset
|
|
try:
|
|
yield
|
|
finally:
|
|
self._indent -= offset
|
|
|
|
return ctx()
|
|
|
|
def do_indent(self, offset=1):
|
|
self._indent += offset
|
|
|
|
def do_unindent(self, offset=1):
|
|
self._indent -= offset
|
|
|
|
def splice(self, other_code, strip=False):
|
|
if isinstance(other_code, IndentedBuffer):
|
|
dedent = float("inf")
|
|
for line in other_code._lines:
|
|
if not isinstance(line, LineContext) and line:
|
|
dedent = min(dedent, len(line) - len(line.lstrip()))
|
|
if math.isinf(dedent):
|
|
dedent = 0
|
|
for line in other_code._lines:
|
|
if isinstance(line, LineContext):
|
|
self._lines.append(line)
|
|
else:
|
|
IndentedBuffer.writeline(self, line[int(dedent) :])
|
|
else:
|
|
other_code = textwrap.dedent(other_code)
|
|
if strip:
|
|
other_code = other_code.lstrip()
|
|
if not other_code:
|
|
return
|
|
other_code = other_code.rstrip()
|
|
for line in other_code.split("\n"):
|
|
self.writeline(line)
|
|
|
|
def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
|
|
res = IndentedBuffer(initial_indent=self._indent)
|
|
res._lines = [func(line) for line in self._lines]
|
|
return res
|
|
|
|
def __repr__(self):
|
|
return f"{type(self)}({self.getvalue()})"
|
|
|
|
def __add__(self, other):
|
|
assert self._indent == other._indent
|
|
res = IndentedBuffer(initial_indent=self._indent)
|
|
res.writelines(self._lines)
|
|
res.writelines(other._lines)
|
|
return res
|
|
|
|
|
|
class FakeIndentedBuffer(IndentedBuffer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def __getattribute__(self, name):
|
|
if name == "__class__": # Allow access to the class attribute
|
|
return object.__getattribute__(self, name)
|
|
raise RuntimeError(
|
|
f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
|
|
"is currently used on TritonTemplateKernel to prevent actual"
|
|
"writes to the body without explicitly specifying the body with"
|
|
"`TritonTemplateKernel.set_subgraph_body(name)`"
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def restore_stdout_stderr(initial_stdout, initial_stderr):
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.stdout = initial_stdout
|
|
sys.stderr = initial_stderr
|
|
|
|
|
|
class DeferredLineBase:
|
|
"""A line that can be 'unwritten' at a later time"""
|
|
|
|
def __init__(self, line):
|
|
if not line.strip():
|
|
line = ""
|
|
self.line = line
|
|
|
|
def __call__(self) -> Optional[str]:
|
|
"""Returns either self.line or None to indicate the line has been 'unwritten'"""
|
|
raise NotImplementedError
|
|
|
|
def _new_line(self, line: str) -> DeferredLineBase:
|
|
"""Returns a new deferred line with the same condition"""
|
|
raise NotImplementedError
|
|
|
|
def with_prefix(self, prefix):
|
|
return self._new_line(f"{prefix}{self.line}")
|
|
|
|
def lstrip(self):
|
|
return self._new_line(self.line.lstrip())
|
|
|
|
def __getitem__(self, index):
|
|
return self._new_line(self.line[index])
|
|
|
|
def __bool__(self):
|
|
return bool(self.line)
|
|
|
|
def __len__(self):
|
|
return len(self.line)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def is_big_gpu(index) -> bool:
|
|
min_sms = 68 # 3080
|
|
avail_sms = torch.cuda.get_device_properties(index).multi_processor_count
|
|
if avail_sms < min_sms:
|
|
log.warning(
|
|
"Not enough SMs to use max_autotune_gemm mode",
|
|
extra={"min_sms": min_sms, "avail_sms": avail_sms},
|
|
)
|
|
return False
|
|
return True
|
|
|
|
|
|
def use_max_autotune() -> bool:
|
|
return (
|
|
config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
|
|
)
|
|
|
|
|
|
def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
|
|
return (
|
|
use_max_autotune()
|
|
and layout.device.type == "cuda"
|
|
and layout.dtype in allowed_layout_dtypes
|
|
and is_big_gpu(layout.device.index or 0)
|
|
)
|
|
|
|
|
|
def _use_autotune_backend(backend: str) -> bool:
|
|
return backend.upper() in [
|
|
x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
|
|
]
|
|
|
|
|
|
def use_triton_template(layout, *, enable_int32=False):
|
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
|
if enable_int32:
|
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
|
|
return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
|
|
"TRITON"
|
|
)
|
|
|
|
|
|
def use_cutlass_template(layout, m, n, k):
|
|
from .virtualized import V
|
|
|
|
gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
|
|
if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
|
|
return False
|
|
from .codegen.cuda.cutlass_utils import try_import_cutlass
|
|
|
|
# Do not use cutlass template on ROCm
|
|
if torch.version.hip:
|
|
return False
|
|
|
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
|
|
res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
|
|
"CUTLASS"
|
|
)
|
|
|
|
if res:
|
|
if not try_import_cutlass():
|
|
log.warning(
|
|
"Failed to import CUTLASS lib. Please check whether "
|
|
"_inductor.config.cuda.cutlass_dir is set correctly. "
|
|
"Skipping CUTLASS backend for now."
|
|
)
|
|
return False
|
|
return res
|
|
|
|
|
|
def _use_template_for_cpu(layout):
|
|
return use_max_autotune() and layout.device.type == "cpu"
|
|
|
|
|
|
def use_cpp_packed_gemm_template(layout, mat1, mat2):
|
|
from . import ir
|
|
from .codegen.cpp_micro_gemm import create_micro_gemm
|
|
from .kernel.mm_common import mm_args
|
|
|
|
if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
|
|
return False
|
|
|
|
if not config.cpp.weight_prepack:
|
|
return False
|
|
|
|
layout_dtypes = [torch.float32]
|
|
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2)
|
|
# TODO(jgong5): support dynamic shapes for n or k
|
|
if has_free_symbols((n, k)):
|
|
return False
|
|
if isinstance(mat2, ir.BaseView):
|
|
mat2 = mat2.unwrap_view()
|
|
micro_gemm = create_micro_gemm(
|
|
"micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads()
|
|
)
|
|
# TODO(jgong5): support n % n_block_size != 0
|
|
return (
|
|
layout.dtype in layout_dtypes
|
|
and micro_gemm is not None
|
|
and n % micro_gemm.register_blocking[1] == 0
|
|
and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input
|
|
and isinstance(mat2, ir.StorageBox)
|
|
and mat2.is_module_buffer()
|
|
)
|
|
|
|
|
|
def use_aten_gemm_kernels():
|
|
return not use_max_autotune() or _use_autotune_backend("ATEN")
|
|
|
|
|
|
class DebugDirManager:
|
|
counter = itertools.count(0)
|
|
prev_debug_name: str
|
|
|
|
def __init__(self):
|
|
self.id = next(DebugDirManager.counter)
|
|
|
|
def __enter__(self):
|
|
self.prev_debug_name = torch._dynamo.config.debug_dir_root
|
|
self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
|
|
torch._dynamo.config.debug_dir_root = self.new_name
|
|
|
|
def __exit__(self, *args):
|
|
shutil.rmtree(self.new_name)
|
|
torch._dynamo.config.debug_dir_root = self.prev_debug_name
|
|
|
|
|
|
def run_and_get_code(fn, *args, **kwargs):
|
|
from .graph import GraphLowering
|
|
|
|
compile_to_module = GraphLowering.compile_to_module
|
|
source_codes: List[str] = []
|
|
|
|
def patched_compile_to_module(self):
|
|
mod = compile_to_module(self)
|
|
with open(mod.__file__) as f:
|
|
source_codes.append(f.read())
|
|
return mod
|
|
|
|
# If FX code caching is enabled, a hit prevents getting the code.
|
|
with config.patch({"fx_graph_cache": False}):
|
|
with mock.patch.object(
|
|
GraphLowering, "compile_to_module", patched_compile_to_module
|
|
):
|
|
torch._dynamo.reset()
|
|
result = fn(*args, **kwargs)
|
|
return result, source_codes
|
|
|
|
|
|
def get_code(fn, *args, **kwargs):
|
|
"""Get the inductor-generated code, but skip any actual compilation or running."""
|
|
from .graph import GraphLowering
|
|
|
|
source_codes: List[str] = []
|
|
|
|
def patched_compile_to_module(self: GraphLowering):
|
|
class DummyModule:
|
|
"""This is empty to replace the generated triton module"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def call(self, *args, **kwargs):
|
|
# Don't do anything when called
|
|
pass
|
|
|
|
code, _ = (
|
|
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
|
|
)
|
|
# Skip all the actual compiling.
|
|
|
|
source_codes.append(code)
|
|
return DummyModule()
|
|
|
|
# If FX code caching is enabled, a hit prevents getting the code.
|
|
with config.patch({"fx_graph_cache": False}):
|
|
with mock.patch.object(
|
|
GraphLowering, "compile_to_module", patched_compile_to_module
|
|
):
|
|
torch._dynamo.reset()
|
|
# Note the return here is None
|
|
_ = fn(*args, **kwargs)
|
|
|
|
return source_codes
|
|
|
|
|
|
def get_triton_code(fn, *args, **kwargs):
|
|
source_codes = get_code(fn, *args, **kwargs)
|
|
# Can have two outputs if backwards was eagerly compiled
|
|
assert (
|
|
1 <= len(source_codes) <= 2
|
|
), f"expected one or two code outputs got {len(source_codes)}"
|
|
return source_codes[0]
|
|
|
|
|
|
def run_and_get_triton_code(fn, *args, **kwargs):
|
|
_, source_codes = run_and_get_code(fn, *args, **kwargs)
|
|
# Can have two outputs if backwards was eagerly compiled
|
|
assert (
|
|
1 <= len(source_codes) <= 2
|
|
), f"expected one or two code outputs got {len(source_codes)}"
|
|
return source_codes[0]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def override_lowering(aten_op, override_fn):
|
|
"""
|
|
Override the lowering of aten_op with override_fn.
|
|
The first argument of override_fn is the original lowering fn.
|
|
"""
|
|
from torch._inductor import lowering
|
|
|
|
orig_fn = lowering.lowerings[aten_op]
|
|
try:
|
|
lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
|
|
yield
|
|
finally:
|
|
lowering.lowerings[aten_op] = orig_fn
|
|
|
|
|
|
def add_scheduler_init_hook(pre_fn, post_fn=None):
|
|
"""
|
|
Add hook functions to be called at the beginning and end of Scheduler.__init__.
|
|
Used for unit tests.
|
|
"""
|
|
from torch._inductor.scheduler import Scheduler
|
|
|
|
orig_fn = Scheduler.__init__
|
|
|
|
def wrapper(scheduler, nodes):
|
|
pre_fn(scheduler, nodes)
|
|
out = orig_fn(scheduler, nodes)
|
|
if post_fn:
|
|
post_fn(scheduler, nodes)
|
|
return out
|
|
|
|
return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
|
|
|
|
|
|
def developer_warning(msg):
|
|
"""
|
|
Warnings that will be actionable for PyTorch developers, but not
|
|
end users. Allows us to easily disable them in stable releases but
|
|
keep them on for nightly builds.
|
|
"""
|
|
if config.developer_warnings:
|
|
log.warning(msg)
|
|
else:
|
|
log.info(msg)
|
|
|
|
|
|
def get_benchmark_name():
|
|
"""
|
|
An experimental API used only when config.benchmark_kernel is true.
|
|
|
|
The benchmark name is only available at codegen time. So we can not
|
|
directly call it in benchmark_all_kernels which is run after codegen.
|
|
|
|
The function assumes the argument after --only is the benchmark name.
|
|
It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
|
|
scripts, this function may return None.
|
|
|
|
There are 2 flavors of --only argument we need handle:
|
|
1. --only model_name
|
|
2. --only=model_name
|
|
"""
|
|
try:
|
|
idx = sys.argv.index("--only")
|
|
if (
|
|
idx + 1 < len(sys.argv)
|
|
and len(sys.argv[idx + 1]) > 0
|
|
and sys.argv[idx + 1][0] != "-"
|
|
):
|
|
return sys.argv[idx + 1]
|
|
except ValueError:
|
|
pass
|
|
|
|
for arg in sys.argv:
|
|
if arg.startswith("--only="):
|
|
return arg[len("--only=") :]
|
|
|
|
|
|
def is_ones(items):
|
|
return all(x == 1 for x in items)
|
|
|
|
|
|
def is_zeros(items):
|
|
return all(x == 0 for x in items)
|
|
|
|
|
|
def is_cpu_device(inputs):
|
|
return all(
|
|
item.device == torch.device("cpu")
|
|
for item in inputs
|
|
if isinstance(item, torch.Tensor)
|
|
)
|
|
|
|
|
|
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
|
assert isinstance(
|
|
val, sympy.Expr
|
|
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
|
if val.is_integer: # type: ignore[attr-defined]
|
|
return torch.int64
|
|
else:
|
|
return torch.float64
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def maybe_profile(should_profile, *args, **kwargs):
|
|
if should_profile:
|
|
with torch.profiler.profile(*args, **kwargs) as p:
|
|
yield p
|
|
else:
|
|
yield
|
|
|
|
|
|
def parallel_num_threads():
|
|
threads = config.cpp.threads
|
|
if threads < 1:
|
|
threads = torch.get_num_threads()
|
|
return threads
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_device_tflops(dtype):
|
|
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
|
|
|
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
|
|
|
|
if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
|
|
# Triton API change in https://github.com/openai/triton/pull/2293
|
|
from torch._utils_internal import max_clock_rate
|
|
|
|
sm_clock = max_clock_rate()
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
return get_max_tensorcore_tflops(dtype, sm_clock)
|
|
|
|
if torch.backends.cuda.matmul.allow_tf32:
|
|
return get_max_tensorcore_tflops(torch.float32, sm_clock)
|
|
else:
|
|
return get_max_simd_tflops(torch.float32, sm_clock)
|
|
else:
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
return get_max_tensorcore_tflops(dtype)
|
|
|
|
if torch.backends.cuda.matmul.allow_tf32:
|
|
return get_max_tensorcore_tflops(torch.float32)
|
|
else:
|
|
return get_max_simd_tflops(torch.float32)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_gpu_dram_gbps():
|
|
from triton.testing import get_dram_gbps
|
|
|
|
return get_dram_gbps()
|
|
|
|
|
|
def get_gpu_shared_memory():
|
|
from triton.runtime import driver
|
|
|
|
return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
|
|
|
|
|
|
def is_welford_reduction(reduction_type):
|
|
return reduction_type.startswith("welford")
|
|
|
|
|
|
def reduction_num_outputs(reduction_type):
|
|
return 3 if is_welford_reduction(reduction_type) else 1
|
|
|
|
|
|
def is_linux() -> bool:
|
|
return platform.system() == "Linux"
|
|
|
|
|
|
def has_free_symbols(itr: Iterable[Any]):
|
|
return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
|
|
|
|
|
|
def is_dynamic(*args):
|
|
from . import ir
|
|
|
|
for t in args:
|
|
if isinstance(t, ir.TensorBox):
|
|
if has_free_symbols(t.data.get_size()) or (
|
|
hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride())
|
|
):
|
|
return True
|
|
elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)):
|
|
assert hasattr(t, "get_size") and hasattr(t, "get_stride")
|
|
if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()):
|
|
return True
|
|
elif not isinstance(t, ir.IRNode):
|
|
continue
|
|
else:
|
|
raise TypeError(f"unexpected type for is_dynamic {type(t)}")
|
|
|
|
return False
|
|
|
|
|
|
# Placeholder strings used in triton codegen.
|
|
class Placeholder(enum.Enum):
|
|
# The placeholder for the actual name of a triton kernel.
|
|
# e.g. for "def triton_" it would be "triton_"
|
|
KERNEL_NAME = "KERNEL_NAME"
|
|
|
|
# The descriptive name of the triton kernel; when unique_kernel_names = False, this
|
|
# placeholder will be replaced with a string with more information.
|
|
DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
|
|
|
|
|
|
def pass_execution_and_save(func, gm, inp, msg):
|
|
from .pattern_matcher import stable_topological_sort
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w",
|
|
encoding="utf-8",
|
|
delete=False,
|
|
) as f:
|
|
before_io = io.StringIO()
|
|
after_io = io.StringIO()
|
|
ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
|
|
print(f"Before:\n{gm.graph}", file=f)
|
|
print(gm.graph, file=before_io)
|
|
start_time = datetime.now()
|
|
func(gm.graph)
|
|
time_elapsed = datetime.now() - start_time
|
|
# recompile graph
|
|
stable_topological_sort(gm.graph)
|
|
gm.graph.lint()
|
|
gm.recompile()
|
|
|
|
print(f"After:\n{gm.graph}", file=f)
|
|
print(gm.graph, file=after_io)
|
|
t = before_io.getvalue() == after_io.getvalue()
|
|
log.info(
|
|
"%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
|
|
msg,
|
|
f.name,
|
|
t,
|
|
time_elapsed,
|
|
)
|
|
|
|
|
|
def is_collective(node):
|
|
from . import ir
|
|
|
|
return type(node) == ir._CollectiveKernel
|
|
|
|
|
|
def is_wait(node):
|
|
from . import ir
|
|
|
|
return type(node) == ir._WaitKernel
|
|
|
|
|
|
def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int):
|
|
"Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
|
|
num_rng_seed_offset_inputs = (
|
|
2 if torch._functorch.config.functionalize_rng_ops else 0
|
|
)
|
|
return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
|
|
|
|
|
|
def count_tangents(fx_g: torch.fx.GraphModule):
|
|
"""
|
|
Infers which inputs are static for a backwards graph
|
|
"""
|
|
|
|
def is_saved_tensor(x):
|
|
return (
|
|
"tangents" not in x.name
|
|
and "bwd_seed" not in x.name
|
|
and "bwd_base_offset" not in x.name
|
|
)
|
|
|
|
arg_count = 0
|
|
static_arg_idxs = []
|
|
for n in fx_g.graph.nodes:
|
|
if n.op == "placeholder":
|
|
if is_saved_tensor(n):
|
|
static_arg_idxs.append(arg_count)
|
|
arg_count += 1
|
|
|
|
assert static_arg_idxs == list(range(len(static_arg_idxs)))
|
|
return len(static_arg_idxs)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BoxedBool:
|
|
value: bool
|
|
|
|
def __bool__(self):
|
|
return self.value
|
|
|
|
@staticmethod
|
|
def disable(obj):
|
|
if isinstance(obj, BoxedBool):
|
|
obj.value = False
|
|
return obj
|
|
return False
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def collect_defined_kernels(kernel_list):
|
|
from .codegen.wrapper import WrapperCodeGen
|
|
|
|
orig_define_kernel = WrapperCodeGen.define_kernel
|
|
|
|
def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs):
|
|
nonlocal kernel_list
|
|
kernel_list.append(kernel_code)
|
|
return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs)
|
|
|
|
with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel):
|
|
yield
|
|
|
|
|
|
def get_cloned_parameter_buffer_name(name: str):
|
|
return name + "__original__"
|
|
|
|
|
|
def is_gpu(device: str):
|
|
return device in ["cuda", "xpu"]
|
|
|
|
|
|
def device_need_guard(device: str):
|
|
assert isinstance(device, str)
|
|
return is_gpu(device)
|
|
|
|
|
|
def needs_fallback_due_to_atomic_add_limitations(dtype):
|
|
# tl.atomic_add does NOT support the following types
|
|
return dtype in {torch.int64, torch.bool, torch.bfloat16}
|
|
|
|
|
|
def use_scatter_fallback(
|
|
op_overload: torch._ops.OpOverload,
|
|
reduction_type,
|
|
self_dtype,
|
|
src_dtype,
|
|
src_device_type,
|
|
src_is_tensor,
|
|
):
|
|
reduce_ty = (
|
|
"add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
|
|
)
|
|
|
|
return (
|
|
reduction_type not in {None, reduce_ty}
|
|
or (
|
|
src_is_tensor
|
|
and is_gpu(src_device_type)
|
|
and needs_fallback_due_to_atomic_add_limitations(src_dtype)
|
|
)
|
|
or (
|
|
op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
|
|
and reduction_type == "sum"
|
|
and src_is_tensor
|
|
and src_device_type == "cpu"
|
|
and config.cpp.fallback_scatter_reduce_sum
|
|
and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
|
|
)
|
|
or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64})
|
|
or torch.are_deterministic_algorithms_enabled()
|
|
)
|
|
|
|
|
|
def dump_node_schedule(node_schedule):
|
|
"""
|
|
An API that can be used in pdb to dump a node_schedule.
|
|
Right mainly dump the read/write dependencies but can add more as needed.
|
|
"""
|
|
from torch._inductor.codegen.simd import DisableReduction, EnableReduction
|
|
from torch._inductor.scheduler import SchedulerNode
|
|
|
|
print(f"Node schedule with {len(node_schedule)} nodes")
|
|
for idx, node in enumerate(node_schedule):
|
|
print(f" {idx:3}:")
|
|
if node is EnableReduction:
|
|
print("enable reduction")
|
|
elif node is DisableReduction:
|
|
print("disable reduction")
|
|
elif isinstance(node, SchedulerNode):
|
|
is_red = node.is_reduction()
|
|
print(f"{'red' if is_red else 'pw'} scheduler node")
|
|
if is_red:
|
|
assert node.node is not None
|
|
print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
|
|
print("ReadDep:")
|
|
for dep in node.read_writes.reads:
|
|
print(dep)
|
|
print("WriteDep:")
|
|
for dep in node.read_writes.writes:
|
|
print(dep)
|
|
else:
|
|
raise RuntimeError(f"Unrecognized node type: {type(node)}")
|
|
|
|
|
|
def tensor_is_aligned(tensor: torch.Tensor):
|
|
# See Note: [Input Alignment handling in Inductor]
|
|
# Right now, we don't try to guard on the alignment of the storage offset.
|
|
# When this comment was written, non-symbolic storage_offsets are not guarded on
|
|
# but symbolic storage_offsets are. For consistency, we suppress guard creation
|
|
# upon performing this check: that ensures that we don't add recompiles when we
|
|
# add this logic.
|
|
return (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % ALIGNMENT == 0
|
|
|
|
|
|
def should_assume_input_aligned(example_input: torch.Tensor):
|
|
# See Note: [Input Alignment handling in Inductor]
|
|
|
|
# right now, we only care about alignment for cuda tensors.
|
|
if not is_gpu(example_input.device.type):
|
|
return False
|
|
return config.assume_aligned_inputs or tensor_is_aligned(example_input)
|
|
|
|
|
|
def maybe_get_suppress_shape_guards_ctx():
|
|
# Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
|
|
# If it's not available, return a nullcontext.
|
|
|
|
# If we're dealing with cudagraphs, we might not have a tracing_context
|
|
tracing_context = torch._guards.TracingContext.try_get()
|
|
if not tracing_context:
|
|
return contextlib.nullcontext()
|
|
|
|
# In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
|
|
shape_env = tracing_context.fake_mode.shape_env
|
|
if not shape_env:
|
|
return contextlib.nullcontext()
|
|
|
|
return shape_env.suppress_guards()
|
|
|
|
|
|
def aoti_eager_cache_dir(namespace: str, device: str):
|
|
return Path(cache_dir()) / "aoti_eager" / namespace / device
|
|
|
|
|
|
def aoti_eager_op_conf_lock(op_func_name_with_overload: str):
|
|
from filelock import FileLock
|
|
|
|
# Avoid circular import
|
|
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
|
|
|
|
op_conf_lock_file = f"{op_func_name_with_overload}.lock"
|
|
lock_dir = get_lock_dir()
|
|
return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT)
|
|
|
|
|
|
def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str):
|
|
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
|
|
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
|
|
if not op_conf.exists():
|
|
return []
|
|
|
|
with aoti_eager_op_conf_lock(op_func_name_with_overload):
|
|
with open(op_conf) as f:
|
|
json_data = json.load(f)
|
|
for item in json_data:
|
|
# Get absolution path for kernel library
|
|
kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
|
|
item["kernel_path"] = kernel_lib_abs_path.as_posix()
|
|
|
|
# Check if the kernel library exists
|
|
if not kernel_lib_abs_path.exists():
|
|
return []
|
|
|
|
for metadata in item["meta_info"]:
|
|
assert not metadata[
|
|
"is_dynamic"
|
|
], "Only support static shape for now"
|
|
if metadata["device_type"] == "cpu":
|
|
metadata["device_index"] = -1
|
|
metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1])
|
|
|
|
return json_data
|
|
|
|
|
|
def aoti_compile_with_persistent_cache(
|
|
ns: str,
|
|
op_func_name_with_overload: str,
|
|
device_type: str,
|
|
dynamic: bool,
|
|
f: Callable[..., Any],
|
|
args: Tuple[Any],
|
|
kwargs: Dict[str, Any],
|
|
*,
|
|
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
|
options: Optional[Dict[str, Any]] = None,
|
|
remove_runtime_assertions: bool = False,
|
|
disable_constraint_solver: bool = False,
|
|
):
|
|
"""
|
|
Compile the given function with persistent cache for AOTI eager mode.
|
|
"""
|
|
assert not dynamic, "Only support static shape for now"
|
|
type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool}
|
|
supported_scalar_types = tuple(type_to_torch_dtype.keys())
|
|
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
|
if not all(
|
|
isinstance(input, (supported_scalar_types, torch.Tensor))
|
|
for input in flattened_inputs
|
|
):
|
|
raise NotImplementedError("Only support tensor, int, float, bool for now")
|
|
|
|
persistent_cache = aoti_eager_cache_dir(ns, device_type)
|
|
if not persistent_cache.exists():
|
|
persistent_cache.mkdir(parents=True)
|
|
|
|
persistent_cache_lib = persistent_cache / "lib"
|
|
if not persistent_cache_lib.exists():
|
|
persistent_cache_lib.mkdir()
|
|
|
|
with mock.patch.dict(
|
|
os.environ,
|
|
{"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()},
|
|
):
|
|
try:
|
|
kernel_lib_path = torch._export.aot_compile(
|
|
f,
|
|
args,
|
|
kwargs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
options=options,
|
|
remove_runtime_assertions=remove_runtime_assertions,
|
|
disable_constraint_solver=disable_constraint_solver,
|
|
# Some operations may have non-Tensor parameters like int, float, bool. These
|
|
# non-Tensor parameters will not be the input of the graph. Therefore, we do
|
|
# need to keep the same signature.
|
|
same_signature=False,
|
|
)
|
|
|
|
kernel_metadata_items = []
|
|
for input in flattened_inputs:
|
|
# TODO(Eikan): To add dynamic support
|
|
metadata: Dict[str, Any] = {}
|
|
metadata["is_dynamic"] = dynamic
|
|
|
|
if isinstance(input, torch.Tensor):
|
|
metadata["device_type"] = f"{input.device.type}"
|
|
if is_cpu_device([input]):
|
|
metadata["device_index"] = -1
|
|
else:
|
|
metadata["device_index"] = input.device.index
|
|
metadata["dtype"] = f"{input.dtype}"
|
|
metadata["sizes"] = list(input.size())
|
|
metadata["strides"] = list(input.stride())
|
|
else:
|
|
assert isinstance(input, supported_scalar_types)
|
|
# Scalar tensor
|
|
metadata["device_type"] = device_type
|
|
metadata["device_index"] = -1 if device_type == "cpu" else 0
|
|
metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}"
|
|
metadata["sizes"] = []
|
|
metadata["strides"] = []
|
|
metadata["scalar_value"] = input
|
|
|
|
kernel_metadata_items.append(metadata)
|
|
|
|
kernel_meta_info: Dict[str, Any] = {}
|
|
kernel_meta_info["meta_info"] = kernel_metadata_items
|
|
kernel_meta_info["kernel_path"] = (
|
|
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
|
|
)
|
|
|
|
json_data = []
|
|
update_json = True
|
|
op_conf = persistent_cache / f"{op_func_name_with_overload}.json"
|
|
mode = "r" if op_conf.exists() else "w"
|
|
with aoti_eager_op_conf_lock(op_func_name_with_overload):
|
|
with open(op_conf, mode) as op_conf_file:
|
|
try:
|
|
json_data = json.load(op_conf_file)
|
|
except Exception as e:
|
|
json_data = []
|
|
|
|
assert isinstance(json_data, list)
|
|
for item in json_data:
|
|
assert isinstance(item, dict)
|
|
# Same kernel meta info already exists in the json file
|
|
if item["meta_info"] == kernel_metadata_items:
|
|
update_json = False
|
|
break
|
|
|
|
if update_json:
|
|
json_data.append(kernel_meta_info)
|
|
with open(op_conf, "w") as op_conf_file:
|
|
json.dump(json_data, op_conf_file, indent=4)
|
|
|
|
return kernel_lib_path
|
|
except Exception as e:
|
|
return ""
|
|
|
|
|
|
def run_and_get_cpp_code(fn, *args, **kwargs):
|
|
# We use the patch context manager instead of using it as a decorator.
|
|
# In this way, we can ensure that the attribute is patched and unpatched correctly
|
|
# even if this run_and_get_cpp_code function is called multiple times.
|
|
with unittest.mock.patch.object(config, "debug", True):
|
|
torch._dynamo.reset()
|
|
import io
|
|
import logging
|
|
|
|
log_capture_string = io.StringIO()
|
|
ch = logging.StreamHandler(log_capture_string)
|
|
from torch._inductor.graph import output_code_log
|
|
|
|
output_code_log.addHandler(ch)
|
|
prev_level = output_code_log.level
|
|
output_code_log.setLevel(logging.DEBUG)
|
|
result = fn(*args, **kwargs)
|
|
s = log_capture_string.getvalue()
|
|
output_code_log.setLevel(prev_level)
|
|
output_code_log.removeHandler(ch)
|
|
return result, s
|