As titled. Before the PR, after we split then inline_, there will be getitem calls in the graph while the original graph module doesn't have them. This PR removes the additional get_item calls by inlining.
Test Plan:
Added new test cases for graphs that return multiple outputs and takes multiple inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119913
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #119732, #119736, #119810
This pr is the 1/N pr of transforming the global state mutating ops such as torch._C.set_grad_enabled calls in pre-dispatch graph into a higher order op so that the graph becomes more functional. We make use of split_module to help us do the transformation.
This pr preserves the node.name in original module by adding a new kwarg `keep_original_node_name` to split_module.
For a graph looks like this:
```python
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(add_1, 1)
return pytree.tree_unflatten((add_1, sub), self._out_spec)
```
Before the change, split graph returns the following graphs and subgraphs (notice the change from `add` -> `add_tensor`, `sin` -> `sin_default`:
```python
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
submod_0 = self.submod_0(arg0_1); arg0_1 = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1)
return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)
# submod_0
def forward(self, arg0_1):
add_tensor = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin_default = torch.ops.aten.sin.default(add_tensor); add_tensor = None
sum_default = torch.ops.aten.sum.default(sin_default); sin_default = None
return sum_default
# submod_1
def forward(self, sum_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_tensor = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
return add_tensor
# submod_2
def forward(self, add_1):
_set_grad_enabled = torch._C._set_grad_enabled(True)
sub_tensor = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None
return sub_tensor
""")
```
After the change, the test produce the following graph, all the node names in original graph module are preserved in sub_modules.
```python
def forward(self, arg_0):
sub, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
submod_0 = self.submod_0(sub); sub = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1)
return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)
# submod_0
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
return sum_1
# submod_1
def forward(self, sum_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
return add_1
# submod_2
def forward(self, add_1):
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None
return sub
```
Note that currently, we call split_module on the graph after pre-dispatch aot. The difference is even larger if we `split_module` the graph module produced by dynamo, where all the original variables names in user program are preserved after dynamo but lost after `split_module` without this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119732
Approved by: https://github.com/tugsbayasgalan
Summary:
X-link: https://github.com/pytorch/executorch/pull/1817
Basic support for non-persistent buffers, which are buffers that do not show up in the state dict.
One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict.
This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them.
As a side effect, this diff tightened up quite a few sloppy behaviors around state dict handling:
- Tensor attributes were getting promoted to be buffers—bad!
- Tracing through a module not in the children of the root module would add its parameters/buffers to the state dict—bad!
This behavior is unlikely to show up in user code since the model would be totally broken, but did show up in a bunch of tests.
#buildmore
Test Plan:
unit tests
sandcastle
Differential Revision: D53340041
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118969
Approved by: https://github.com/guangy10, https://github.com/huydhn, https://github.com/titaiwangms
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
This PR rewrites two paths to use the newly-added keypaths API in pytree:
First: we were hand-rolling a tree_map during fakification because we wanted to track sources. This PR uses keypaths instead, which can do the same thing without needing custom code.
Second: our constraint error formatting was referencing placeholder names in error messages. These placeholder names are not otherwise user-visible, so they are super confusing to users (e.g. "which input does arg1_3 correspond to?"). This diff uses the `keystr` API to format the error message.
This necessitated some small refactors—generating the keystr is expensive so doing it in an f-string was very bad.
It can also be further improved—we can inspect the signature so that instead of `*args[0]` we can give people the actual argument name, which would be the ideal UX. But leaving that for later.
Differential Revision: [D53139358](https://our.internmc.facebook.com/intern/diff/D53139358/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118609
Approved by: https://github.com/zhxchen17
ghstack dependencies: #118607, #118608
Summary: While turning on .module() for all the export tests, I uncovered some bugs with .module() and while fixing them I ended up rewriting some of the code... Some of the bugs were:
* bad kwargs support on the unlifted module
* no support for user input mutations
* (at the commit hash i was working off of) no support for custom objects
* there were no tests on unlifting weights from cond/map submodules
Test Plan: CI
Differential Revision: D53075380
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118272
Approved by: https://github.com/suo
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Previous discussion: https://github.com/pytorch/pytorch/pull/109476
In this PR, I made following additions to the original PR:
1) Unlifted graph module now runs the runtime assertions in its' forward call.
2) When we retrace, we make sure we run the assertions to make sure user is tracing the module with correct inputs with respect to the assumptions we made during first tracing. The way I do is that I create new graph module type with modified call method. And the runtime assertions happen under torchdynamo.disable so that it is just run in eager directly. The reason is we don't this to be traced part of the graph.
3) Both ep.module and capture_pre_autograd now returns _UnliftedGraphModule.
Differential Revision: [D51078056](https://our.internmc.facebook.com/intern/diff/D51078056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110222
Approved by: https://github.com/zhxchen17
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.
Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536
Serializing to json is more stable, and renamed the API:
```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```
If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```
We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
type: Optional[str] # a string name of the type. null for the case of a LeafSpec
context: Optional[Any] # optional, a json dumpable format of the context
children_specs: List[TreeSpec],
}
```
If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.
## Motivation:
HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.
This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.
## Implementation:
@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](https://github.com/google/jax/issues/2371#issuecomment-805361566), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.
We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.
Also added some tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106160
Approved by: https://github.com/zhxchen17