diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index f664e7e84a4..0be906fada7 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -436,16 +436,18 @@ def _compare_pytorch_onnx_with_ort( ref_input_args = input_args ref_input_kwargs = input_kwargs - ref_outputs = onnx_program.adapt_torch_outputs_to_onnx( - ref_model(*ref_input_args, **ref_input_kwargs) - ) - + # ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. + # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. + # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() ort_outputs = onnx_program(*input_args, **input_kwargs) + ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs) + ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs) if len(ref_outputs) != len(ort_outputs): raise AssertionError( f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" ) + for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close( ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 26fa6f215be..dcede4718a8 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -942,6 +942,37 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): loaded_exported_program, (x,), skip_dynamic_shapes_check=True ) + @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( + "Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. " + "github issue: https://github.com/pytorch/pytorch/issues/114406" + ) + def test_exported_program_as_input_lifting_buffers_mutation(self): + for persistent in (True, False): + + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "my_buffer", torch.tensor(4.0), persistent=persistent + ) + + def forward(self, x, b): + output = x + b + ( + self.my_buffer.add_(1.0) + 3.0 + ) # Mutate buffer through in-place addition + return output + + inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3)) + model = CustomModule() + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + # Buffer will be mutated after the first iteration + self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( + model, inputs, skip_dynamic_shapes_check=True + ) + def _parameterized_class_attrs_and_values_with_fake_options(): input_values = [] diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index 51c31560b14..5f1fbb5c748 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -63,6 +63,10 @@ class TorchExport(exporter.FXGraphExtractor): # tensor, etc), we flatten the collection and register each element as output. options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) + options.fx_tracer.output_adapter.append_step( + io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model) + ) + # Export FX graph to ONNX ModelProto. return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value] diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 45134505000..28db50a5b58 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -550,3 +550,39 @@ class PrependParamsAndBuffersAotAutogradInputStep(InputAdaptStep): if model_kwargs: return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs) return updated_args, {} + + +class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): + """Prepend model's mutated buffers to the user output. + + :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they + must be added to the user output after the model is executed. + + Args: + model: The PyTorch model with mutated buffers. + """ + + def __init__(self, model: torch_export.ExportedProgram): + assert isinstance( + model, torch_export.ExportedProgram + ), "'model' must be a torch.export.ExportedProgram." + self.model = model + + def apply(self, model_outputs: Any) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + + Returns: + flattened_outputs: The flattened model outputs. + """ + + ordered_buffers = tuple( + self.model.state_dict[name] + for name in self.model.graph_signature.buffers_to_mutate.values() + ) + + # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. + updated_outputs = (*ordered_buffers, *model_outputs) + return updated_outputs