Previously, `node.meta["nn_module_stack"]` had type `Dict[str, Tuple[str, class]]` when exported, and later `Dict[str, Tuple[str, str]]` after de/serialization. This PR changes it to consistently be `Dict[str, Tuple[str, str]]` for round-trippability, i.e.
```
{..., 'L__self___conv': ('conv', 'torch.nn.modules.conv.Conv2d')}
```
`source_fn_stack` is left untouched in this PR.
note: the `Union[type, str]` type annotations in ONNX are because ONNX goes through both `export.export()` and `_dynamo.export()` (which still has the original `Dict[str, Tuple[str, class]]` format). nn_module_stack from `export.export()` should consistently have the new format, and we verify/test for that in `_trace.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123308
Approved by: https://github.com/zhxchen17, https://github.com/thiagocrepaldi
Fixes#118795
This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.
**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
def forward(X):
# This reparametrization just negates the original parameter value
return -X
m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)
# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight
# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```
**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.
**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
* This cache was originally introduced in #100128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121041
Approved by: https://github.com/anijain2305
`unimplemented` is a function that raises an error, so
`raise unimplemented(...)` never reaches the `raise`.
Another related issue is that `raise unimplemented(...) from e`
doesn't attach the exception cause correctly. I fix this by adding
a `from_exc` argument to `unimplemented`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122136
Approved by: https://github.com/lezcano
Summary:
Seems like `kwargs` is already support in `_infer_argument`, so we don't need the extra assertion `len(kwargs) == 0`.
This optimization ensures compatibility with torch.compile() for LazyModules with kwargs inputs, preventing graph breaks.
Test Plan: Unit tetst and CI
Differential Revision: D53558778
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119445
Approved by: https://github.com/yanboliang
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.
Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
1. Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.
In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/issues/113041
In the case where we have an object represented as an UnspecializedNNModuleVariable, the source of an attribute on that object is `AttrSource(base=NotNNModuleSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq'))), member='b')`. This causes dynamo to add an extra attribute as it doesn't go to this [`register_attr` step](eddce3c054/torch/_dynamo/variables/builder.py (L955-L962)).
However if we have an object represented as a UserDefinedObjectVariable, the source of an attribute on that object is `AttrSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq')), member='b')`.
It seems that UnspecializedNNModuleVariables should behave in the same was as UserDefinedObjectVariables, but the source in these two cases are different. So, I removed the part that changes the source in the UnspecializedNNModuleVariables, and it seems to work! And CI is green (+ reduced graph breaks).
```
def test_unspecialized_nnmodule(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.tensor(1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.a
def forward_hook(
module: torch.nn.Module, inputs, output
) -> torch.Tensor:
return 2 * output
seq = torch.nn.Sequential(TestModule()).eval()
seq.b = torch.tensor(2)
handle = seq.register_forward_hook(forward_hook)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.seq = seq
def forward(self, x):
# self.seq.b has source: AttrSource(base=NotNNModuleSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq'))), member='b')
return self.seq(x) + self.seq.b
inp = (torch.randn(2, 8),)
ep = export(M(), inp)
```
```
def test_user_defined_var(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.tensor(1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.a
class UserDefined:
def __init__(self):
self.test_module = TestModule()
self.b = torch.tensor(2)
def __call__(self, x):
return self.test_module(x)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.seq = UserDefined()
def forward(self, x):
# self.seq.b has source: AttrSource(base=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='seq')), member='b')
return self.seq(x) + self.seq.b
inp = (torch.randn(2, 8),)
ep = export(M(), inp)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113852
Approved by: https://github.com/yanboliang
In my work on making guards installed eagerly (look up the stack), I found that our checkpoint/restore mechanism is very broken. There is lots of state (especially in shape_env) which we don't checkpoint and restore properly. We also have lots of mutable state on variable trackers already which is not checkpointed/restored. (See other PRs in this stack for some spot fixes.)
Since we wanted to get rid of this anyway for making VariableTracker mutable, I figured I would just switch to restarting analysis.
For other usages of copy_graphstate/restore_graphstate:
1) Many usages were pointless and not needed, these are removed in PRs below this.
2) Some other usage (similar to this one) is removed in PRs above this.
3) The tricky one I am not handling is higher_order_ops, which uses checkpoint/restore a lot. There might be some cases there where this speculate/restart trick won't work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112902
Approved by: https://github.com/voznesenskym
The strategy for supporting functools partials is relatively straightforward.
There are 2 cases we need to support:
**1) Functools partials as input**
In this case, we are first seeing the functools partial and it is guaranteed to have a source. As such, the args, keywords, and func of the functools partial are passed through VariableBuilder. As this is the first time we are seeing these objects (as it is an input), we re-enter VariableBuilder with a source referencing the args, keywords, and func as attributes of the input to produce:
- func: A callable VariableTracker (UDF, TorchVariable, etc) depending on the value of `func`
- args: List[VariableTracker] - note, not ListVariableTracker!
- keywords: Dict[str, VariableTracker]
A major benefit of this structure is that it very elegantly matches the args to `call_function`.
We then compose a FunctoolsPartialVariable from the VariableTrackers made above.
**2) Functools partials created within compile**
In this case, we already have all the args as known VTs, and thus just compose a FunctoolsPartialVariable as we do for case (1).
For both (1) and (2) - we propagate all guards from the func, args, and keyword VTs to the FunctoolsPartialVariable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108846
Approved by: https://github.com/ezyang, https://github.com/jansel
Summary:
Original commit changeset: 33650f7cb0fb
Original Phabricator Diff: D48833682
Test Plan: See T162942232 for how we figured out that this diff caused significant numeric difference.
Reviewed By: voznesenskym
Differential Revision: D49082219
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108823
Approved by: https://github.com/xw285cornell
Summary: Currently node metadata "nn_module_stack" is only being used by export. For some export model, we still want to retain nn_module_stack for unspecialized module for various purposes. This diff add a path to also record nn_module_stack when unspecialized module has a source available.
Test Plan: test_export_nn_module_stack_patched_module
Differential Revision: D48841193
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108281
Approved by: https://github.com/yanboliang, https://github.com/tugsbayasgalan
before the PR, running super(MyConv1d, self).forward or super(MyConvTranspose, self).foward, dynamo will create a graph break when executing NNModuleVariable.call_method and raise unimplemented error for name=_conv_forward / _output_padding. see issue for full detail: https://github.com/pytorch/pytorch/issues/101155
after the PR, for torch.nn.conv module with function name _conv_forward / _output_padding, we inline the function with tx.inline_user_function_return
code refactor: added NNModuleVariable._inline_user_function_return_helper to consolidaste tx.inline_user_function_return into 1 place to keep code dry. after factor, there are 2 uncolidated inline_user_function_return with different ```fn``` and ```source``` logic. the code is still dry. For local testing, they are covered by test_modulelist, test_moduledict, test_conv_call_super_forward_directly and test_conv_transpose_call_super_forward_directly in test_modules.py
Differential Revision: [D46494460](https://our.internmc.facebook.com/intern/diff/D46494460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102509
Approved by: https://github.com/yanboliang
Opening this so I can discuss with @albanD
I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py
So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")`
Right now I'm trying to extend this to work for nn.modules more generally
TODO: Failing tests
* [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested
* [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py`
* [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch`
* [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang
* [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures
* [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1`
* [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine
* [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing
* [x] Tried out to instead override `__call__`, python doesnt like this though https://github.com/pytorch/pytorch/pull/97565#issuecomment-1524570439
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97565
Approved by: https://github.com/aaronenyeshi, https://github.com/albanD, https://github.com/voznesenskym
Opening this so I can discuss with @albanD
I built a proof of concept of an in place API for an nn.Module that allows us to save and load a torch.compiled model with no issues https://github.com/msaroufim/mlsys-experiments/blob/main/save-compiled-model.py
So users can run` model.compile()` and then run `torch.save(model, "model.pt")` and `torch.load(model, "model.pt)` with no issues unlike the rather strange current suggestion we give to users which is `opt_mod = torch.compile(mod); torch.save(mod, "model.pt")`
Right now I'm trying to extend this to work for nn.modules more generally
TODO: Failing tests
* [x] torch.jit.load -> issue was because of aliasing `__call__` to `_call_impl`, _call_impl used to be skipped when now it lo longer is so expanded the skip check. I added an explicit `torch.jit.load()` test now which @davidberard98 suggested
* [x] functorch seems to be a flake - ran locally and it worked `pytest functorch/test_eager_transforms.py`
* [x] a test infra flake - `test_testing.py::TestImports::test_no_mutate_global_logging_on_import_path_functorch`
* [x] It seems like I broke inlining in dynamo though `python -m pytest test/dynamo/test_dynamic_shapes.py -k test_issue175` chatting with Voz about it but still not entirely sure how to fix - found a workaround after chatting with @yanboliang
* [x] `pytest test/dynamo/test_modules.py` and `test/dynamo/test_dynamic_shapes` `test/dynamo/test_misc.py` seem to be failing in CI but trying it out locally they all pass tests passed with 0 failures
* [x] `pytest test/profiler/test_profiler_tree.py ` these tests have ProfilerTrees explicitly printed and will now break if __call__ is not in tree - ran with `EXPECT_ACCEPT=1`
* [x] `pytest test/test_torch.py::TestTorch::test_typed_storage_deprecation_warning` a flake, ran this locally and it works fine
* [x] I reverted my changes to `_dynamo/nn_module.py` since it looks like @wconstab is now directly handling `_call_impl` there but this is triggering an infinite inlining which is crashing
* [x] Tried out to instead override `__call__`, python doesnt like this though https://github.com/pytorch/pytorch/pull/97565#issuecomment-1524570439
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97565
Approved by: https://github.com/aaronenyeshi, https://github.com/albanD
**TL;DR**: This PR fixes handling for lazy modules where `cls_to_become is None`. In those cases, we should leave the type of the lazy module as the old value.
**Details**:
Lazy modules are intended to be initialized at execution; some of them are also supposed to switch to a different type after they have been initialized. However, not all are supposed to switch; see this logic from `nn/modules/lazy.py`
```python
def _infer_parameters(self, ...):
...
if module.cls_to_become is not None:
module.__class__ = module.cls_to_become
```
i.e., we should leave the module type as the old value if `module.cls_to_become is None`. This PR updates dynamo's handling to match this behavior.
Test `test_lazy_module_no_cls_to_become` added to `test/dynamo/test_module.py`.
Differential Revision: [D45253698](https://our.internmc.facebook.com/intern/diff/D45253698)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99943
Approved by: https://github.com/jansel
Before this PR, if users call ```Conv2d(x)```, dynamo handles it well(no graph break) and puts a ```call_module``` op in the FX graph. However, if users explicitly call ```Conv2d.forward(x)``` in another ```forward``` function, the inlining would be failed(caused graph break). This PR fixed this issue by translating the explicit ```Conv2d.forward(x)``` to ```Conv2d(x)```.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99015
Approved by: https://github.com/jansel, https://github.com/wconstab
Allowed modules are stuck into dynamo's fx graph as call_module
nodes, without dynamo doing any tracing of the module. This means
during AOT trace time, hooks will fire during tracing when the
call_module is executed, but the hooks themselves will disappear
after that and not be present in the compiled program.
(worse, if they performed any tensor operations, those would get
traced so you could end up with part of the hook's functionality).
To circumvent this, there are two options for 'allowed modules' with hooks.
1) don't treat them as 'allowed' - trace into them
2) graph-break, so the module is no longer part of the dynamo trace at all
(1) will fail for users that opted into allowed modules becuase they know
their module has problems being traced by dynamo.
(2) causes graph breaks on common modules such as nn.Linear, just because they
are marked as 'allowed'.
It would help matters if we could differentiate between types of allowed modules
(A) allowed to avoid overheads - used for common ops like nn.Linear
(B) allowed to avoid dynamo graphbreaks caused by unsupported code
Ideally, we'd use method (1) for group (A) and (2) for (B).
For now, graph-break on all cases of allowed modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97184
Approved by: https://github.com/jansel
This fixes a regression added in the following PR to graph-break on allowed modules with hooks, but has its own problems.
- following #97184 PR makes 'allowed modules' with hooks graph-break, and lazy modules
are allowed. (should we just make lazy modules not allowed ?)
- graph-breaks at lazy modules fail the lazy module unit tests which assert no graphbreaks
- this PR attempts to always 'initialize' lazy modules before tracing/calling into their __call__,
and initializing a lazy module should delete all its hooks after firing them once, making
the above issue go away
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98516
Approved by: https://github.com/yanboliang, https://github.com/jansel