Fix pre-dispatch AC HOP calling convention (#165145)

For AC HOP, dynamo traces it without kwargs. (kwargs are only inputs to the HOP, not to the body)
55f01a48af/torch/_dynamo/variables/higher_order_ops.py (L2594-L2609)

When we add non-strict support, we should match this calling convention too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165145
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433, #164437
This commit is contained in:
Simon Fan 2025-10-10 07:21:32 -07:00 committed by PyTorch MergeBot
parent 058814794b
commit 992857e286

View File

@ -325,7 +325,8 @@ def proxy_mode_key(
qualname = proxy_mode.tracer.get_fresh_qualname("wrap_body") # type: ignore[union-attr]
# TODO (tmanlaibaatar) don't we need flat_apply here??
flat_args, _ = pytree.tree_flatten((args, kwargs))
# Dynamo already traced the gmod body without kwargs
flat_args, _ = pytree.tree_flatten(args)
with fx_traceback.preserve_node_meta():
gmod_aten = reenter_make_fx(Interpreter(gmod).run)(*flat_args)
gmod_aten.meta["_checkpoint_context_fn"] = gmod.meta["_checkpoint_context_fn"]