mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This reverts commit db0ce4acf3.
Reverted https://github.com/pytorch/pytorch/pull/88384 on behalf of https://github.com/malfet due to Broke test_public_bindings across the board
141 lines
3.2 KiB
Python
141 lines
3.2 KiB
Python
from contextlib import contextmanager
|
|
from itertools import chain
|
|
from threading import local
|
|
|
|
import sympy
|
|
|
|
from torch.fx.graph import inplace_methods, magic_methods
|
|
|
|
from .utils import sympy_str, sympy_symbol
|
|
|
|
threadlocal = local()
|
|
|
|
|
|
class Virtualized:
|
|
"""
|
|
A global variable that redirects via thread local variable
|
|
|
|
This allows us to swap in different op implementations in codegen.
|
|
"""
|
|
|
|
def __init__(self, vname, default):
|
|
self._key = f"__torchinductor_{vname}"
|
|
self._default = default
|
|
|
|
def _set_handler(self, value):
|
|
prior = self._get_handler()
|
|
setattr(threadlocal, self._key, value)
|
|
|
|
@contextmanager
|
|
def ctx():
|
|
try:
|
|
yield
|
|
finally:
|
|
self._set_handler(prior)
|
|
|
|
return ctx()
|
|
|
|
def _get_handler(self):
|
|
try:
|
|
return getattr(threadlocal, self._key)
|
|
except AttributeError:
|
|
return self._default()
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._get_handler(), name)
|
|
|
|
|
|
class NullHandler:
|
|
pass
|
|
|
|
|
|
def _arg_str(a):
|
|
if isinstance(a, sympy.Expr):
|
|
return sympy_str(a)
|
|
return str(a)
|
|
|
|
|
|
class MockHandler:
|
|
def __getattr__(self, name):
|
|
def inner(*args, **kwargs):
|
|
fargs = [_arg_str(a) for a in args]
|
|
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
|
|
return self.truncate_expr(f"{name}({', '.join(fargs)})")
|
|
|
|
return inner
|
|
|
|
@staticmethod
|
|
def truncate_expr(expr):
|
|
return expr
|
|
|
|
@classmethod
|
|
def masked(cls, mask, body, other):
|
|
return cls.truncate_expr(f"masked({mask}, {body()}, {other})")
|
|
|
|
@staticmethod
|
|
def indirect_indexing(index_var):
|
|
return sympy_symbol(f"({str(index_var)})")
|
|
|
|
@classmethod
|
|
def _init_cls(cls):
|
|
def make_handler(format_string):
|
|
@staticmethod
|
|
def inner(*args):
|
|
return format_string.format(*args)
|
|
|
|
return inner
|
|
|
|
for name, format_string in chain(
|
|
magic_methods.items(), inplace_methods.items()
|
|
):
|
|
setattr(cls, name, make_handler(format_string))
|
|
|
|
|
|
class WrapperHandler:
|
|
def __init__(self, inner):
|
|
self._inner = inner
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self._inner, item)
|
|
|
|
|
|
MockHandler._init_cls()
|
|
|
|
ops = Virtualized("ops", MockHandler)
|
|
_graph = Virtualized("graph", NullHandler)
|
|
_kernel = Virtualized("kernel", NullHandler)
|
|
_debug = Virtualized("debug", NullHandler)
|
|
|
|
|
|
class _V:
|
|
MockHandler = MockHandler
|
|
WrapperHandler = WrapperHandler
|
|
|
|
set_ops_handler = ops._set_handler
|
|
get_ops_handler = ops._get_handler
|
|
set_graph_handler = _graph._set_handler
|
|
set_kernel_handler = _kernel._set_handler
|
|
set_debug_handler = _debug._set_handler
|
|
|
|
@property
|
|
def ops(self) -> MockHandler:
|
|
"""The operator handler specific to the current codegen task"""
|
|
return ops._get_handler()
|
|
|
|
@property
|
|
def graph(self):
|
|
"""The graph currently being generated"""
|
|
return _graph._get_handler()
|
|
|
|
@property
|
|
def kernel(self):
|
|
"""The kernel currently being generated"""
|
|
return _kernel._get_handler()
|
|
|
|
@property
|
|
def debug(self):
|
|
return _debug._get_handler()
|
|
|
|
|
|
V = _V()
|