Unconditionally enable python dispatcher in AOTAutograd (#88365)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88365
Approved by: https://github.com/Chillee
This commit is contained in:
Edward Z. Yang 2022-11-02 18:55:33 -07:00 committed by PyTorch MergeBot
parent a689502275
commit 97d3b200ca
2 changed files with 17 additions and 16 deletions

View File

@ -391,6 +391,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
disable_amp = torch._C._is_any_autocast_enabled()
if config.use_functionalize:
with enable_python_dispatcher():
# Trace once without decompositions, into a graph of ATen ops.
# NB: tracing_mode is real, as it's assumed the calling context setup
# fake tensor mode / symbolic shapes if that is needed

View File

@ -103,7 +103,7 @@ def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
# The dispatch key itself will implicitly route to backend fallback.
# This is probably not great for the pure Python implementation.
return k
raise RuntimeError("could not find kernel")
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
pyop_namespace = {}