Summary:
Sometimes we only want to generate a replacement for a matched pattern
once we know some information about the nodes in the pattern.
So far, we have found this the most useful to do matches based on specific
shapes of tensors flowing into functions.
Use a callback function similar to `match_filters`. By default this isn't used.
Had to make `replacement` a None-able parameter because Callable was
already used to detect a case where a graph needed to be traced.
Differential Revision: D62412628
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135553
Approved by: https://github.com/SherlockNoMad
Part of #134054.
This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Summary:
Allow doing pattern replacement with just an fx.Graph instead of a fx.GraphModule,
which can let callers avoid paying the cost of `recompile()` for a small graph if they
don't need the module.
This is a significant speedup if you use hundreds of small patterns for replacement.
Test Plan: Tested in a diff stacked on top of this: {D50756722}
Reviewed By: SherlockNoMad, angelayi
Differential Revision: D50756723
@diff-train-skip-merge
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112409
Approved by: https://github.com/ZainRizvi
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
Summary: This commit adds support for conv + BN fusion for the
case where conv has no bias. Since the replacement patterns with
and without conv bias are substantially different, we perform the
replacement for each of these two cases separately.
Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_prepare_qat_conv_bn_fusion_no_conv_bias
Reviewers: jerryzh168, kimishpatel
Differential Revision: [D45743510](https://our.internmc.facebook.com/intern/diff/D45743510)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100610
Approved by: https://github.com/jerryzh168
Summary:
This PR adds support for folding bn weights into conv for QAT flow, this is equivalent
to the QAT branch of `from_float` in eager mode quantized conv module: https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/modules/conv.py#L223
Items that needs followup:
* there are some workaround I did because quantize_per_tensor is using float/int args and dynamo does not support these args, need to fix after we change the quantized model representation and also change these args to Tensor
Test Plan: buck2 test @//mode/opt //caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_convert_qat_conv_bn_fusion (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)'
Reviewed By: andrewor14
Differential Revision: D45344281
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100442
Approved by: https://github.com/kimishpatel
Fixes#98974
When `torch.fx.subgraph_rewriter._replace_pattern` is used to remove nodes from a graph, if there are two adjacent matches then after the first removal, the nodes in `InternalMatch.nodes_map` and `placeholder_nodes` become outdated because they contain nodes that were just removed from the graph.
This fix is to update the `match.nodes_map` and `match.placeholder_nodes` using the node changes stored in `match_changed_node`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99039
Approved by: https://github.com/angelayi
Pattern replacement behaves incorrectly when the replacement pattern maps inputs to outputs (such a pattern can be used to replace redundant code). However, current code in `torch.fx.subgraph_rewriter._replace_pattern` causes the list of replacement nodes to include the entire graph before that node, resulting in an exponential slowdown due to recursive calls traversing the entire graph multiple times.
The proposed fix is to add a check in `_replace_pattern` to prevent the call to `get_replacement_nodes`:
```python
for ret_node in copied_returning_nodes:
if ret_node in match.placeholder_nodes:
replacement_nodes.append(ret_node)
else:
get_replacement_nodes(ret_node)
```
Fixes#97817
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97903
Approved by: https://github.com/angelayi
Summary: Modified replace_pattern in the subgraph rewriter to return a list of pairs of matches along with their corresponding replacement nodes in the modified graph (`List[Tuple[Match, List[Node]]]`). This allows us to easily modify the replaced nodes, including setting the metadata.
Test Plan: CI
Differential Revision: D41737056
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90244
Approved by: https://github.com/SherlockNoMad
Continuation after https://github.com/pytorch/pytorch/pull/90163.
Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):
_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.
``` python
import ast
import os
import docstring_parser
for root, dirs, files in os.walk('.'):
for name in files:
if root.startswith("./.git/") or root.startswith("./third_party/"):
continue
if name.endswith(".py"):
full_name = os.path.join(root, name)
with open(full_name, "r") as source:
tree = ast.parse(source.read())
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
all_node_args = node.args.args
if node.args.vararg is not None:
all_node_args.append(node.args.vararg)
if node.args.kwarg is not None:
all_node_args.append(node.args.kwarg)
if node.args.posonlyargs is not None:
all_node_args.extend(node.args.posonlyargs)
if node.args.kwonlyargs is not None:
all_node_args.extend(node.args.kwonlyargs)
args = [a.arg for a in all_node_args]
docstring = docstring_parser.parse(ast.get_docstring(node))
doc_args = [a.arg_name for a in docstring.params]
clean_doc_args = []
for a in doc_args:
clean_a = ""
for c in a.split()[0]:
if c.isalnum() or c == '_':
clean_a += c
if clean_a:
clean_doc_args.append(clean_a)
doc_args = clean_doc_args
for a in doc_args:
if a not in args:
print(full_name, node.lineno, args, doc_args)
break
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
Summary:
att, this is experimental api so not marking it as bc-breaking.
The match will be accepted only if all the filters in the list passes.
Changing the filter arg to be list also allows us to pass in empty list that means no filter, which makes user code cleaner.
Test Plan:
python test/test_fx.py -k test_replace_pattern_with_filters
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D40810943](https://our.internmc.facebook.com/intern/diff/D40810943)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87998
Approved by: https://github.com/SherlockNoMad
Summary:
att, this is experimental api so not marking it as bc-breaking.
The match will be accepted only if all the filters in the list passes.
Changing the filter arg to be list also allows us to pass in empty list that means no filter, which makes user code cleaner.
Test Plan:
python test/test_fx.py -k test_replace_pattern_with_filters
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87257
Approved by: https://github.com/SherlockNoMad
This PR introduces an interface for user defined function that filters the matches in SubgraphRewriter. The function will have the following signature.
callable(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) -> bool
This filter is applied after SubgraphMatcher returns the matches, and before replacement takes place.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86430
Approved by: https://github.com/jerryzh168
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD