mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This reverts commit 62b775583f.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96014
Approved by: https://github.com/Chillee
180 lines
4.4 KiB
Python
180 lines
4.4 KiB
Python
import itertools
|
|
from contextlib import contextmanager
|
|
from itertools import chain
|
|
from threading import local
|
|
|
|
import sympy
|
|
|
|
from torch._inductor.utils import IndentedBuffer
|
|
|
|
from torch.fx.graph import inplace_methods, magic_methods
|
|
|
|
from .utils import sympy_str, sympy_symbol
|
|
|
|
threadlocal = local()
|
|
|
|
|
|
class Virtualized:
|
|
"""
|
|
A global variable that redirects via thread local variable
|
|
|
|
This allows us to swap in different op implementations in codegen.
|
|
"""
|
|
|
|
def __init__(self, vname, default):
|
|
self._key = f"__torchinductor_{vname}"
|
|
self._default = default
|
|
|
|
def _set_handler(self, value):
|
|
prior = self._get_handler()
|
|
setattr(threadlocal, self._key, value)
|
|
|
|
@contextmanager
|
|
def ctx():
|
|
try:
|
|
yield
|
|
finally:
|
|
self._set_handler(prior)
|
|
|
|
return ctx()
|
|
|
|
def _get_handler(self):
|
|
try:
|
|
return getattr(threadlocal, self._key)
|
|
except AttributeError:
|
|
return self._default()
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._get_handler(), name)
|
|
|
|
|
|
class NullHandler:
|
|
pass
|
|
|
|
|
|
def _arg_str(a):
|
|
if isinstance(a, sympy.Expr):
|
|
return sympy_str(a)
|
|
return str(a)
|
|
|
|
|
|
class MockHandler:
|
|
def __getattr__(self, name):
|
|
if name == "name":
|
|
return "MockHandler"
|
|
|
|
def inner(*args, **kwargs):
|
|
fargs = [_arg_str(a) for a in args]
|
|
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
|
|
return f"{name}({', '.join(fargs)})"
|
|
|
|
return inner
|
|
|
|
@staticmethod
|
|
def masked(mask, body, other):
|
|
return f"masked({mask}, {body()}, {other})"
|
|
|
|
@staticmethod
|
|
def indirect_indexing(index_var):
|
|
return sympy_symbol(f"({str(index_var)})")
|
|
|
|
@classmethod
|
|
def _init_cls(cls):
|
|
def make_handler(format_string):
|
|
@staticmethod
|
|
def inner(*args):
|
|
return format_string.format(*args)
|
|
|
|
return inner
|
|
|
|
for name, format_string in chain(
|
|
magic_methods.items(), inplace_methods.items()
|
|
):
|
|
setattr(cls, name, make_handler(format_string))
|
|
|
|
|
|
class KernelFormatterHandler:
|
|
def __init__(self, parent_handler):
|
|
self.parent_handler = parent_handler
|
|
self.output = IndentedBuffer()
|
|
self.var_counter = itertools.count()
|
|
|
|
def __getattr__(self, name):
|
|
def inner(*args, **kwargs):
|
|
line = getattr(self.parent_handler, name)(*args, **kwargs)
|
|
if name == "indirect_indexing":
|
|
return line
|
|
# replace line with a new variable name
|
|
varname = f"tmp{next(self.var_counter)}"
|
|
self.output.writeline(f"{varname} = {line}")
|
|
return varname
|
|
|
|
return inner
|
|
|
|
def getvalue(self, result):
|
|
self.output.writeline(f"return {result}")
|
|
return self.output.getvalue()
|
|
|
|
|
|
class WrapperHandler:
|
|
def __init__(self, inner):
|
|
self._inner = inner
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self._inner, item)
|
|
|
|
|
|
MockHandler._init_cls()
|
|
|
|
ops = Virtualized("ops", MockHandler)
|
|
_graph = Virtualized("graph", NullHandler)
|
|
_fake_mode = Virtualized("fake_mode", NullHandler)
|
|
_kernel = Virtualized("kernel", NullHandler)
|
|
_debug = Virtualized("debug", NullHandler)
|
|
_interpreter = Virtualized("interpreter", NullHandler)
|
|
|
|
|
|
class _V:
|
|
MockHandler = MockHandler
|
|
KernelFormatterHandler = KernelFormatterHandler
|
|
WrapperHandler = WrapperHandler
|
|
|
|
set_ops_handler = ops._set_handler
|
|
get_ops_handler = ops._get_handler
|
|
set_graph_handler = _graph._set_handler
|
|
set_fake_mode = _fake_mode._set_handler
|
|
set_kernel_handler = _kernel._set_handler
|
|
set_debug_handler = _debug._set_handler
|
|
set_interpreter_handler = _interpreter._set_handler
|
|
|
|
@property
|
|
def ops(self) -> MockHandler:
|
|
"""The operator handler specific to the current codegen task"""
|
|
return ops._get_handler()
|
|
|
|
@property
|
|
def graph(self):
|
|
"""The graph currently being generated"""
|
|
return _graph._get_handler()
|
|
|
|
@property
|
|
def fake_mode(self):
|
|
"""The graph currently being generated"""
|
|
return _fake_mode._get_handler()
|
|
|
|
@property
|
|
def kernel(self):
|
|
"""The kernel currently being generated"""
|
|
return _kernel._get_handler()
|
|
|
|
@property
|
|
def debug(self):
|
|
return _debug._get_handler()
|
|
|
|
@property
|
|
def interpreter(self):
|
|
return _interpreter._get_handler()
|
|
|
|
|
|
V = _V()
|