Summary:
This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.
This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function
Unit Test Added:
test_split_module_dead_code
A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
def forward(self, x):
output = x * 2 # we want this
dead_line = x + 2 # this is dead
return output
```
Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```
After:
```
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
submod_2 = self.submod_2(x)
submod_1 = self.submod_1(x); x = None
return submod_1
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
add = x + 2; x = None
return None
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
mul = x * 2; x = None
return mul
```
Submod 2 is correctly extracted
Test Plan: Tested with new unit test
Differential Revision: D47196732
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104554
Approved by: https://github.com/yf225
Summary: Title, the mapping currently has lots of unused keys due to the condition or always return True, but it will not affect the correctness.
Test Plan: N/A
Differential Revision: D43579510
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95493
Approved by: https://github.com/Skylion007
Summary: One common cause of jit unscriptability issue is loss of node type annotations on local names after one or several FX transform(s). One way to improve the type coverage is to eagerly annotate the type for `getitem` nodes from its parent sequence node. This diff introduces an fx pass to do that.
Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```
Reviewed By: xush6528
Differential Revision: D41749744
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90237
Approved by: https://github.com/xush6528
Summary: Some nodes lost the type annotation during `split_module`, causing the submodels to be un-scriptable. This is because compiler always infer Tensor type, which is wrong for non-Tensor types. We attempt to infer type annotation for `getitem` node to improve scriptability.
Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```
Differential Revision: D41037819
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88510
Approved by: https://github.com/xush6528
Summary:
{F770932209}
Given the original execution order and the node dependency relationship (note that the same dependency order could generate multiple execution order, which refers to “Topological Order”), after reunion, we could find the new execution order of the new GraphModule is different from the original one which is not what we want.
For example, let’s assume that NewLeaf_1 is EmbeddingLookup (Calling EmbeddingLookup is awaitable, we will keep executing the following nodes rather than waiting for the result until we have to know the lookup result), NewLeaf_4 is the node where we HAVE to get the lookup result to interact with the NewLeaf_3. So NewLeaf_1 will launch a lookup kernel and all2all communication stream to distribute the result to all ranks. In the meantime, we want to keep executing NewLeaf_2 and NewLeaf_3 to avoid meaningless waiting. However, given the new execution order, we have to wait for the lookup kernel and all2all communication to be finished since the next node NewLeaf_4 needs the result, until then we can execute NewLeaf_2, etc. It cannot leverage the advantage of parallel computation and communication stream and will hurt the QPS a lot.
So while constructing the GraphModule, we have to change from the topological order to the original order
Test Plan:
Unit test
Not sure how to add tests in FX as there's no TARGETS, so I added in the TorchRec folder
Differential Revision: D39567314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85188
Approved by: https://github.com/SherlockNoMad
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73564
While maintaining API backward compatibility, add an optional output parameter to split_module() that returns a mapping from the new qualified names in the modules after split to the old qualified names in the original module
Test Plan:
1. Added a test (test_split_qualname_mapping) to test_fx_experimental.py to check the returned qualname mapping
```
$ python test_fx_experimental.py
...
Ran 1084 tests in 73.464s
OK (skipped=531, expected failures=4)
```
2. Ask test_fx.py to accept split_module's new signature
```
$ python test_fx.py --accept
```
Reviewed By: jamesr66a
Differential Revision: D34541792
fbshipit-source-id: e8ec7e77ec884e4db7cad0c0593e31861c76e42d
(cherry picked from commit d2e5a95a353ee5fb52cdba065f127489e9df47ae)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71790
If a leaf module is specified, it means we should treat it as a blackbox and we should just avoid rewriting it too.
Test Plan:
```
buck test caffe2/test:test_fx_acc_tracer
```
with a new unit test.
Reviewed By: jfix71, houseroad, wushirong
Differential Revision: D33731903
fbshipit-source-id: 0560d9e8435b40f30d9b99dc3b2f47d1a04eb38b
(cherry picked from commit 747e9e44ee)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71016
I found out that `split_module` doesn't preserve default values for arguments. In trying to fix that, I noticed that `Graph.placeholder` doesn't make it easy to add a default argument when making a placeholder. This PR addresses both of those issues
Test Plan: Imported from OSS
Reviewed By: ansley
Differential Revision: D33482218
Pulled By: jamesr66a
fbshipit-source-id: 57ebcdab25d267333fb1034994e08fc1bdb128ee
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65542
Add docstring for torch.fx.passes.split_module that conforms to Google Python Style conventions.
Changed original example to the example from this diff:
https://www.internalfb.com/diff/D24925283 (9734c042b8)
Test Plan:
Ran buck test //caffe2/test:fx. No errors detected
https://pxl.cl/1QCch
Reviewed By: jamesr66a
Differential Revision: D31145694
fbshipit-source-id: 8e54f3b1be3dca1c4d414fdeeab71b9f2b5d9f3e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65933
We use `split_module` to split the input model that we want to const fold into const and non-const subgraphs. Previously we were taking the non-const graph and trying to hack it back into the same signature as the input model. However this was complex/buggy.
Instead, refactor to just keep using the base split module that contains both const and non-const graphs. This means we:
- Inline the non-const graph into the split module
- Remove the const graph from the module and replace it with a getattr that will be run to insert that attr when we `run_folding`
Test Plan: Added test coverage to cover newly supported folding, and updated other tests for new strategy.
Reviewed By: yinghai
Differential Revision: D31293307
fbshipit-source-id: 6e283a8c7222cf07b14e30e74dffc8ae5ee8b55f
Summary:
During development it is common practice to put `type: ignore` comments on lines that are correct, but `mypy` doesn't recognize this. This often stems from the fact, that the used `mypy` version wasn't able to handle the used pattern.
With every new release `mypy` gets better at handling complex code. In addition to fix all the previously accepted but now failing patterns, we should also revisit all `type: ignore` comments to see if they are still needed or not. Fortunately, we don't need to do it manually: by adding `warn_unused_ignores = True` to the configuration, `mypy` will error out in case it encounters an `type: ignore` that is no longer needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60006
Reviewed By: jbschlosser, malfet
Differential Revision: D29133237
Pulled By: albanD
fbshipit-source-id: 41e82edc5cd5affa7ccedad044b59b94dad4425a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56212
The current design doesn't make it easy to use `node.copy()`. Explicitly copy over the node's meta.
Test Plan: Updated `test_subgraph_creation` in `test_fx_experimental`
Reviewed By: jamesr66a
Differential Revision: D27808477
fbshipit-source-id: 7fe7b6428c830307dbd1e395f16fa2774936d3b3