pytorch/test/expect
Horace He d635d0f86e Refactor FX codegen into extensible Codegen object (#72566)
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.

```
class Codegen(object):
    def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
        """
        Given the free variables and a return annotation, generates the beginning of the FX function.
        By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
        """
    def generate_output(self, output_args: Argument) -> str:
        """
        Given the output arguments, generates the return statement of the FX function.
        """
    def process_inputs(self, args: Any) -> Any:
        """
        Transforms the inputs so that the graph can take them as arguments, as
        non-default codegen may result in the inputs to the function being
        different from the inputs to the graph.

        If the graph was directly runnable, this invariant should hold true
        `f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
        """
    def process_outputs(self, outputs: Any) -> Any:
        """
        Transforms the outputs of the graph to be identical to the codegen.

        See ``process_inputs`` for more details.
        """
    def additional_globals(self) -> List[Tuple[str, Any]]:
        """
        If your codegen uses extra global values, add them here.
        For example, return ['List', typing.List] if you need ``List`` in the global context.
        """
```

So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
        class ListCodeGen(CodeGen):
            def generate_prologue(self, free_vars, maybe_return_annotation):
                lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
    {', '.join(free_vars)} = args_list"""
                return lst_unpack

            def additional_globals(self):
                return [('List', typing.List)]

            def process_inputs(self, *inputs):
                assert(len(inputs) == 1)
                return inputs[0]
```
and
```
        def f(a, b):
            return a + b

        nf = fx.symbolic_trace(f)
        nf.graph.set_codegen(ListCodeGen())
        nf.recompile()
        print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
    a, b = args_list
    add = a + b;  a = b = None
    return add
```

Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566

Reviewed By: desertfire

Differential Revision: D34160424

Pulled By: Chillee

fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa1)
2022-02-11 18:13:29 +00:00
..
__init__.py remediation of S205607 2020-07-17 17:19:47 -07:00
TestAutograd.test_function-x_grad_desc.expect simplify op name determination into a single forward pass (#64261) 2021-09-02 07:32:11 -07:00
TestAutograd.test_function-y_grad_desc.expect simplify op name determination into a single forward pass (#64261) 2021-09-02 07:32:11 -07:00
TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect Refactor FX codegen into extensible Codegen object (#72566) 2022-02-11 18:13:29 +00:00
TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect [FX] Add a default_value arg to Graph.placeholder and fix split_module (#71016) 2022-01-12 14:03:17 -08:00
TestJit.test_cu_escaped_number.expect [JIT][write path] Make NoneType annotation_str emit NoneType instead of None (#54746) 2021-04-12 17:36:45 -07:00
TestJit.test_import_method.expect Enable backward/forward compatibility for TS runtime (#57498) 2021-05-07 15:41:45 -07:00
TestJit.test_non_ascii_string.expect Generalize constant_table from tensor only to ivalue (#40718) 2020-07-09 09:09:40 -07:00
TestJit.test_pretty_printer-empty_float_list_test.expect
TestJit.test_pretty_printer-empty_int_list_test.expect
TestJit.test_pretty_printer-if_one.expect
TestJit.test_pretty_printer-if_test.expect
TestJit.test_pretty_printer-loop_use_test.expect Enable backward/forward compatibility for TS runtime (#57498) 2021-05-07 15:41:45 -07:00
TestJit.test_pretty_printer-print_weird_test.expect [JIT][write path] Make NoneType annotation_str emit NoneType instead of None (#54746) 2021-04-12 17:36:45 -07:00
TestJit.test_pretty_printer-python_op_name_test.expect
TestJit.test_pretty_printer-while_if_test.expect Enable backward/forward compatibility for TS runtime (#57498) 2021-05-07 15:41:45 -07:00
TestJit.test_pretty_printer-while_test.expect Enable backward/forward compatibility for TS runtime (#57498) 2021-05-07 15:41:45 -07:00
TestPytorchExportModes.test_aten_fallback.expect
TestPytorchExportModes.test_onnx_aten.expect
TestScript.test_annot_ast_mypy_fn.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_ast_mypy_method.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_ast_py3_fn.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_ast_py3_method.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_string_mypy_fn.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_string_mypy_method.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_string_py3_fn.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annot_string_py3_method.expect Support Union in TorchScript (#64234) 2021-09-03 06:12:24 -07:00
TestScript.test_annotated_script_fn.expect
TestScript.test_annotated_script_method.expect Serialize ClassType as its qualname 2019-11-20 16:17:26 -08:00
TestScript.test_format-stdout.expect
TestScript.test_listconstruct_erasure.expect
TestScript.test_parser_type_annotations_comment.expect
TestScript.test_parser_type_annotations.expect
TestScript.test_print-stdout.expect Revert "Revert D18171156: Merge Tensor and Variable." (#29299) 2019-11-08 09:11:20 -08:00
TestScript.test_python_frontend_py2.expect
TestScript.test_python_frontend_py3.expect Support custom exception message (#41907) 2020-08-01 13:03:45 -07:00
TestScript.test_python_frontend.expect add support for multiple assignment statements (#24477) 2019-08-22 10:17:14 -07:00
TestScript.test_string_print-stdout.expect Revert "Revert D18171156: Merge Tensor and Variable." (#29299) 2019-11-08 09:11:20 -08:00
TestScript.test_torch_dot_tensor_annotation.expect
TestSparseCPU.test_print_coalesced_cpu_float64.expect Eliminate global usage of torch.set_default_dtype in sparse test (#56393) 2021-04-27 15:23:14 -07:00
TestSparseCPU.test_print_uncoalesced_cpu_float64.expect Eliminate global usage of torch.set_default_dtype in sparse test (#56393) 2021-04-27 15:23:14 -07:00
TestSparseCSRCPU.test_sparse_csr_print_cpu.expect CUDA support in the CSR layout: constructors (#59010) 2021-05-26 16:39:43 -07:00
TestSparseCSRCUDA.test_sparse_csr_print_cuda.expect CUDA support in the CSR layout: constructors (#59010) 2021-05-26 16:39:43 -07:00
TestSparseCUDA.test_print_coalesced_cuda_float64.expect Eliminate global usage of torch.set_default_dtype in sparse test (#56393) 2021-04-27 15:23:14 -07:00
TestSparseCUDA.test_print_uncoalesced_cuda_float64.expect Eliminate global usage of torch.set_default_dtype in sparse test (#56393) 2021-04-27 15:23:14 -07:00
TestTensorBoard.test_audio.expect
TestTensorBoard.test_caffe2_simple_cnnmodel.expect
TestTensorBoard.test_caffe2_simple_model.expect
TestTensorBoard.test_histogram_auto.expect
TestTensorBoard.test_histogram_doane.expect
TestTensorBoard.test_histogram_fd.expect
TestTensorBoard.test_hparams_bool.expect [tensorboard] Let hparam render values correctly (#31544) 2020-05-08 00:05:16 -07:00
TestTensorBoard.test_hparams_number.expect [tensorboard] Let hparam render values correctly (#31544) 2020-05-08 00:05:16 -07:00
TestTensorBoard.test_hparams_string.expect [tensorboard] Let hparam render values correctly (#31544) 2020-05-08 00:05:16 -07:00
TestTensorBoard.test_image_with_3_channel_batched.expect
TestTensorBoard.test_image_with_boxes.expect
TestTensorBoard.test_image_with_one_channel_batched.expect
TestTensorBoard.test_image_with_one_channel.expect
TestTensorBoard.test_image_without_channel.expect
TestTensorBoard.test_mesh.expect added mesh plugin (#24039) 2019-08-09 10:22:43 -07:00
TestTensorBoard.test_nested_nn_squential.expect Remove some unnecessary python functional wrappers (#61608) 2022-02-01 16:59:26 +00:00
TestTensorBoard.test_pr_curve_raw.expect
TestTensorBoard.test_pr_curve.expect
TestTensorBoard.test_pytorch_graph.expect [jit] Set debug name for value coming out of GetAttr nodes. (#59123) 2021-06-09 12:24:55 -07:00
TestTensorBoard.test_scalar_new_style.expect [TB] Support writing new style scalar (#53496) 2021-03-12 19:03:13 -08:00
TestTensorBoard.test_text.expect
TestTensorBoard.test_video.expect
TestTorch.test_is_nonzero-empty.expect Clean up error handling in is_nonzero and where in TensorCompare.cpp (#38150) 2020-05-13 20:19:40 -07:00
TestTorch.test_is_nonzero-multiple.expect Clean up error handling in is_nonzero and where in TensorCompare.cpp (#38150) 2020-05-13 20:19:40 -07:00
TestTorch.test_print-non_contiguous.expect