Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.
Fixes#35735
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069
Approved by: https://github.com/mikaylagawarecki
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`
I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
Summary:
When we pickle/unpickle graph module in multipy, we would lost modules/attributes that are not referred in the graph. This is because when unpickle fx graph module, we use the stored `__dict__` and the fx graph to create a new graph module. In GraphModule init, we drop any attribute that is not referred in the graph.
This behavior is not ideal because we actually expect a graph module that's exactly the same after unpickling.
Test Plan:
```
buck test mode/opt caffe2/test:fx -- test_preserve_unused_attr_after_unpickle
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Differential Revision: D46976230
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104115
Approved by: https://github.com/houseroad
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`
I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
Summary:
# Context
In TorchRec's train pipeline, we need to fx trace a module to analyze the arguments on the forward call. In order to do this, we need to preserve some sort of meaning with each argument (a key or name of sorts that lets us identify the argument).
The issue is, when you use concrete args, internally, fx will unflatten the arg into it's constituents (to locate PHs).
Given a function that looks like this:
```
def process(batch: Dict[str, torch.Tensor]):
....
symbolic_trace(process, concrete_args: {"batch": {"f1": PH, "f2": PH}})
# function will be rewritten to look like:
def process(batch_1, batch_2): # batch_1 -> "f1", batch_2->"f2"
...
```
When you traverse through the nodes of the graph, the names of the argument nodes to the function are batch_1 and batch_2. **This doesn't mean anything to the user who is fx tracing.** There isn't anything indicating that batch_1 corresponds to key "f1" in the batch input.
# Solution
When fx sees a "PH", it creates a proxy node.
The user does not have direct access to proxy creation, but only through the PH structure.
Attach a piece of metadata, `ph_key`, to the PH when you set it in the concrete args, it will get passed into proxy + node creation. So when you traverse the graph, this metadata sticks onto the node as an attribute. This way you have a way of tagging that "batch_1" as "f1".
Test Plan: added a unit test
Reviewed By: dstaay-fb
Differential Revision: D44947653
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102195
Approved by: https://github.com/PaliC
Summary: Change placeholder check from singleton to instanceof PHBase so you can create your own PH class with metadata
Test Plan: added unit test
Reviewed By: joshuadeng
Differential Revision: D46085128
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102008
Approved by: https://github.com/PaliC
Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.
`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`
Args:
* graph: The graph we want to partition
* wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)
Returns:
* Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.
```
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
```
Example:
Original:
```
x -> linear -> linear -> relu -> linear
```
Traced graph:
```
.graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
```
Result of `get_module_partitions`:
```
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
```
Also added helper function to check if two module partitions are connected:
`check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628
Approved by: https://github.com/cccclai
Summary: There're some customized functions that we would also like to keep during eliminate dead code pass. Add a function to help us to do.
Test Plan: Added a unit test
Differential Revision: D44273630
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97288
Approved by: https://github.com/houseroad
Summary:
Currently torch.fx support Modules with input of namedtuple/dataclass, return as namedtuple, but does not allow Module.forward to return a dataclass, running `test_trace_return_dataclass` without this change will have following error:
NotImplementedError: argument of type: <class 'test_fx.TestFX.test_trace_return_dataclass.<locals>.MyOutput'>
File "test_trace_return_dataclass
traced_graph = symbolic_trace(module).graph
File "test/__fx__/fx#link-tree/torch/fx/_symbolic_trace.py", line 1114, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "test/__fx__/fx#link-tree/torch/fx/_symbolic_trace.py", line 783, in trace
(self.create_arg(fn(*args)),),
File "test/__fx__/fx#link-tree/torch/fx/_symbolic_trace.py", line 378, in create_arg
return super().create_arg(a)
File "test/__fx__/fx#link-tree/torch/fx/proxy.py", line 269, in create_arg
raise NotImplementedError(f"argument of type: {type(a)}")
this diff handle dataclass type.
Test Plan:
buck test @//mode/opt @//mode/inplace //caffe2/test:fx -- test_trace_
graph():
%d : torch.Tensor [#users=1] = placeholder[target=d]
%my_output : [#users=1] = call_function[target=test_fx.MyOutput](args = (), kwargs = {foo: %d, bar: %d})
return my_output
Differential Revision: D44916519
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99576
Approved by: https://github.com/suo
Twice this week I have had people confuse "operator defined with Python
operator registration aka torch.library" and "PyOperator which is used
to define control flow operators and other operators that cannot be
represented in JIT schema." Renaming PyOperator for clarity.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97493
Approved by: https://github.com/SherlockNoMad
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
To match nodes within the graph, the matcher currently flattens the arguments and compares each argument against each other. However, if it believes that a list input contains all literals, it will not flatten the list and will instead compare the list directly against each other. It determines if a list is a literal by checking if the first element is a node. However this doesn't work in some cases (like the test cases I added).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94375
Approved by: https://github.com/SherlockNoMad
Fixes https://github.com/pytorch/pytorch/issues/89421
The strategy is to patch the given function wrapped with `@torch.fx.wrap` so that if a tensor tracer is active, we will `proxy_call` the function.
`proxy_call` will also skip certain checks if the function to proxy call is not a torch op (checked with `isinstance(.., OpOverload)`.
@IvanYashchuk @ezyang @Chillee
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93273
Approved by: https://github.com/ezyang
Summary:
One of such places where circular reference can occur is: _load_state_dict_pre_hooks contains a _WrappedHook, _WrappedHook has a weakref to the same module.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93038
Approved by: https://github.com/jerryzh168
In 3.11 bytecode size is not constant, so in order to get from `f_lasti` to opcode index, one need to search for the closes offset in disassembled instructions.
Update `_patch_function` to construct code with all the properties that exist in 3.11 runtime.
Update `_torchscript_schema_to_signature` to mark `from` named arg as positional argument only, as this is a reserved keyword in Python and as such checked by `inspect` package in 3.11
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92895
Approved by: https://github.com/albanD
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
Summary:
This PR supports the following feature for QConfigMapping:
```
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
backend_config = get_qnnpack_pt2e_backend_config()
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
```
which means users want to set the qconfig for all calls to `torch.nn.Conv2d` to use `qconfig`, note this is only verified for the case when the module is broken down to a single aten op right now, e.g. torch.nn.Conv2d will be torch.ops.aten.convolution op when traced through. will need to support more complicated modules that is broken down to multiple operators later, e.g. (MaxPool)
Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_qconfig_module_type
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92355
Approved by: https://github.com/jcaip
This PR:
- Updates the docs to say it is deprecated
- Raises a UserWarning
- Changes most of the callsites inside PyTorch to use
torch.func.functional_call, minus the test_stateless testing.
The motivation behind this is that we can now align behind a single
functional_call API in PyTorch.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92280
Approved by: https://github.com/albanD