Summary:
1. Adding `input` field to `_adapt_flat_args` function
2. In `process_forward_inputs`, `reorder_kwargs` will now do nothing if no kwargs are provided (previously would error)
3. Pass `args` as input to `_adapt_flat_args`
These changes are made to update the InputAdapter
see more context in D73811508
Test Plan: see D73811508
Differential Revision: D73945419
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152575
Approved by: https://github.com/angelayi
Summary:
Recall that we use "ivals" to track intermediate values of mutations during unflattening. Previously, for each such intermediate value, we would create a hidden shared attribute that would be updated / read by respective submodules.
Unfortunately this scheme doesn't work when some but not all of those submodules are swapped out. This is because the swapped in submodules have no knowledge of these hidden attributes. Thus the submodules that are not swapped out end up reading / updating dangling state.
This PR does away with these hidden attributes. Instead, we directly read the underlying buffer or placeholder that was updated, and update those underlying buffers and placeholders in place. This makes the graphs look much closer to their eager origins.
Test Plan: added some tests, ensured existing tests pass
Differential Revision: D71203469
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149206
Approved by: https://github.com/tugsbayasgalan
Summary: Add support for inputs that no longer exist in `input_fields`, but is not actually used by the original program. In this case, we just give it a dummy input based on the node's metadata.
Test Plan: Verified for S488841
Differential Revision: D69328093
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147238
Approved by: https://github.com/pianpwk
If a user passes in a namedtuple as an input, currently the input TreeSpec looks like: `TreeSpec(type=namedtuple, context=”class_fqn”, children_spec=[*, *])`
The user then saves the program containing this input TreeSpec. But what happens if they load it in a new environment where `class_fqn` now contains an additional field?
This means that the exported program is now expected to take in another input. But since those fields were not used in the original program, users should be able just drop those additional fields and the program will run successfully. This is needed/used in APS where they use unflattener's adapter to adapt the inputs based on the previously saved treespecs.
There are a couple of [solutions](https://docs.google.com/document/d/1V4ZSdy-8PUISWc8RqvGu3DU01BVegJhHHPWqa1Io7Eg/edit?tab=t.0) for how we can address this, but eventually we settled on saving a side table mapping namedtuple types to their list of field names, which can then be accessed by the adapter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145956
Approved by: https://github.com/zhxchen17
When we unflatten, the submodules we generate (`InterpreterModule` or `InterpreterModuleDispatcher`) are not related by type to the original submodules `N`. This makes `isinstance(mod, N)` checks fail. Since we do not have the original types after export, the best we can do is expose a `type_name()` method that carries the original type name, which we do carry in `nn_module_stack` entries.
Differential Revision: [D67526542](https://our.internmc.facebook.com/intern/diff/D67526542/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143664
Approved by: https://github.com/tugsbayasgalan
Combining several fixes to unflatten for bugs revealed by random graph testing.
The fixes target two categories of bugs:
1. Some bugs show up as exponential blowups for largish system of nn modules. These are fixes by converting lists to sets, using caching, or otherwise rewriting to reuse computation more effiicently.
2. Other bugs were due to missing intermediate modules created when attributes such as submodules and buffers are accessed through longish paths before calling the corresponding intermediate modules, or missing attributes such as buffers and constants in submodules corresponding to multiple calls.
Differential Revision: [D66659795](https://our.internmc.facebook.com/intern/diff/D66659795/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142141
Approved by: https://github.com/ydwu4
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.
Having these `# type: ignore` linger around is not ideal for two reasons:
- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.
I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.
This PR should have no effect on runtime at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
With largish systems of nn modules with buffers, sinking params suffered from some kind of exponential blowup that is easily fixed by using a set instead of a list to keep track of unlifted buffer placeholders.
Test Plan: added random dag test that failed previously
Differential Revision: D66457661
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141494
Approved by: https://github.com/angelayi
Handling of nested modules in unflatten had several bugs, which were caught by trying to preserve module call signatures for nested modules.
* A module `k` encountered when calling `k.n()` before `k()` used to become an empty nn module. This caused some information to be dropped when `k()` was eventually called. Relatedly, we would also lose call counts for `k.n()` through different paths (say, when `k()` calls `n()`).
* Deleting call-indexed modules and patching up their call sites was broken for nested modules when creating dispatcher modules, because of silliness when handling their fqns.
An interesting aside is that we used random graph generation for testing some of these changes. A future PR will add the infra to create tests using these random graphs.
Differential Revision: D66192799
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141066
Approved by: https://github.com/angelayi
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Differential Revision: [D65307961](https://our.internmc.facebook.com/intern/diff/D65307961/)
This PR introduces the concept of a "dispatcher" module `n` that carries multiple interpreter modules `n`, `n@1`, `n@2`, etc., each corresponding to a particular call of `n` and thus might carry a different specialized graph. We only do this when we're preserving module call signatures for `n`. The carried modules have the same number and order of calls to `n` appearing in the original module / exported program. In the unflattened module, all those calls go to the "dispatcher" module which internally tracks how many calls have been made so far and invokes the corresponding interpreter module. We reset this tracking after a successful or unsuccessful run of the unflattened module.
Overall this makes swapping easier when module call signatures are preserved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139439
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #139438
# Why?
I want the following code to work.
minimal repro:
```
class M(torch.nn.Module):
def forward(self, dilate_flag):
return dilate_flag.item()
input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
model = M().cuda()
ep = torch.export.export(model, input1, strict=True)
path = torch._inductor.aot_compile(ep.module(), input1)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input1)
```
error: AssertionError: Encountered an unsupported object of type <class 'torch.SymBool'> while writing the metadata for exported program
second error will be handled by https://github.com/pytorch/pytorch/pull/138760
# Motivation
I could technically bypass it with a torch.int tensor. However, it doesn't work with torch.cond. I want the following to work. It would also require https://github.com/pytorch/pytorch/pull/138760 for aot compile to work.
```
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dilate_flag = 0
def forward(self, dilate_flag):
self.dilate_flag = dilate_flag.item()
def true_fn(dilate_flag):
return dilate_flag.clone()
def false_fn(dilate_flag):
return dilate_flag.clone()
torch.cond(
self.dilate_flag,
true_fn,
false_fn,
(dilate_flag,),
)
return self.dilate_flag
input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
input2 = (torch.tensor([0], dtype=torch.bool, device="cuda"),)
inputs = (input1, input2)
model = M().cuda()
for input in inputs:
expected_output = model(*input)
ep = torch.export.export(model, input, strict=False)
path = torch._inductor.aot_compile(ep.module(), input)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input)
assert (
expected_output == actual_output
), f"henry they are not equal {expected_output} != {actual_output}"
```
Differential Revision: D64867504
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138765
Approved by: https://github.com/ydwu4
Summary:
Unflatten was broken for HOPs for a couple of reasons:
(1) we didn't expect `get_attr` nodes in the exported program, but they can occur to hold graph arguments to HOPs; such attributes must be moved from the exported program to the corresponding unflattened submodule containing the HOP call.
(2) we don't record metadata for graph arguments on serialization (there's nothing to hold it in our schema), and accordingly the `get_attr` nodes we create on deserialization don't have `nn_module_stack` metadata, which obviously wrecks unflatten.
Test Plan: added a couple of tests
Differential Revision: D65013647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138978
Approved by: https://github.com/zhxchen17
As called out in https://github.com/pytorch/pytorch/pull/137999, preserving signatures of multiple calls when buffer mutations are present was NYI. The main problem was that intermediate values of buffers were not tracked, so couldn't be propagated statefully between multiple calls (i.e., they would need to be explicitly passed around, defeating the unlifting needed for preserving signatures).
This PR fixes this situation, by introducing module attributes that carry the necessary intermediate values of buffer mutations. In general, a buffer mutation can have several intermediate values it depends on recursively, even other buffers. So rather than tying an intermediate value with a particular buffer, we tie it with the submodules that create and read it. We install an attribute on all modules that create or read a particular intermediate value, sharing the same initial storage (i.e., initialized with the same empty tensor). For the module that creates this intermediate value, we copy the value into the corresponding attribute; and for the modules that read it, we read the corresponding attribute instead.
Another complication that needed to be addressed was that a `run_decompositions` following an `export_for_training` was not preserving module call graphs, which is needed for unflattening and, in particular, used when remapping inputs. Fortunately some existing metadata already tracks provenance of nodes, which we could use to update a module call graph after functionalization / decomposition.
Differential Revision: D64806175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138669
Approved by: https://github.com/tugsbayasgalan
Previously we would error when trying to preserve the call signature for a module when it was called multiple times. This PR can now do this without erroring. The fix is to propagate call indices in a few more places.
Note that while this works in the presence of params, buffers, and tensor constants, preserving call signatures for multiple calls to a module when buffers are mutated is not supported yet. This is future work. The main problem is that we do not have enough metadata to `copy_` mutated buffers at the end of each call to a module, so the next call can read those buffers at the beginning. Making this work will likely need some explicit tracking of intermediate values of mutated buffers when collecting metadata during functionalization in export.
Note also that we stop short of creating a single graph out of multiple graphs: that is still future work. So the unflattened module will still have different targets `n`, `n@1`, `n@2`, etc. for each call when we ask the module call signature of `n` to be preserved. However it is way easier to swap all of these targets with a replacement that behaves similar to the original, because all of these calls will respect the original module call signature. (In particular, any constant inputs will be carried by the calls.)
Differential Revision: D64406945
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137999
Approved by: https://github.com/tugsbayasgalan
We use nn_module_stack in unflatten to recognize when module calls begin and end. However the current format is not sufficient to detect module call boundaries when we have successive calls to the same module, because the successive instructions (end of one call, begin of next call) have the same nn_module_stack. This causes us to effectively "unroll" successive calls to a single call. This can cause problems when preserving module call signatures because the outputs of the successive calls might be concatenated in the single call.
Previously we introduced the concept of a "call index" to generate multiple graphs when unflattening, one per call. This PR pushes this concept into nn_module_stack itself. In particular, the keys of nn_module_stack now go from `key` to `key@call_index`. (In a previous attempt, https://github.com/pytorch/pytorch/pull/137457, instead values in nn_module_stack go from (fqn, type) to (fqn, type, call_index), which is BC-breaking.)
Note that we still do not have the ability to preserve module call signatures for multiple calls to the same module. But now instead of randomly crashing we give a proper error. OTOH when not preserving module call signatures we simply generate multiple calls, each with its own graph, possibly deduplicated, matching what we would do for non-successive calls.
Test Plan: Like D64014936
Differential Revision: D64136277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137646
Approved by: https://github.com/angelayi
Added an optimization pass to the swap function which removes extraneous pytrees. Currently it removes the pytree flatten/unflatten calls between modules in very specific scenarios (all the inputs of one module go into the other).
Future work can be to remove the input pytree.flatten if the inputs go directly into an unflatten, and output pytree unflatten if the outputs are directly from a pytree.flatten.
Differential Revision: [D62879820](https://our.internmc.facebook.com/intern/diff/D62879820)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136191
Approved by: https://github.com/avikchaudhuri
Summary: In unflatten, when we generate module calls when their signature has been preserved, we do not pass the original constant args. This can cause strange effects, e.g., if the module is swapped out with itself, we may suddenly go down a different path than the original, or even crash.
Test Plan: added a test
Reviewed By: angelayi
Differential Revision: D63913750
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137363
Approved by: https://github.com/angelayi
When we populate unlifted graph module, we actually only "unlift" constant tensor inputs which is problematic because export de-duplicates aliasing constants. As a result, we only register one constant instead of two constants. This PR fixes that by querying ep.constants table instead of ep.graph_signature.lifted_tensor_constants.
Differential Revision: [D63743111](https://our.internmc.facebook.com/intern/diff/D63743111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137162
Approved by: https://github.com/pianpwk
Previously we were making a fairly restrictive assumption when unflattening an exported program: for any submodule, we would assert that the graph of every call to that submodule must be the same. This assertion is load-bearing, i.e., if we simply remove the assertion then we can get incorrect results, as shown by the following example.
```
class N(torch.nn.Module):
def forward(self, x, b):
if b:
return x + 1
else:
return x + 2
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.n = N()
def forward(self, x):
x0 = x + 3
x1 = self.n(x0, True)
x2 = x1 + 4
x3 = self.n(x2, False)
return x3 + 5
m = M()
inp = (torch.ones(1),)
print(m(*inp)) # tensor([16.])
ep = torch.export.export(m, inp)
print(ep.module()(*inp)) # tensor([16.])
unflattened = torch.export.unflatten(ep)
print(unflattened(*inp)) # tensor([15.])
```
However, this goes against the spirit of specializing graphs when exporting: we should *expect* that for every call to a submodule we *might* generate a different graph. The goal of this PR is to fix unflattening to handle multiple specialized graphs corresponding to multiple calls to the same submodule.
The idea is simple: for every call to a child module `foo`, we will create potentially different child modules `foo`, `foo@1`, `foo@2`, etc. and use those names as targets in `callmodule` instructions in the parent graph. An immediate consequence of this is that the list of fqns in an unflattened module may not be the same as an exported module. Note that all these variants share the same parameters / buffers, so that multiple calls to the same submodule can share state as expected.
However, as described so far this scheme may end up with needlessly too many submodules. Thus, between calls to the same submodule, if graphs are equal then we optimize away the extra submodules and reuse call names as much as possible. Moreover, when submodules are shared across fqns, we also try to de-duplicate graphs corresponding to their calls as much as possible. Note that no matter what, information about which submodule was called is still preserved, so that if a submodule has to be swapped with another, one can still find all calls to the former submodule and replace them with calls to the latter.
A note on the choice of naming scheme for call names: instead of generating "sibling" modules `foo@1`, `foo@2`, etc. for `foo`, we had considered generating "children" modules `foo._1`, `foo._2`, etc. of `foo`. However this can cause spurious cycles when de-duplicating graphs. E.g., suppose that `foo` is an alias for `bar._1` and `foo._1` is an alias for `bar`, then we must either introduce a cycle or drop the opportunity to optimize. Another idea would be to make `foo` a dummy module that contains `foo._0` corresponding to the first call, but this necessitates too many changes to existing tests and hurts the common case.
Differential Revision: D63642479
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137013
Approved by: https://github.com/pianpwk