Did some easy fixes from enabling TRY200. Most of these seem like oversights instead of intentional. The proper way to silence intentional errors is with `from None` to note that you thought about whether it should contain the cause and decided against it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111496
Approved by: https://github.com/malfet
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.
Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison, https://github.com/Chillee
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.
Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`
I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
Summary:
When we pickle/unpickle graph module in multipy, we would lost modules/attributes that are not referred in the graph. This is because when unpickle fx graph module, we use the stored `__dict__` and the fx graph to create a new graph module. In GraphModule init, we drop any attribute that is not referred in the graph.
This behavior is not ideal because we actually expect a graph module that's exactly the same after unpickling.
Test Plan:
```
buck test mode/opt caffe2/test:fx -- test_preserve_unused_attr_after_unpickle
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Differential Revision: D46976230
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104115
Approved by: https://github.com/houseroad
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`
I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
Summary:
Previously prepare_fx returns an ObservedGraphModule and convert_fx returns a QuantizedGraphModule,
this is to preserve the attributes since torch.fx.GraphModule did not preserve them, after https://github.com/pytorch/pytorch/pull/92062
we are preserving `model.meta`, so we can store the attributes in model.meta now to preserve them.
With this, we don't need to create a new type of GraphModule in these functions and can use GraphModule directly, this
is useful for quantization in pytorch 2.0 flow, if other transformations are using GraphModule as well, the quantization passes will be composable with them
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels
python test/test_quantization.py TestQuantizePT2E
Imported from OSS
Differential Revision: D42979722
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94412
Approved by: https://github.com/vkuzo
Summary:
One of such places where circular reference can occur is: _load_state_dict_pre_hooks contains a _WrappedHook, _WrappedHook has a weakref to the same module.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93038
Approved by: https://github.com/jerryzh168
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
Before:
```
Traceback (most recent call last):
File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
return self.executor.run(*executor_args)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
return super().run(*args, initial_env=initial_env)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
self.env[node] = self.run_node(node)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
return super().call_module(target, args, kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
return submod(*args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e.with_traceback(None)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
After:
```
Traceback (most recent call last):
File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
return self.executor.run(*executor_args)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
return super().run(*args, initial_env=initial_env)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
self.env[node] = self.run_node(node)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
return super().call_module(target, args, kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
return submod(*args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 622, in wrapped_call
return super(cls, self).__call__(*args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "<eval_with_key>.42", line 74, in forward
File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/utils/fx.py", line 180, in wrapper
return func(*args, **kwargs)
File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/modeling_utils.py", line 256, in create_extended_attention_mask_for_decoder
causal_mask = causal_mask.to(attention_mask.dtype)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
The last lines of stack trace show where the problem is
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74655
Approved by: https://github.com/ansley, https://github.com/rohan-varma
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.
```
class Codegen(object):
def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
"""
def generate_output(self, output_args: Argument) -> str:
"""
Given the output arguments, generates the return statement of the FX function.
"""
def process_inputs(self, args: Any) -> Any:
"""
Transforms the inputs so that the graph can take them as arguments, as
non-default codegen may result in the inputs to the function being
different from the inputs to the graph.
If the graph was directly runnable, this invariant should hold true
`f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
"""
def process_outputs(self, outputs: Any) -> Any:
"""
Transforms the outputs of the graph to be identical to the codegen.
See ``process_inputs`` for more details.
"""
def additional_globals(self) -> List[Tuple[str, Any]]:
"""
If your codegen uses extra global values, add them here.
For example, return ['List', typing.List] if you need ``List`` in the global context.
"""
```
So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
class ListCodeGen(CodeGen):
def generate_prologue(self, free_vars, maybe_return_annotation):
lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
{', '.join(free_vars)} = args_list"""
return lst_unpack
def additional_globals(self):
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
return inputs[0]
```
and
```
def f(a, b):
return a + b
nf = fx.symbolic_trace(f)
nf.graph.set_codegen(ListCodeGen())
nf.recompile()
print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
a, b = args_list
add = a + b; a = b = None
return add
```
Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566
Reviewed By: desertfire
Differential Revision: D34160424
Pulled By: Chillee
fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa1)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66430
On the whole, I'm not totally satisfied with this approach. I think we should be building a prefix tree data structure during initial iteration over the submodules and querying that when deleting submodules. But I think this approach works and I want to see if we can get it in before 1.10
Test Plan: Imported from OSS
Reviewed By: Chillee
Differential Revision: D31546137
Pulled By: jamesr66a
fbshipit-source-id: f08b8409a3cf511277017ccccb916097b7c4c4fe