From a76bb5d84d017cc2cd8b95e13b4ea2637457bcc3 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Wed, 22 Nov 2023 22:13:48 +0000 Subject: [PATCH] Add support for models with mutated buffer on torch.onnx.dynamo_export (#112272) This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model. The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model. However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output. This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272 Approved by: https://github.com/titaiwangms, https://github.com/BowenBao --- test/onnx/onnx_test_common.py | 10 +++--- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 31 ++++++++++++++++ .../fx/torch_export_graph_extractor.py | 4 +++ torch/onnx/_internal/io_adapter.py | 36 +++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index f664e7e84a4..0be906fada7 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -436,16 +436,18 @@ def _compare_pytorch_onnx_with_ort( ref_input_args = input_args ref_input_kwargs = input_kwargs - ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( - ref_model(*ref_input_args, **ref_input_kwargs) - ) - + # ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. + # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. + # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() ort_outputs = onnx_program(*input_args, **input_kwargs) + ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs) + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs) if len(ref_outputs) != len(ort_outputs): raise AssertionError( f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" ) + for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close( ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 26fa6f215be..dcede4718a8 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -942,6 +942,37 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): loaded_exported_program, (x,), skip_dynamic_shapes_check=True ) + @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( + "Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. " + "github issue: https://github.com/pytorch/pytorch/issues/114406" + ) + def test_exported_program_as_input_lifting_buffers_mutation(self): + for persistent in (True, False): + + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "my_buffer", torch.tensor(4.0), persistent=persistent + ) + + def forward(self, x, b): + output = x + b + ( + self.my_buffer.add_(1.0) + 3.0 + ) # Mutate buffer through in-place addition + return output + + inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3)) + model = CustomModule() + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + # Buffer will be mutated after the first iteration + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + def _parameterized_class_attrs_and_values_with_fake_options(): input_values = [] diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index 51c31560b14..5f1fbb5c748 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -63,6 +63,10 @@ class TorchExport(exporter.FXGraphExtractor): # tensor, etc), we flatten the collection and register each element as output. options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) + options.fx_tracer.output_adapter.append_step( + io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model) + ) + # Export FX graph to ONNX ModelProto. return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value] diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 45134505000..28db50a5b58 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -550,3 +550,39 @@ class PrependParamsAndBuffersAotAutogradInputStep(InputAdaptStep): if model_kwargs: return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs) return updated_args, {} + + +class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): + """Prepend model's mutated buffers to the user output. + + :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they + must be added to the user output after the model is executed. + + Args: + model: The PyTorch model with mutated buffers. + """ + + def __init__(self, model: torch_export.ExportedProgram): + assert isinstance( + model, torch_export.ExportedProgram + ), "'model' must be a torch.export.ExportedProgram." + self.model = model + + def apply(self, model_outputs: Any) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + + Returns: + flattened_outputs: The flattened model outputs. + """ + + ordered_buffers = tuple( + self.model.state_dict[name] + for name in self.model.graph_signature.buffers_to_mutate.values() + ) + + # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. + updated_outputs = (*ordered_buffers, *model_outputs) + return updated_outputs