Summary: Adding an experimental API to FX graph module to place "hooks" every time when we are changing or replacing nodes in a graph, so that we can properly update the new name in graph signature and potentially other places.
Test Plan:
buck test mode/opt -c fbcode.enable_gpu_sections=true caffe2/test/distributed/_tensor/experimental:tp_transform
buck test mode/opt caffe2/test:test_export -- -r test_replace_hook
Differential Revision: D52896531
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117825
Approved by: https://github.com/avikchaudhuri
Inductor codegen for `_assert_async` is currently disabled because we don't really understand how to codegen `scalar_to_tensor` on a Sympy expression. I initially tried to see if I could get this to work, but I got into some weird problem involving stride sorting, so I decided to fix it properly by not going through a tensor.
So we introduce an `_assert_scalar` which takes a scalar as an argument, avoiding needing to turn a SymBool into a tensor before asserting on it. I also add `_functional_assert_scalar` for good luck, although this doesn't do anything right now because https://github.com/pytorch/pytorch/pull/104203 still hasn't been landed.
I need to customize the codegen for this operator, so I decide to directly implement it in Inductor, rather than trying to treat it as a generic ExternKernel. This leads to the new AssertScalar IR node. This is written carefully so that it doesn't get DCE'd by Inductor.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114148
Approved by: https://github.com/jansel
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
- `assert(a == b)` -> `assert a == b`
- `if(x > y or y < z):`->`if x > y or y < z:`
- And `return('...')` -> `return '...'`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
Summary:
Traditionally when user want to update the arguments for an FX node, the only way is to call the setter of .args property on nodes. This may be problematic when we insert a lot of arguments. Because of the semantics of the setter method, it has a worst case O(n) complexity.
Adding a new insert_arg provides us two benefits:
1. The operation is guaranteed to be O(1) cost.
2. User can express the intentation more directly, instead of writing code like `node.args = (arg,) + node.args`
Test Plan: caffe2/test:fx -- -r test_insert_arg
Reviewed By: suo
Differential Revision: D50574435
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111974
Approved by: https://github.com/angelayi
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```
The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).
The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.
**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```
Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
Introduces two higher order operators
* run_and_save_rng_state - Saves the current rng state and then runs the op.
* run_with_rng_state - Runs the op with the rng state supplied as an input
Ideally, we would like to use torch.compile for these operators. But currently the plan is to introduce these operators at the partitioner level, obviating the need to support them fully through the torch.compile stack. To ensure that we have good enough debugging with minifiers, we have ensure that they work with make_fx. In future, we can move on torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102934
Approved by: https://github.com/jansel, https://github.com/zou3519
This PR introduces a new operator called aten._assert_async.msg, which allows passing a tensor value and assertion message as inputs. As part of TorchDynamo, we're replacing the use of torch._assert with this new operator so that make_fx also knows how to handle assertions. This is subset of https://github.com/pytorch/pytorch/pull/98878, refer there for historic reviews.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100101
Approved by: https://github.com/jansel
Summary: There're some customized functions that we would also like to keep during eliminate dead code pass. Add a function to help us to do.
Test Plan: Added a unit test
Differential Revision: D44273630
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97288
Approved by: https://github.com/houseroad
I added a bunch of asserts to verify that I didn't accidentally kill copy_ in the graph, hopefully this combined with our existing tests is good enough.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97275
Approved by: https://github.com/bdhirsh
I added a bunch of asserts to verify that I didn't accidentally kill copy_ in the graph, hopefully this combined with our existing tests is good enough.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97275
Approved by: https://github.com/bdhirsh
As found in #92709, thanks to @ngimel and @jansel, currently `torch.Tensor.fn` points to `UserDefinedObjectVariable` rather than `TorchVariable`. The root cause is due to https://github.com/pytorch/pytorch/pull/92709#pullrequestreview-1273357406. To prevent this, build `TorchVariable` of `torch.Tensor.fn` pointing to `torch.ops.aten.fn`.
This issue propagates to `torch.Tensor.fn` causing graph break with `nopython=True`.
```python
import torch
import torch._dynamo as dynamo
#op = torch.ops.aten.abs_ # no graph break
op = torch.Tensor.abs_ # graph break
args = torch.empty(10)
def foo(args):
return op(args)
opt_foo = dynamo.optimize("inductor", nopython=True)(foo)
y_ = opt_foo(args)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93243
Approved by: https://github.com/jansel
Use Prims to implement group_norm, group_norm_backward and mean_var
Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.
Add `__all__` to `torch/backends/mps/__init__.py` as well as alias all imports as private
Add `TestNNMPS.test_group_norm_backward` that validates no NaNs are generated during the backward pass
Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
Fixes https://github.com/pytorch/torchdynamo/issues/1708
Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph.
If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph.
Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87255
Approved by: https://github.com/wconstab, https://github.com/SherlockNoMad
It seems like the [torch.fx.Node docs](https://pytorch.org/docs/stable/fx.html#torch.fx.Node) are incorrect regarding the inclusion of the self argument for module call nodes.
While the docs state that self (the module) is included in `args`, it is in fact not, as demonstrated by this code:
```python
import torch
from torch import fx, nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.submod = nn.Linear(10, 10)
def forward(self, x):
x = x.flatten()
return self.submod(x)
graph_module = fx.symbolic_trace(Net())
print(graph_module.graph) # doesn't show self for the submodule call
submod_node = list(graph_module.graph.nodes)[2]
print(submod_node.op) # call_module
print(submod_node.args) # (flatten,) => would need to have len 2 if self was included
flatten_node = list(graph_module.graph.nodes)[1]
print(flatten_node.op) # call_method
print(flatten_node.args) # (x,) => here self is included (and docs are correct)
```
Since [torch.fx.Interpreter also uses `args` as if self was is not included](2fe5808590/torch/fx/interpreter.py (L288)), I assume the docs are incorrect.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86685
Approved by: https://github.com/soulitzer
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72301
First step in resolving #35026.
This adds `PythonRecordFunction` which is a `torch::CustomClassHolder`
for `at::RecordFunction` to keep the ATen code free of torch includes.
And adds new unused internal API functions
`_record_function_enter_new` which return the torchbind object.
Once the FC period is expired, `torch.profiler.record_function` will
be updated to use this new internal API. Then once BC period is
expired, the cpp_custom_type_hack-based API can be removed.
Test Plan: Imported from OSS
Reviewed By: dagitses
Differential Revision: D34586311
Pulled By: robieta
fbshipit-source-id: d3eb9ffad7b348548a2b22c75203a92d1cb5115b
(cherry picked from commit 92d2ca808e5fbd20c9d6645dcabc3f059f9ef2d3)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73763
The test that is enabled generates a graph as such:
```
linear_25 --> sigmoid_14 --> output_1
\--> output_2
```
Before this diff, (unpadding) layout_transform nodes would be added as follows:
```
linear_25 --> layout_xform1 --> sigmoid_14 --> layout_xform2--> output_1
\--> output_2
```
This causes an assertion to fail for the sigmoid node where the input and output types
don't match due to padding differences.
This diff modifies the replacement algorithm to not affect users of an output's parent node
when the user requires padded inputs. This yields the following graph instead:
```
linear_25 --> sigmoid_14 --> layout_xform2--> output_1
\--> layout_xform1 --> output_2
```
Test Plan: Manually and CI
Reviewed By: jfix71, dborkovic
Differential Revision: D34623590
fbshipit-source-id: 3834b06c95fc5626eccc282216cbe039ac5a3242
(cherry picked from commit af012372ae1a6bb654b0ed9b765993960d5251e4)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73464
- Improve formatting of graph by centering everything
- Add num_users
- Add args/kwargs
- Don't print more than 10 of any list/tuple by default (this is necessary for very large concats)
Test Plan: tested locally
Reviewed By: khabinov
Differential Revision: D34492256
fbshipit-source-id: 8073992edb3efddcf8bfd72e2d3db49cc242db10
(cherry picked from commit b1b802965c143fdb0d308b70f51aa741f7d90f78)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73198
Previously, if an arg to an FX node is a subclass of tuple then it gets sanitized essentially back to that base class. An example here is when setting an arg to be a TensorMetadata object, which is a NamedTuple, it will be set as a tuple instead.
- Change `map_aggregate` to repack the tuple to `type(a)` when it's not directly a tuple (try/except for best attempt)
- During codegen, call `add_global` for `type(a)` if it's not directly a tuple.
- Add an option for an arg to provide a `_custom_fx_repr_fn` for use inside stringifying via `_format_arg`
Test Plan: Added unit test coverage, where we inline the named tuple into arg/kwarg.
Reviewed By: jamesr66a
Differential Revision: D34381888
fbshipit-source-id: bd672a8542e2bba5aa604b448bec920efc256440
(cherry picked from commit 68f99c12dd)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69368
Before this PR, copying a node would lose the stack trace. This PR
ensures that the stack trace is preserved across copies.
This is useful because quantization passes would like to start
allowing the user to preserve stack traces, and we use the copy
behavior.
Test Plan:
```
python test/test_fx.py TestFX.test_stack_traces
```
Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D32835248
fbshipit-source-id: 91610fd8d05f5683cfa5e11fb6f9f3feacb8e241
Summary:
Fixes [issue#67](https://github.com/MLH-Fellowship/pyre-check/issues/67)
This PR fixes the type checking errors in Pytorch torch/fx/node.py .
The variable types in 363:20 and 364:20 were declared to have type `List[str]` but were assigned a value of `None`. This caused an incompatitble variable type error. I changed the type from `List[str]` to `Optional[List[str]` . This therefore fixed the incompatitble variable type error.
Signed-off-by: Onyemowo Agbo
onionymous
0xedward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68124
Reviewed By: gmagogsfm
Differential Revision: D32322414
Pulled By: onionymous
fbshipit-source-id: be11bbbd463715ddf28a5ba78fb4adbf62878c80
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67068
Prepending a node to itself will result in the node gets removed from the graph.
Usually people won't prepend a node with itself. But people would accidentally try to append a node that's already next to `self` node, which will be prepending `self` to `self`.
Test Plan: Added a unit test
Reviewed By: jamesr66a
Differential Revision: D31849030
fbshipit-source-id: b0fdfbb893f785f268595acd823b426d57c15e61
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66048
Previously, create_arg would fail if it encountered a not `None` layout argument. Adding it to `BaseArgumentTypes` list should be enough to fix that.
Test Plan: Added unittest
Reviewed By: jamesr66a
Differential Revision: D31362662
fbshipit-source-id: 20049971e18c17e9c75e50540500c567266daa55
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55995
Normalization is kind of broken currently. But making default arguments visible still appears to work, and is nice functionality to still be able to rely on/use. Adds an option to `NormalizeArgs`'s `__init__` called `normalize_to_only_use_kwargs` which defaults to true, which if set to false will keep using the same signature as provided, but additionally set kwargs in kwargs.
Test Plan: Added test to `test_fx_experimental`.
Reviewed By: 842974287
Differential Revision: D27759448
fbshipit-source-id: 620061fcf46d8549ac70b62aede8b6740aee3778
Summary:
Commandeered from https://github.com/pytorch/pytorch/pull/54563
Primary changes from first PR:
1. Refactored primary `normalize_function` logic into `operator_schemas.py` so that non-FX users can use it.
2. Refactored tests a bit, and added a path to call `normalize_function` directly.
3. Moved check for `boolean_dispatch` so that `torch.lu` also gets properly handled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55992
Reviewed By: mruberry
Differential Revision: D27774396
Pulled By: Chillee
fbshipit-source-id: 7f65632e1d608e4abd55aec5ccbfdc3f67f52b8e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52658
DCE will reverse iterate over the graph looking for nodes without users and delete them. It will skip over unused placeholders (since this affects the signature of the method) and outputs (which never have users but we want to keep them :) )
Test Plan: Added unit tests
Reviewed By: jamesr66a, khabinov, chenccfb
Differential Revision: D26602212
fbshipit-source-id: f4f196973e40546076636090bb0008c24f33795e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51974
Right now, when an FX `Graph` references an external object, we will emit
code like:
import foo
def forward(input: foo.bar.baz):
...
This is problematic in a world with `torch.package`, since then name
`foo.bar.baz` may reference a name from any number of packages.
This PR lays the groundwork for FX-package integration by separating the
resolution of external references from the genration of the function
code.
When generating a Graph's Python source, we keep track of all external
references and assign them unique names. At the end, we have a
dictionary mapping names -> actual objects. This becomes the `globals`
namespace we pass to `exec` when installing the forward function in a
`GraphModule`. This is nice because we can always be sure that `exec` is
seeing the same objects that were referenced from the `Graph`, no import
statements needed.
At serialization time, we use a `ModuleEnv` to resolve the globals dict
to a set of import statements that can be run to reprodce the `global`
namespace. This is only used on serialiation/deserialization, and those
functions are expected to check that the import statements are producing
the correct results.
Concretely, the code above will now look like:
from foo.bar import baz as foo_bar_baz
def forward(input: foo_bar_baz):
...
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D26340593
Pulled By: suo
fbshipit-source-id: fe247f75205d0a03fd067bdd0f95491e8edf1436
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46325
Otherwise, mutating them would make the uses/users lists inaccurate.
You can still mutate the node by assigning a new value to .args or .kwargs
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D24308672
Pulled By: zdevito
fbshipit-source-id: a5305e1d82668b36e46876c3bc517f6f1d03dd78
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46121
Otherwise, mutating them would make the uses/users lists inaccurate.
You can still mutate the node by assigning a new value to .args or .kwargs
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D24232288
Pulled By: zdevito
fbshipit-source-id: c95b1a73ae55ad9bdb922ca960c8f744ff732100
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45708
This makes it possible to define reasonable semantics for what happens
when a node in the list is deleted. In particular the iteration over nodes
will continue at the node that was after the deleted node _when it was deleted_.
If the new node is also deleted, we skip it and, continue to the node after it.
Eventually we either reach a node still in the list or we reach the end of the list.
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D24089516
Pulled By: zdevito
fbshipit-source-id: d01312d11fe381c8d910a83a08582a2219f47dda
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43083
This adds type annotations to all classes, arguments, and returns
for fx. This should make it easier to understand the code, and
encourage users of the library to also write typed code.
Test Plan: Imported from OSS
Reviewed By: ezyang
Differential Revision: D23145853
Pulled By: zdevito
fbshipit-source-id: 648d91df3f9620578c1c51408003cd5152e34514
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43082
Fixes all present errors in mypy. Does not try to add annotations everywhere.
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D23145854
Pulled By: zdevito
fbshipit-source-id: 18e483ed605e89ed8125971e84da1a83128765b7
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42991
Have Node both be a record of the operator in the graph, and the
way we _build_ the graph made it difficult to keep the IR datastructure
separate from the proxying logic in the build.
Among other issues this means that typos when using nodes would add
things to the graph:
```
for node in graph.nodes:
node.grph # does not error, returns an node.Attribute object!
```
This separates the builder into a Proxy object. Graph/Node no longer
need to understand `delegate` objects since they are now just pure IR.
This separates the `symbolic_trace` (proxy.py/symbolic_trace.py) from
the IR (node.py, graph.py).
This also allows us to add `create_arg` to the delegate object,
allowing the customization of how aggregate arguments are handled
when converting to a graph.
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D23099786
Pulled By: zdevito
fbshipit-source-id: 6f207a8c237e5eb2f326b63b0d702c3ebcb254e4