Our experience using `constraints` / `dynamic_dim` with the existing export API has found it to be (subjectively) clunky and (objectively) verbose in common cases.
This PR implements a new design for the export API that replaces the use of `constraints` / `dynamic_dim` with a new way of specifying dynamic shapes, involving the following concepts:
* a constructor `Dim` for first-class named dynamic dimensions with ranges (similar to `functorch.dim`, and analogous to internal symbolic sizes)
* a mechanism that uses the above in `export` calls to associate inputs to their dynamic shape specifications (`dynamic_shapes`)
Design doc: https://docs.google.com/presentation/d/168U7XK72C_WSsZpGESP6Cho9udh193fi0gfjxCNcJ4E/edit#slide=id.p (Meta-only). Note that we only implement Option 1 in that doc. An older version of this PR also implemented Option 3, which is an alternative way of specifying dynamic shapes using tensor type annotations on the exported callable; but we have moved that to future work for now.
See docs for these new features in `torch.export`. The existing `torch.export.export` is modified to use the new API, `torch._export.export__RC__`, whenever `constraints=None`. We have not deprecated the existing API yet, but will do in a follow-up.
Constraint violation errors arising through use of the new API will now contain suggested fixes using the new API. No longer do we need to report all specializations for static dimensions and suggest all constraints over dynamic dimensions to fix such errors. Instead, due to the redesign, the suggested fixes are much more concise, only involving modifying the definitions of relevant `Dim`s.
Differential Revision: [D48919204](https://our.internmc.facebook.com/intern/diff/D48919204/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108448
Approved by: https://github.com/suo, https://github.com/gmagogsfm
Fix: #107315
This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.
In summary, this PR:
- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
`as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
- Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
- Removed `functools.wraps` function, since it can't be inlined
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures.
Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()
# Inside torch.cond, we'd like to do something like
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else.
# Dynamo will use the cached code of foo for "eager" backend
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```
This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).
**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.
Note: More lines are printed for debug log due to newly added context manager and guard adds .
**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107337
Approved by: https://github.com/jansel
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures.
Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()
# Inside torch.cond, we'd like to do something like
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else.
# Dynamo will use the cached code of foo for "eager" backend
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```
This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).
**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.
2. Then newly added context manager and guard adds more lines for debug log so we change the uppper limit from 50 to 55.
**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107337
Approved by: https://github.com/jansel
In https://github.com/pytorch/pytorch/pull/106673 , I created a private API `_debug_get_cache_entry_list` to help pull out cache entries from compiled functions.
Recently, I find that @anijain2305 commented in the code that this API should be revisited, and so I created this PR.
First, this API cannot be removed even if cache entry becomes a first-class python class`torch._C._dynamo.eval_frame._CacheEntry`. The facts that `extra_index` is static, and `get_extra_state` is inline static, make them not accessible elsewhere. This API `_debug_get_cache_entry_list` is the only way for users to get all the cache entries from code.
Second, since the`torch._C._dynamo.eval_frame._CacheEntry` class is a python class, I simplified the C-part code, and remove the necessity of creating a namedtuple for this in the python code.
Third, I also add a small improvement, that if the argument is a function, we can automatically pass its `__code__` to the API.
The above change will slightly change the output, from list of named tuple to list of `torch._C._dynamo.eval_frame._CacheEntry`. I will update the corresponding docs that use this API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108335
Approved by: https://github.com/jansel, https://github.com/anijain2305
Instead of hardcoding a new callback creation using 'convert_frame',
add an attribute to both callbacks that implement 'self cloning with new
backend', so DDPOptimizer can invoke this in a consistent way.
Fixes#107686
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107834
Approved by: https://github.com/ezyang
Previously when we found some input or output mismatch between original args / traced result vs. graph-captured input / output, we would have a pretty sparse error message. (This might be partly due to the urge to reuse the same code for matching both inputs and outputs.)
With this PR we now point out which input or output is problematic, what its type is, and also present the expected types along with descriptions of what they mean. We don't suggest any fixes, but the idea is that it should be evident what went wrong looking at the error message.
Differential Revision: [D48668059](https://our.internmc.facebook.com/intern/diff/D48668059/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107907
Approved by: https://github.com/gmagogsfm
Moves the logic to casting state to match parameters into a hook so that users can choose to enable their hooks before or after the casting has happened.
With this, there is a little bit of redundancy of the id_map building and the check that the param groups are still aligned in length.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106725
Approved by: https://github.com/albanD
Dynamo currently runs the real graph module with real inputs as a way to match the return result of graph module with the eager return type. This is unsafe when graph module is side effectful. In the long term, we will get rid of this step. But in the short term, we just fakify the graph module again and run it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107271
Approved by: https://github.com/ezyang
This PR makes CacheEntry a PyObject. This is prep PR for cache size changes. As CacheEntry is a py object, we can now traverse the linked list in Python and write cache size policies. It was possible to do in C, but Python is just easier to iterate upon. We call convert_frame only when we (re)compile, so a small bump in latency going from C to Python is acceptable here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107405
Approved by: https://github.com/ezyang
ghstack dependencies: #106917, #107117
This PR adds a **same_signature** flag to dynamo.export.
**Motivation:**
In https://github.com/pytorch/pytorch/pull/105679, we experimented on **using dynamo to inspect the UDFs** for cond in eager mode (without torch.compile). This helps us to normalize the inputs (e.g. lifting closure to inputs) and makes higher order operator more robust (e.g. forbid python side effects) and less error-prone in general.
We decided to use dynamo.export (instead of torch.compile) to do the inspection (pointed out by @voznesenskym @zou3519):
- We'd like a **whole-graph capture** for the UDF.
- We'd like the dynamo inspection to be **stateless**. Using torch.compile would require resetting dynamo context before and after the inspection because the compile flags may be different from users' torch.compile. This will clear all dynamo cache.
- We can still implement some **caching** based on the guards.
However, this requires export to be able to handle the case where it cannot always rewrite signature: e.g. closure lifted as input.
This PR makes the rewrite optional.
**Implementation:**
We just put all the code that are related to signature rewriting into a function called rewrite_signature and use a same_signature flag to optionally to the transformation.
**Test Plan:**
existing tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106569
Approved by: https://github.com/ezyang
Previously, you would get an error like
```
Dynamo input and output is a strict subset of traced input/output
```
now you get
```
Cannot export model which references tensors that are neither
buffers/parameters/constants nor are direct inputs. For each tensor, if you'd
like this tensor to be an explicit input, add it as a dummy argument
to the top-level model definition you are exporting; if you would
like its value to be embedded as an exported constant, wrap its access
in a function marked with @assume_constant_result.
G['bulbous_bouffant'], accessed at:
File "test_export.py", line N, in f
return bulbous_bouffant + y
```
This doesn't handle outputs, I'm going to hit that next.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106403
Approved by: https://github.com/tugsbayasgalan
Currently, exporting a model to ONNX with fake tensor mode requires the
user to load data and model within `torch.onnx.enable_fake_mode` context,
but the actual call to `torch.onnx.dynamo_export` is done outside such
context.
With this PR, we enable `torch.onnx.dynamo_export` to be called either
within `torch.onnx.enable_fake_mode` or outside of it. This feature
required changes to the core PyTorch Dynamo, which were greatly
supported by @ezyang
In future steps we will determine which scenario we are going to
support, but for now we can use either to explore different options and
scenarios and asses their pros and cons.
This PR also creates a separate suite of tests for fake mode specific
scenarios (`TestFxToOnnxFakeTensorWithOnnxRuntime`).
It was done separately to decrease the test time, but we
could merge it with the default `TestFxToOnnxWithOnnxRuntime`. The
additional parameters are `load_checkpoint_during_init` and
`export_within_fake_mode`
With the newly added supported of nested export within fake mode, the
following scenarios are now supported:
```python
import torch
with torch.onnx.enable_fake_mode() as fake_context:
fake_args = create_args()
fake_kwargs = create_kwargs()
fake_model = create_model()
fake_model.load_state_dict(torch.load(tmp_checkpoint_file.name))
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
# `torch.onnx.dynamo_export` called WITHIN `torch.onnx.enable_fake_mode`
export_output = torch.onnx.dynamo_export(
fake_model,
*fake_args,
**fake_kwargs,
export_options=export_options,
)
export_output.save("/path/to/model.onnx", model_state_dict=create_model())
```
If we decide to only support scenarios in which `torch._dynamo.export` is called within `FakeTensorMode`, then we can remove `fake_mode` argument from `torch._dynamo.export` as a follow-up task
ps: This PR is mostly Edward's https://github.com/pytorch/pytorch/pull/105468 + unit tests after an offline discussion
ps: https://github.com/pytorch/pytorch/issues/105464 tracks pending tasks/limitations from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105477
Approved by: https://github.com/ezyang, https://github.com/BowenBao
Summary:
We are working toward full model compilation, where when compilation error happens, we just fall back to eager mode rather than error out.
But at the same time, we should fix these issues if they are bugs. We will:
* 1/ log warnings in OSS;
* 2/ log warnings and write them into Scuba in fbcode;
to prevent us from ignoring these issues.
Test Plan: Manual test
Differential Revision: D47506314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105307
Approved by: https://github.com/jansel
issues resolved: https://github.com/pytorch/pytorch/issues/101832
**context**: get torch.compile config for further usage. E.g, the training platform wants to get if model is compiled with cudagraph enabled and trigger further action
**how it is implemented**
* the core logic is backend.get_compiler_config() in torch/_dynamo/eval_frame.py
* for backend='inductor' / _TorchCompileInductorWrapper, we have inductor-specific implementation in get_compiler_config in torch/_inductor/compile_fx.py and torch/__init__.py
**how to use it**: Below is an example.
```
model = DummyModule()
optimized_module = torch.compile(
model, options={"triton.cudagraphs": True}
)
compiler_config = optimized_module.get_compiler_config()
if compiler_config["triton.cudagraphs"]:
pass
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105026
Approved by: https://github.com/yanboliang, https://github.com/jansel