mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89203 Approved by: https://github.com/jansel, https://github.com/ngimel
289 lines
9.8 KiB
Python
289 lines
9.8 KiB
Python
import collections
|
|
import dataclasses
|
|
import itertools
|
|
import logging
|
|
import typing
|
|
from typing import Callable, cast, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import sympy
|
|
|
|
from . import config
|
|
from .codegen.common import index_prevent_reordering
|
|
from .utils import (
|
|
get_dtype_size,
|
|
sympy_product,
|
|
sympy_str,
|
|
sympy_subs,
|
|
sympy_symbol,
|
|
VarRanges,
|
|
)
|
|
from .virtualized import V
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
Dep = Union["MemoryDep", "StarDep"]
|
|
|
|
|
|
class MemoryDep(typing.NamedTuple):
|
|
name: str
|
|
index: sympy.Expr # type: ignore[assignment]
|
|
size: Tuple[sympy.Expr, ...]
|
|
|
|
def broadcast_extend_sizes(self, extra_sizes: List[sympy.Expr]) -> "MemoryDep":
|
|
size = (*self.size, *[x for x in extra_sizes if x != 1])
|
|
return MemoryDep(self.name, self.index, size)
|
|
|
|
def maybe_swap_sizes(self) -> "MemoryDep":
|
|
# swap only in simple cases where index is trivial and
|
|
# there are just 2 sizes
|
|
if (
|
|
len(self.size) == 2
|
|
and len(self.index.args) == 0
|
|
and cast(sympy.Symbol, self.index).name == canonicalization_prefix() + "0"
|
|
):
|
|
c = canonicalization_prefix()
|
|
size = (self.size[1], self.size[0])
|
|
s0 = sympy_symbol(c + "0")
|
|
s1 = sympy_symbol(c + "1")
|
|
index = sympy_subs(self.index, {s0: s1})
|
|
return MemoryDep(self.name, index, size)
|
|
else:
|
|
return self
|
|
|
|
def strip_last_size(self) -> "MemoryDep":
|
|
nsizes = len(self.size)
|
|
if not (nsizes >= 1 and len(self.index.args) <= nsizes - 1):
|
|
return self
|
|
# make sure last dim index is not used
|
|
prefix = canonicalization_prefix()
|
|
len_prefix = len(prefix)
|
|
prefixes = [
|
|
fs.name[:len_prefix]
|
|
for fs in cast(Set[sympy.Symbol], self.index.free_symbols)
|
|
]
|
|
assert (
|
|
len(prefixes) == 0 or prefix in prefixes
|
|
), "index expression should contain canonicalized symbols"
|
|
last_index = f"{prefix}{len(self.size)-1}"
|
|
if last_index not in self.index.free_symbols:
|
|
size = self.size[:-1]
|
|
return MemoryDep(self.name, self.index, size)
|
|
else:
|
|
return self
|
|
|
|
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
|
|
if self.name in renames:
|
|
return MemoryDep(renames[self.name], self.index, self.size)
|
|
return self
|
|
|
|
def numbytes_hint(self):
|
|
vars = set(self.index.free_symbols)
|
|
size_vars_used = []
|
|
for var in vars:
|
|
if var.name.startswith(canonicalization_prefix()):
|
|
# Sometimes with indirect indexing we have very weird symbol names
|
|
assert " " not in var.name
|
|
size_vars_used.append(int(var.name[len(canonicalization_prefix()) :]))
|
|
|
|
return V.graph.sizevars.size_hint(
|
|
sympy_product([self.size[i] for i in size_vars_used])
|
|
) * get_dtype_size(V.graph.get_dtype(self.name))
|
|
|
|
def is_contiguous(self) -> bool:
|
|
return isinstance(self.index, (sympy.Symbol, sympy.Integer))
|
|
|
|
|
|
class StarDep(typing.NamedTuple):
|
|
# depends on the entire buffer
|
|
name: str
|
|
|
|
def rename(self, renames: Dict[str, str]) -> "StarDep":
|
|
if self.name in renames:
|
|
return StarDep(renames[self.name])
|
|
return self
|
|
|
|
def numbytes_hint(self):
|
|
from .ir import MultiOutputLayout
|
|
|
|
if self.name in V.graph.name_to_buffer:
|
|
buf = V.graph.name_to_buffer[self.name]
|
|
elif self.name in V.graph.graph_inputs:
|
|
buf = V.graph.graph_inputs[self.name]
|
|
else:
|
|
return 1
|
|
if hasattr(buf, "layout") and isinstance(buf.layout, MultiOutputLayout):
|
|
# NB: Too annoying to acquire, should only be used for instrumentation
|
|
return 1
|
|
return V.graph.sizevars.size_hint(
|
|
sympy_product(buf.get_size())
|
|
) * get_dtype_size(buf.get_dtype())
|
|
|
|
def is_contiguous(self) -> bool:
|
|
return False
|
|
|
|
|
|
class IndexExprDep(typing.NamedTuple):
|
|
index: sympy.Expr # type: ignore[assignment]
|
|
size: Tuple[sympy.Expr, ...]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ReadWrites:
|
|
reads: Set[Dep]
|
|
writes: Set[Dep]
|
|
index_exprs: Set[IndexExprDep]
|
|
range_vars: Optional[List[sympy.Expr]] = None
|
|
var_ranges: Optional[VarRanges] = None
|
|
|
|
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
|
|
return ReadWrites(
|
|
{dep.rename(renames) for dep in self.reads},
|
|
{dep.rename(renames) for dep in self.writes},
|
|
self.index_exprs,
|
|
self.range_vars,
|
|
self.var_ranges,
|
|
)
|
|
|
|
def with_read(self, name: str) -> "ReadWrites":
|
|
assert isinstance(name, str)
|
|
return ReadWrites(
|
|
set.union(self.reads, {StarDep(name)}),
|
|
self.writes,
|
|
self.index_exprs,
|
|
self.range_vars,
|
|
self.var_ranges,
|
|
)
|
|
|
|
def merge(self, other):
|
|
reads = set.union(self.reads, other.reads)
|
|
writes = set.union(self.writes, other.writes)
|
|
index_exprs = set.union(self.index_exprs, other.index_exprs)
|
|
return ReadWrites(
|
|
reads - writes,
|
|
writes,
|
|
index_exprs,
|
|
)
|
|
|
|
|
|
class RecordLoadStore(V.MockHandler): # type: ignore[name-defined]
|
|
def __init__(self, var_ranges: VarRanges, normalize: bool):
|
|
super(RecordLoadStore, self).__init__()
|
|
self._reads: Set[MemoryDep] = set()
|
|
self._writes: Set[MemoryDep] = set()
|
|
self._index_exprs: Set[IndexExprDep] = set()
|
|
self._var_ranges: VarRanges = var_ranges
|
|
self._normalize: bool = normalize
|
|
|
|
# Truncate the expr str by a threshold to prevent it's too long
|
|
# and cause process hanging. The result is not used.
|
|
# https://github.com/pytorch/torchdynamo/issues/1352
|
|
@staticmethod
|
|
def truncate_expr(expr):
|
|
if len(expr) > config.realize_bytes_threshold:
|
|
expr = f"{expr[:config.realize_bytes_threshold]}..."
|
|
return expr
|
|
|
|
def canonicalize(
|
|
self, index: sympy.Expr
|
|
) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]:
|
|
sizes = list(self._var_ranges.values())
|
|
sizes = [V.graph.sizevars.simplify(x) for x in sizes]
|
|
if not self._normalize:
|
|
return index, tuple([x for x in sizes if x != 1])
|
|
|
|
# Try to further simplify the indexes even if simplify_loops didn't
|
|
# convert it to the simpliest form because of the interference from
|
|
# different indexing formulas.
|
|
index_vars = list(self._var_ranges.keys())
|
|
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
|
|
index_vars,
|
|
sizes,
|
|
index_prevent_reordering([index], index_vars, sizes),
|
|
)
|
|
|
|
# assign new variables each dimension to deal with numbering mismatches
|
|
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1
|
|
_, add_var = var_builder(canonicalization_prefix())
|
|
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
|
|
|
index = sympy_subs(sympy.expand(index), replacement)
|
|
return index, tuple(new_sizes)
|
|
|
|
def load(self, name: str, index: sympy.Expr) -> str:
|
|
canonicalized_index, canonicalized_size = self.canonicalize(index)
|
|
self._reads.add(MemoryDep(name, canonicalized_index, canonicalized_size))
|
|
return f"load({name}, {sympy_str(index)})"
|
|
|
|
def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
|
|
canonicalized_index, canonicalized_size = self.canonicalize(index)
|
|
self._writes.add(MemoryDep(name, canonicalized_index, canonicalized_size))
|
|
return f"store({name}, {sympy_str(index)}, {value}, {mode})"
|
|
|
|
def reduction(
|
|
self, name: str, dtype, src_dtype, reduction_type, index, value
|
|
) -> str:
|
|
return self.store(name, index, f"reduce_{reduction_type})({value})")
|
|
|
|
def index_expr(self, index: sympy.Expr, dtype) -> str:
|
|
canonicalized_index, canonicalized_size = self.canonicalize(index)
|
|
self._index_exprs.add(IndexExprDep(canonicalized_index, canonicalized_size))
|
|
return f"index_expr({sympy_str(index)}, {dtype})"
|
|
|
|
|
|
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
|
|
cnt = itertools.count()
|
|
var_ranges: VarRanges = collections.OrderedDict()
|
|
|
|
def add_var(length: sympy.Expr) -> sympy.Symbol:
|
|
v = sympy_symbol(f"{prefix}{next(cnt)}")
|
|
var_ranges[v] = length
|
|
return v
|
|
|
|
return var_ranges, add_var
|
|
|
|
|
|
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
|
|
var_ranges, add_var = var_builder(prefix)
|
|
args: List[List[sympy.Symbol]] = []
|
|
for size in argsizes:
|
|
args.append(list(map(add_var, size)))
|
|
return args, var_ranges
|
|
|
|
|
|
def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
|
|
from .ir import SqueezeView
|
|
|
|
var_ranges, add_var = var_builder(prefix)
|
|
args: List[List[sympy.Expr]] = []
|
|
new_sizes: List[List[sympy.Expr]] = []
|
|
for size in argsizes:
|
|
new_size, reindex = SqueezeView.squeezer(size)
|
|
new_sizes.append(new_size)
|
|
args.append(reindex(list(map(add_var, new_size))))
|
|
return new_sizes, args, var_ranges
|
|
|
|
|
|
def extract_read_writes(
|
|
fn: Callable,
|
|
*argsizes: Tuple[sympy.Expr, ...],
|
|
normalize: bool = False,
|
|
prefix: str = "d",
|
|
):
|
|
_, args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
|
|
rw = RecordLoadStore(var_ranges, normalize=normalize)
|
|
with V.set_ops_handler(rw): # type: ignore[call-arg]
|
|
fn(*args)
|
|
|
|
if normalize:
|
|
range_vars = [] # Number of vars could differ due to normalization
|
|
else:
|
|
range_vars = [*itertools.chain(*args)]
|
|
|
|
return ReadWrites(
|
|
set(rw._reads), set(rw._writes), rw._index_exprs, range_vars, var_ranges
|
|
)
|
|
|
|
|
|
def canonicalization_prefix():
|
|
return "c"
|