mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.
```
class Codegen(object):
def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
"""
def generate_output(self, output_args: Argument) -> str:
"""
Given the output arguments, generates the return statement of the FX function.
"""
def process_inputs(self, args: Any) -> Any:
"""
Transforms the inputs so that the graph can take them as arguments, as
non-default codegen may result in the inputs to the function being
different from the inputs to the graph.
If the graph was directly runnable, this invariant should hold true
`f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
"""
def process_outputs(self, outputs: Any) -> Any:
"""
Transforms the outputs of the graph to be identical to the codegen.
See ``process_inputs`` for more details.
"""
def additional_globals(self) -> List[Tuple[str, Any]]:
"""
If your codegen uses extra global values, add them here.
For example, return ['List', typing.List] if you need ``List`` in the global context.
"""
```
So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
class ListCodeGen(CodeGen):
def generate_prologue(self, free_vars, maybe_return_annotation):
lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
{', '.join(free_vars)} = args_list"""
return lst_unpack
def additional_globals(self):
return [('List', typing.List)]
def process_inputs(self, *inputs):
assert(len(inputs) == 1)
return inputs[0]
```
and
```
def f(a, b):
return a + b
nf = fx.symbolic_trace(f)
nf.graph.set_codegen(ListCodeGen())
nf.recompile()
print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
a, b = args_list
add = a + b; a = b = None
return add
```
Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566
Reviewed By: desertfire
Differential Revision: D34160424
Pulled By: Chillee
fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| TestAutograd.test_function-x_grad_desc.expect | ||
| TestAutograd.test_function-y_grad_desc.expect | ||
| TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect | ||
| TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect | ||
| TestJit.test_cu_escaped_number.expect | ||
| TestJit.test_import_method.expect | ||
| TestJit.test_non_ascii_string.expect | ||
| TestJit.test_pretty_printer-empty_float_list_test.expect | ||
| TestJit.test_pretty_printer-empty_int_list_test.expect | ||
| TestJit.test_pretty_printer-if_one.expect | ||
| TestJit.test_pretty_printer-if_test.expect | ||
| TestJit.test_pretty_printer-loop_use_test.expect | ||
| TestJit.test_pretty_printer-print_weird_test.expect | ||
| TestJit.test_pretty_printer-python_op_name_test.expect | ||
| TestJit.test_pretty_printer-while_if_test.expect | ||
| TestJit.test_pretty_printer-while_test.expect | ||
| TestPytorchExportModes.test_aten_fallback.expect | ||
| TestPytorchExportModes.test_onnx_aten.expect | ||
| TestScript.test_annot_ast_mypy_fn.expect | ||
| TestScript.test_annot_ast_mypy_method.expect | ||
| TestScript.test_annot_ast_py3_fn.expect | ||
| TestScript.test_annot_ast_py3_method.expect | ||
| TestScript.test_annot_string_mypy_fn.expect | ||
| TestScript.test_annot_string_mypy_method.expect | ||
| TestScript.test_annot_string_py3_fn.expect | ||
| TestScript.test_annot_string_py3_method.expect | ||
| TestScript.test_annotated_script_fn.expect | ||
| TestScript.test_annotated_script_method.expect | ||
| TestScript.test_format-stdout.expect | ||
| TestScript.test_listconstruct_erasure.expect | ||
| TestScript.test_parser_type_annotations_comment.expect | ||
| TestScript.test_parser_type_annotations.expect | ||
| TestScript.test_print-stdout.expect | ||
| TestScript.test_python_frontend_py2.expect | ||
| TestScript.test_python_frontend_py3.expect | ||
| TestScript.test_python_frontend.expect | ||
| TestScript.test_string_print-stdout.expect | ||
| TestScript.test_torch_dot_tensor_annotation.expect | ||
| TestSparseCPU.test_print_coalesced_cpu_float64.expect | ||
| TestSparseCPU.test_print_uncoalesced_cpu_float64.expect | ||
| TestSparseCSRCPU.test_sparse_csr_print_cpu.expect | ||
| TestSparseCSRCUDA.test_sparse_csr_print_cuda.expect | ||
| TestSparseCUDA.test_print_coalesced_cuda_float64.expect | ||
| TestSparseCUDA.test_print_uncoalesced_cuda_float64.expect | ||
| TestTensorBoard.test_audio.expect | ||
| TestTensorBoard.test_caffe2_simple_cnnmodel.expect | ||
| TestTensorBoard.test_caffe2_simple_model.expect | ||
| TestTensorBoard.test_histogram_auto.expect | ||
| TestTensorBoard.test_histogram_doane.expect | ||
| TestTensorBoard.test_histogram_fd.expect | ||
| TestTensorBoard.test_hparams_bool.expect | ||
| TestTensorBoard.test_hparams_number.expect | ||
| TestTensorBoard.test_hparams_string.expect | ||
| TestTensorBoard.test_image_with_3_channel_batched.expect | ||
| TestTensorBoard.test_image_with_boxes.expect | ||
| TestTensorBoard.test_image_with_one_channel_batched.expect | ||
| TestTensorBoard.test_image_with_one_channel.expect | ||
| TestTensorBoard.test_image_without_channel.expect | ||
| TestTensorBoard.test_mesh.expect | ||
| TestTensorBoard.test_nested_nn_squential.expect | ||
| TestTensorBoard.test_pr_curve_raw.expect | ||
| TestTensorBoard.test_pr_curve.expect | ||
| TestTensorBoard.test_pytorch_graph.expect | ||
| TestTensorBoard.test_scalar_new_style.expect | ||
| TestTensorBoard.test_text.expect | ||
| TestTensorBoard.test_video.expect | ||
| TestTorch.test_is_nonzero-empty.expect | ||
| TestTorch.test_is_nonzero-multiple.expect | ||
| TestTorch.test_print-non_contiguous.expect | ||