Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.
For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
%sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
return (getitem,)
submodule graph():
%randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.
But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
%_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
return (_fx_const_folded_attrs,)
```
This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.
## Fix
We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule
If the call_module subgraph has impure ops, return True to `is_impure`
Test Plan: added tests to test_fx_const_fold.py
Differential Revision: D85798483
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166609
Approved by: https://github.com/blaine-rister
This PR is part of a series attempting to re-submit #134592 as smaller PRs.
In fx tests:
- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)
- Remove any remaining uses of "unittest.main()""
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715
Approved by: https://github.com/Skylion007
Applies some more harmless pyupgrades. This one gets rid of deprecated aliases in unit_tests and more upgrades yield for loops into yield from generators which are more performance and propagates more information / exceptions from original generator. This is the modern recommended way of forwarding generators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94309
Approved by: https://github.com/albanD
Preparation for the next PR in this stack: #89559.
I replaced
- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).
There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
Summary:
This is an un-backout but with a small change to set the default device `device_for_folded_attrs="cuda"` instead of `"cpu"`, which should avoid BC issues for TRT lowering.
Original commit changeset: 4ae1863e28ff
Original Phabricator Diff: D37192230 (24c2aff1b2)
Test Plan: Added unit test
Differential Revision: D37205432
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79696
Approved by: https://github.com/dborkovic
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68614
We need to copy modules over to the `split` graph during const folding. We were previously only doing so from the non-constant submod, but we need to do this for the constant one as well in case some `call_module` is const folded.
Test Plan: Added unit test
Reviewed By: wushirong, 842974287
Differential Revision: D32543289
fbshipit-source-id: 80d1d0ce2c18a665b00e1343d6c55d939390ab10
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65933
We use `split_module` to split the input model that we want to const fold into const and non-const subgraphs. Previously we were taking the non-const graph and trying to hack it back into the same signature as the input model. However this was complex/buggy.
Instead, refactor to just keep using the base split module that contains both const and non-const graphs. This means we:
- Inline the non-const graph into the split module
- Remove the const graph from the module and replace it with a getattr that will be run to insert that attr when we `run_folding`
Test Plan: Added test coverage to cover newly supported folding, and updated other tests for new strategy.
Reviewed By: yinghai
Differential Revision: D31293307
fbshipit-source-id: 6e283a8c7222cf07b14e30e74dffc8ae5ee8b55f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65223
If there're unused inputs, they won't appear in `submod_1`. We need to add all the unused inputs so that the model after const fold has the same inputs as the original model.
Reviewed By: jfix71
Differential Revision: D31021217
fbshipit-source-id: b7452c90d133b747e0699936a81d3fee14af9cc9
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64945
In the const folding pass, we try to create `get_attr` nodes in submod_1 for `get_attr` nodes that are in the main graph. But we don't have the real attributes in submod_1. To fix this we assign main module as the owning module of sumod_1 graph.
The fix above would cause problem for `call_module` node in submod_1 because during split modules gets inlined (target changed from "mod.a.b" -> "mod_a_b") to submod_1. Changing the owning module would make those `call_module nodes unable to find the referring module. To fix this, we set the targeting module to main module.
Reviewed By: jfix71
Differential Revision: D30905949
fbshipit-source-id: cd67bc8fe4b8ad4344ae97b8e36753fdce3ece6d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64782
Previously, get_attrs that were added to the graph did not retain node.meta after folding. Add such support, and improve coverage in general here.
Test Plan: Added test coverage.
Reviewed By: protonu
Differential Revision: D30852704
fbshipit-source-id: ece87a61c69b2e68982964c6adc4dde14dae12c7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64342
Previously we weren't handling the case where an attribute was in a module that wasn't the root.
Test Plan: Added unit test coverage.
Reviewed By: yinghai
Differential Revision: D30691730
fbshipit-source-id: b39b5cf748c4c882f315a4f32b51ad88cc7a43ed
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48443
Add a constant folding pass in FX:
- Iterate over an input graph and tag what nodes are fully constant, i.e. either `get_attr` nodes, or nodes with all inputs that are either `get_attr` or constant
- Use `model_transform.split_by_tags()` to split the graph into two
- Look for the `output` node in the constant graph to get names of attrs that will be folded
- Iterate over the non-constant graph and replace placeholders that are using the same name as the attrs with a `get_attr` as well as a dummy attr on the module
- Return these two graphs in a new `FoldedGraphModule`, which is a normal GraphModule but also stores the constant graph on the side along with a `run_folding()` method that will run const folding and update the dummy parameters with the actual folded parameters
Test Plan: Added a couple tests
Reviewed By: 842974287
Differential Revision: D25033996
fbshipit-source-id: 589c036751ea91bb8155d9be98af7dbc0552ea19