pytorch/torch/_dispatch/python.py
Brian Hirsh b0392de2c3 change pre_autograd to pre_dispatch tracing (#101818)
We discussed in a composability meeting a few weeks ago that `pre_autograd` should probably be renamed to `pre_dispatch`.

One question in this PR was: should I re-use a dispatch key? Or should I create a new dispatch key (that yet again corresponds to "top of the dispatcher")?

~~For now, I ended up sticking our proxy mode on the mode stack corresponding to `PythonTLSSnapshot`, because it was simple and it works. It looks like one of the functorch dispatch keys has higher priority though, so it's possible that functorch will end up running first. Open to options, but we can consider adding a new dispatch key later if that becomes a problem~~

Update: I added a dedicated dispatch key, `PreDispatch`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101818
Approved by: https://github.com/ezyang, https://github.com/Neilblaze, https://github.com/albanD, https://github.com/zou3519
2023-06-09 17:30:15 +00:00

150 lines
5.8 KiB
Python

import torch._C
from contextlib import contextmanager
import unittest.mock
import torch
import torch.utils._pytree as pytree
import itertools
from typing import Iterator
import torch._ops
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher', 'enable_pre_dispatch']
no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
"""
Warning: the set of overloads this will report is very subtle. It is precisely
the set of torch.ops functions that have actually been accessed from Python
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
from the set of registered operators, which will in general be a larger set,
as this would include all operators which we ran C++ static initializers or
Python operator registration on. This does not eagerly populate the list on
torch.ops.aten; this list is lazy!
In other words, this is good for traversing over everything that has an
OpOverload object allocated in Python. We use it for cache invalidation, but
don't rely on this list being complete.
Note that even if we did report all C++ registered overloads, this isn't guaranteed
to be complete either, as a subsequent lazy load of a library which triggers more
registrations could add more things to the set.
"""
for ns in torch.ops:
packets = getattr(torch.ops, ns)
for op_name in packets:
packet = getattr(packets, op_name)
for overload in packet:
yield getattr(packet, overload)
@contextmanager
def suspend_functionalization():
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
f_rv = torch._C._functionalization_reapply_views_tls()
if f_tls:
torch._disable_functionalization()
try:
yield
finally:
if f_tls:
torch._enable_functionalization(reapply_views=f_rv)
def check_tensor_metadata_matches(nv, rv, desc):
assert callable(desc)
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
def check_metadata_matches(n, r, desc):
assert callable(desc)
n_vals, n_spec = pytree.tree_flatten(n)
r_vals, r_spec = pytree.tree_flatten(r)
# TODO: test the specs match; empirically sometimes we have a tuple
# on one side and a list on the other
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
if not isinstance(rv, torch.Tensor):
continue
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
class Lit:
def __init__(self, s):
self.s = s
def __repr__(self):
return self.s
def _fmt(a: object) -> object:
if isinstance(a, torch.Tensor):
return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
else:
return a
def make_crossref_functionalize(op, final_key):
from torch._subclasses.fake_tensor import FakeTensorMode
# This case is pretty weird, suppress it for now
if op == torch.ops.aten.lift_fresh.default:
return final_key
def handler(*args, **kwargs):
fake_mode = FakeTensorMode()
def fakeify_defun(t):
if isinstance(t, torch.Tensor):
if torch._is_functional_tensor(t):
r = torch._from_functional_tensor(t)
# NB: This assumes that the inner tensor sizes/strides match
# the outer tensor sizes/strides. This doesn't necessarily have to
# be the case, see discussion at
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
assert t.size() == r.size()
assert t.stride() == r.stride()
else:
r = t
# TODO: suppress guards
return fake_mode.from_tensor(r)
return t
def maybe_detach(t):
if isinstance(t, torch.Tensor):
return t.detach()
else:
return t
with suspend_functionalization():
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
with fake_mode:
f_r = op(*f_args, **f_kwargs)
r = op._op_dk(final_key, *args, **kwargs)
def desc():
fmt_args = ", ".join(
itertools.chain(
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
(f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
)
)
return f"{op}({fmt_args})"
check_metadata_matches(f_r, r, desc)
return r
return handler
# NB: enabling this is slow, don't do it in a hot loop. This is purely
# for debugging purposes.
@contextmanager
def enable_crossref_functionalize():
for op in all_py_loaded_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
try:
with enable_python_dispatcher(), unittest.mock.patch(
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
yield
finally:
for op in all_py_loaded_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)