As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
As title. Without this patch we get the following error:
Tweaking the `allow_non_fake_inputs` flag on tensor mode doesn't quite
work for AOTAutograd, which also needs to fake-tensor-propagate the
`nonstrict_trace`-ed function, but that's _after_ Dynamo has handled the
`nonstrict_trace` processing and put the `flat_apply(...)` node into the graph.
So we can't easily to temporarily enable the `allow_non_fake_inputs`
flag on current fake mode, when AOTAutograd processes a `flat_apply`
node from Dynamo's `nonstrict_trace` handling. And after discussing
with zou3519, I decided to add a global `FakeTensorTLS` that contains a
`allow_non_fake_inputs_override` flag, and patch the `nonstrict_trace`-ed
function to temporarily tweak this flag during its execution.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147572
Approved by: https://github.com/zou3519
ghstack dependencies: #146714, #146367, #146950, #147571
## Context
> **Note:** `mark_traceable` got renamed to `nonstrict_trace` after
> offline discussion. The reasons are (1) it aligns with `torch.export`'s
> `nonstrict` notion, and (2) it's more definitive in behavior suggestion.
1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0)
2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn)
## Summary
This patch adds a `torch._dynamo.nonstrict_trace` decorator, which
currently is an enhanced version of `torch._dynamo.allow_in_graph` (see
docstring for their differences). Specifically, this patch focuses on
the UI and functionality prototyping/plumbing.
The main enhancement is supporting more input types, and the
implementation challenge lies in reconstructing the input objects from
Dynamo `VariableTracker` (while accounting for buffered side-effects and
guards). This patch takes a middle-ground (simple implementation with a
bit of user labor), by
1. asking the user to provide pytree registration for non-proxy-able
input types,
2. letting Dynamo trace through `pytree_flatten` (which accounts for
buffered side-effects and guards automatically),
3. and passing in the TreeSpec as a graph attribute constant into
`torch._higher_order_ops.flat_apply` (which unflattens the inputs and
invokes the underlying function).
## Next Steps
In subsequent patches, we will try to support the following:
- annotating on class method
- reads to global tensors
- inputs that contains `pytree.register_constant`-ed instances.
- function as input
- more output types (e.g., any pytree-registered type)
- `torch.nn.Module` as inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146367
Approved by: https://github.com/zou3519
ghstack dependencies: #146714
This patch enables `flat_apply` to support certain non-Tensor output
types like containers and graphable types. This will in turn enable the
upcoming `mark_traceable` to support more output types.
The patch also exposes a `func_to_graphable` rather than having the
users calling the lower level `pytree.flatten(ConstantFunction(...))`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146714
Approved by: https://github.com/zou3519
This PR:
- adds pytree.register_constant for registering a class to be treated as
a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
mark_traceable working. A lot more work is necessary to get the custom
operator case working (when make_fx sees a custom operator with PyTree
arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
the current state to unblock the mark_traceable and custom ops
work.
Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
I added a really simple test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059