diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index d4663c6dc71..e0a9e10d11b 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -391,24 +391,25 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi disable_amp = torch._C._is_any_autocast_enabled() if config.use_functionalize: - # Trace once without decompositions, into a graph of ATen ops. - # NB: tracing_mode is real, as it's assumed the calling context setup - # fake tensor mode / symbolic shapes if that is needed - fx_g = make_fx(joint_forward_backward)(*joint_inputs) + with enable_python_dispatcher(): + # Trace once without decompositions, into a graph of ATen ops. + # NB: tracing_mode is real, as it's assumed the calling context setup + # fake tensor mode / symbolic shapes if that is needed + fx_g = make_fx(joint_forward_backward)(*joint_inputs) - context = disable_autocast_manager if disable_amp else nullcontext + context = disable_autocast_manager if disable_amp else nullcontext - def fake_fn(primals, tangents): - with torch.fx.traceback.override_stack_trace(): - return torch.fx.Interpreter(fx_g).run(primals, tangents) + def fake_fn(primals, tangents): + with torch.fx.traceback.override_stack_trace(): + return torch.fx.Interpreter(fx_g).run(primals, tangents) - # Trace a second time, running functionalization, and THEN running decompositions. - # functionalization only acts on ATen today, and doesn't currently handle - # view and inplace ops that come from primtorch. - # Eventually, functionalization should support primtorch view/inplace ops, - # which will make it ok to run decompositions before functionalization. - with context(): - fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs) + # Trace a second time, running functionalization, and THEN running decompositions. + # functionalization only acts on ATen today, and doesn't currently handle + # view and inplace ops that come from primtorch. + # Eventually, functionalization should support primtorch view/inplace ops, + # which will make it ok to run decompositions before functionalization. + with context(): + fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs) fx_g.graph.eliminate_dead_code() fx_g.recompile() else: diff --git a/torch/_ops.py b/torch/_ops.py index ed0276d0ada..a4119d75852 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -103,7 +103,7 @@ def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type] # The dispatch key itself will implicitly route to backend fallback. # This is probably not great for the pure Python implementation. return k - raise RuntimeError("could not find kernel") + raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") pyop_namespace = {}