Summary:
Implements feature request https://github.com/pytorch/pytorch/issues/62021
Test it out with
```python
from torch import fx
from torch import nn
def fx_int(x):
return int(x)
class MyModule(nn.Module):
def forward(self, x):
return fx_int(x.shape[0] / 2)
tracer = fx.Tracer(autowrap_functions=(fx_int,)) # or remove kwarg to demonstrate symbolic trace error
tracer.trace(MyModule())
```
First time contributor, so please advise if I could have done anything to make lives easier for next time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62106
Reviewed By: SplitInfinity, driazati
Differential Revision: D30080834
Pulled By: jamesr66a
fbshipit-source-id: 68fadf8c881ea7930e7afd62b642874010fe4903
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62436
## Problem
Given two modules and a tracer that indiscriminately marks all modules as a leaf:
```
class InnerModule(torch.nn.Module):
def forward(self, t):
return t + t
class MyModule(torch.nn.Module):
def __init__(self, inner):
super().__init__()
self.inner = inner
def forward(self, t):
x = self.inner(t)
y = self.inner(t)
return x + y
class MyTracer(torch.fx.Tracer):
def is_leaf_module(self, module, name):
return True
```
One might generally expect the following behavior (note call_module nodes):
```
print(">> Outer GraphModule (with inner module as nn.Module):")
inner = InnerModule()
m = MyModule(inner)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())
>> Outer GraphModule (with inner module as nn.Module):
opcode name target args kwargs
------------- ------- ----------------------- ---------------- --------
placeholder t t () {}
call_module inner inner (t,) {}
call_module inner_1 inner (t,) {}
call_function add <built-in function add> (inner, inner_1) {}
output output output (add,) {}
None
```
However, when the inner module is first symbolically traced, the symbolic trace of the outer module ignores `is_leaf_node` entirely, and traces through the whole module (note call_function nodes).
```
print(">> Inner module as GraphModule:")
inner = InnerModule()
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
print(inner_gm.graph.print_tabular())
print(">> Outer GraphModule (with inner module as GraphModule):")
m = MyModule(inner_gm)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())
>> Inner module as GraphModule:
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder t t () {}
call_function add <built-in function add> (t, t) {}
output output output (add,) {}
None
>> Outer GraphModule (with inner module as GraphModule):
opcode name target args kwargs
------------- ------ ----------------------- ------------ --------
placeholder t t () {}
call_function add <built-in function add> (t, t) {}
call_function add_1 <built-in function add> (t, t) {}
call_function add_2 <built-in function add> (add, add_1) {}
output output output (add_2,) {}
None
```
This is surprising behavior and at first glance violates the tracer's intent. As I understand it, `torch.fx.symbolic_trace.Tracer.trace` intends to patch `torch.nn.Module.__call__` with a `module_call_wrapper()` that records a `call_module` node if the module is a leaf, else executes `torch.fx._symbbolic_trace._orig_module_call = torch.nn.Module.__call__`, which is set a module loading time.
**Every submodule should be a leaf, but no `call_module` nodes are created when that submodule is a `GraphModule`. Why?**
Upon further inspection, I found:
- The constructor for GraphModule includes a path to `GraphModule.recompile()` via the setter for a `fx.Graph`:
```
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
File "/torch/fx/graph_module.py", line 252, in __init__
self.graph = graph
File "/torch/nn/modules/module.py", line 1183, in __setattr__
object.__setattr__(self, name, value)
File "/torch/fx/graph_module.py", line 277, in graph
self.recompile()
```
- `recompile()` wraps the `__call__` method by holding a reference to the `__call__` method at the time of recompilation:
```
cls = type(self)
cls_call = cls.__call__
...
def wrapped_call(self, *args, **kwargs):
try:
return cls_call(self, *args, **kwargs)
except Exception as e:
...
cls.__call__ = wrapped_call
```
- Recompilation of the inner GraphModule happens on initialization, before creation or tracing of the outer module. Adding some old-fashioned print debug statements gives:
```
Inner Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
recompile: cls.__call__ now wraps _orig_module_call, <function Module._call_impl at 0x7faaebfee8b0>
Outer Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
tracing: patching method <class 'torch.nn.modules.module.Module'>.__call__ <function Module._call_impl at 0x7faaebfee8b0> with <function Module._call_impl at 0x7fa9d42bce50>
outer module MRO before tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
outer module MRO during tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
inner module MRO before tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
inner module MRO during tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
```
- The outer module is patched correctly, but the inner module's first element in its MRO is the `wrapped_call` from `recompile` that still invokes `<function Module._call_impl at 0x7faaebfee8b0>` directly. Therefore, no call_module nodes are created.
## In Practice
In practice, this behavior affects the ability of `torch.package` to package `GraphModules` whose submodules are `GraphModules`. In our case, the `GraphModule` submodules are not passed through a constructor, but created separately and installed on the root `GraphModule` via `setattr`. This means that prior to packaging, there appear to be no issues with the module, since the root's graph was created before any call_module targets were replaced with `GraphModules`.
When unpackaging such a model with `torch.package`, `torch.fx.graph_module._deserialize_graph_module` uses an inline `KeepModules` tracer that sets all submodules to leaves; the unpackaged module is implicitly and surprisingly inlined in the process.
## Potential Solution
This behavior was previously not understood by us, and so the current workaround is a gnarly process of wrapping all submodules with a `nn.Module` with a manually installed forward method.
Changing `wrapped_call` to return `return super(type(self), self).__call__(*args, **kwargs)` whenever `__call__` is inherited at least appears to solve the issue. Does this seem like an acceptable approach?
## Other Thoughts
- Repeated calls to `recompile` create nested `wrapped_calls`, all for the purpose of error handling. This seems probably unnecessary ¯\\_(ツ)\_/¯
- If a root module with a overriden `__call__` method is symbolically traced, it is ignored
Test Plan:
```
buck test:
✓ ListingSuccess: caffe2/test:fx - main (12.570)
✓ Pass: caffe2/test:fx - test_tracing_graphmodules_as_leaf_submodules (test_fx.TestFX) (11.982)
```
Reviewed By: ansley
Differential Revision: D29997935
fbshipit-source-id: 1988fbb025b14188da26a3e73e94fb789c3c1f74
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61733
Allow FX tracer to trace control flow (if/while) statements when parameter shapes are in the condition.
If the user specifies the new "param_shapes_constant" option when constructing a tracer, the model's parameter shape attribute will be evaluated and the resulting constant will be emitted into the IR during tracing.
Also added a new test
`
python test/fx/test_fx_param_shape_control_flow.py
`
The test also performs a somewhat whitebox style testing to check the generated Python code from the IR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61820
Reviewed By: bdhirsh
Differential Revision: D29969299
Pulled By: jerryzhenleicai
fbshipit-source-id: 99aae824bdfec880be69258de7ead5c8cd59eddc
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292
This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.
Test Plan
- new tests in test/test_pytree.py and test/test_fx.py
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D29947302
Pulled By: zou3519
fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
Summary:
### Issue
Build PyTorch wheel packages during build stage for pull requests and install during test stage.
### Fix
Update all tests which call lib*.so (under `./build folder`), change to call lib*.so in `{ent}/pytorch/lib/python3.8/site-packages/torch`
### Diff
This diff starts to update test_fx, test_backend and test_torchbind first to check if current ci pass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61960
Test Plan: check of all ci workflows pass
Reviewed By: malfet, saketh-are
Differential Revision: D29823235
Pulled By: tktrungna
fbshipit-source-id: e7f652def698e303d4843fbaedf4859f5eca2fd9
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61780
These changes would allow objects to control how they are handled when they are an argument to a torch.fx call_module node from within their source. Previously, we have been using a custom Tracer with an overridden create_arg() method and branching based on class name to handle args that are unusual (data classes, etc).
Reviewed By: suo, houseroad
Differential Revision: D27976120
fbshipit-source-id: 0c5249c5f8398368ca0fbec0ad8a07ccf99b7da4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61463
seems like a small oversight(?), current test fails when warnings are recorded. discovered this when calling `graph.call_module(existing_call_module_node.target)` and it raised a warning
Test Plan: `buck test //caffe2/test:fx`
Reviewed By: ansley
Differential Revision: D29637799
fbshipit-source-id: 2305629863230235f76a926fe2e4de480cbf853c
Summary:
Reference https://github.com/pytorch/pytorch/issues/50345
`zeta` was already present in the codebase to support computation of `polygamma`.
However, `zeta` only had `double(double, double)` signature **for CPU** before the PR (which meant that computation `polygamma` were always upcasted to `double` for zeta part).
With this PR, float computations will take place in float and double in double.
Have also refactored the code and moved the duplicate code from `Math.cuh` to `Math.h`
**Note**: For scipy, q is optional, and if it is `None`, it defaults `1` which corresponds to Reimann-Zeta. However, for `torch.specia.zeta`, I made it mandatory cause for me it feels odd without `q` this is Reimann-Zeta and with `q` it is the general Hurwitz Zeta. I think sticking to just general made more sense as passing `1` for q sounds trivial.
Verify:
* [x] Docs https://14234587-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.zeta
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59623
Reviewed By: ngimel
Differential Revision: D29348269
Pulled By: mruberry
fbshipit-source-id: a3f9ebe1f7724dbe66de2b391afb9da1cfc3e4bb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60057
This ensures that if a function was `wrap`'d before symbolic tracing + being passed into the transformer then it will still be wrapped.
Test Plan: Added test to `test_fx.py`
Reviewed By: jamesr66a
Differential Revision: D29151191
fbshipit-source-id: 93560be59505bdcfe8d4f013e21d4719788afd59
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57749
add to a fx test
Test Plan: Imported from OSS
Reviewed By: huiguoo
Differential Revision: D28425974
fbshipit-source-id: 195c7a1944decb7a2a99c2831cab38485f32be17
Summary:
Fixes https://github.com/pytorch/pytorch/issues/57719.
This PR fixes `torch.Tensor{__rsub__, __rdiv__, __rtruediv__, __pow__, __rmatmul__}` to return `NotImplemented` instead of raising a `TypeError`.
cc/ mruberry: The first commit of this PR is the same as 1d209db1cc excepts the commit message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57934
Reviewed By: mruberry
Differential Revision: D28351931
Pulled By: albanD
fbshipit-source-id: 985457a44dba24d2496794dfb8c1661cbcd4ff8f
Summary:
```
class Foo(nn.Module):
def __init__(self):
super().__init__()
def forward(self, y, x):
for k in x:
for v in x[k]:
v += y
return x
example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})
fx.symbolic_trace(new_f, concrete_args=example_dict)
```
prints out
```
def forward(self, y, x):
y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
add = tree_2 + y
add_1 = tree_3 + y
add_2 = tree_4 + y; y = None
return {'a': [tree_2], 'z': [tree_3, tree_4]}
```
Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.
Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888
Reviewed By: jamesr66a
Differential Revision: D27884694
Pulled By: Chillee
fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57383
Notes: I picked up an activation from https://github.com/pytorch/pytorch/issues/56969. You can look at the [activations.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/Activation.cpp#L429) file which has both forward and backward kernel code to help you write the NNC lowering and the symbolic gradient.
I added a test in test_jit_fuser_te for the fusion, and I added an OpInfo and asserted that we expect to see autodiffable nodes to test the symbolic gradient.
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D28197820
Pulled By: eellison
fbshipit-source-id: 05305d85c5bb0847c8f911b95ba47b137dca7e90
Summary:
Fixes https://github.com/pytorch/pytorch/issues/45687
Fix changes the input size check for `InstanceNorm*d` to be more restrictive and correctly reject sizes with only a single spatial element, regardless of batch size, to avoid infinite variance.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56659
Reviewed By: pbelevich
Differential Revision: D27948060
Pulled By: jbschlosser
fbshipit-source-id: 21cfea391a609c0774568b89fd241efea72516bb
Summary:
Fixes https://github.com/pytorch/pytorch/issues/55398
Generates tests that calls `symbolic_trace` on torchvision models and verifies the parity of outputs from eager model, `fx.GraphModule`, `jit.ScriptModule`.
Test errors: GoogleNet and Inception models throw a type mismatch when scripting the traced `fx.GraphModule`.
```
Return value was annotated as having type __torch__.torchvision.models.googlenet.GoogLeNetOutputs but is actually of type Tensor:
dropout = self.dropout(flatten); flatten = None
fc = self.fc(dropout); dropout = None
return fc
~~~~~~~~~ <--- HERE
```
Relevant type-inconsistency 512ea299d4/torchvision/models/googlenet.py (L200)
```
torch.jit.unused
def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
else:
return x # type: ignore[return-value]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55744
Reviewed By: albanD
Differential Revision: D27920595
Pulled By: suraj813
fbshipit-source-id: 01f6f2aef7badbde29b5162a7787b5af9398090d