mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix pre-dispatch AC HOP calling convention (#165145)
For AC HOP, dynamo traces it without kwargs. (kwargs are only inputs to the HOP, not to the body)
55f01a48af/torch/_dynamo/variables/higher_order_ops.py (L2594-L2609)
When we add non-strict support, we should match this calling convention too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165145
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433, #164437
This commit is contained in:
parent
058814794b
commit
992857e286
|
|
@ -325,7 +325,8 @@ def proxy_mode_key(
|
|||
qualname = proxy_mode.tracer.get_fresh_qualname("wrap_body") # type: ignore[union-attr]
|
||||
|
||||
# TODO (tmanlaibaatar) don't we need flat_apply here??
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs))
|
||||
# Dynamo already traced the gmod body without kwargs
|
||||
flat_args, _ = pytree.tree_flatten(args)
|
||||
with fx_traceback.preserve_node_meta():
|
||||
gmod_aten = reenter_make_fx(Interpreter(gmod).run)(*flat_args)
|
||||
gmod_aten.meta["_checkpoint_context_fn"] = gmod.meta["_checkpoint_context_fn"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user