Add decomposition for dynamo_export + ExportedProgram and remove None from input (#112444)

This PR introduces the ability to produce GraphModules with Core ATen IR only through decompositions. It also removes `None` from user inputs as ONNX does not supports them

Tests for these features will be executed when #112289 is merged, but for reference, they are as below:

```python
    def test_log_sigmoid(self):
        # This produces op as `torch.ops.aten.log_sigmoid_forward`, instead of the more
        # conventional `torch.ops.aten.log_sigmoid`.
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.m = torch.nn.LogSigmoid()

            def forward(self, x):
                return self.m(x)

        input = torch.randn(2)
        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            Model(), (input,), model_type=self.model_type
        )

    def test_none_input(self):
        class NoneInputModel(torch.nn.Module):
            def forward(
                self, x: torch.Tensor, y: Optional[torch.Tensor], z: torch.Tensor
            ):
                if y is None:
                    return x + z
                return x + y + z

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            NoneInputModel(),
            (torch.randn(1, 2), None, torch.randn(1, 2)),
            model_type=self.model_type,
        )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112444
Approved by: https://github.com/BowenBao
This commit is contained in:
Thiago Crepaldi 2023-11-01 23:54:06 +00:00 committed by PyTorch MergeBot
parent 6c19de07cd
commit 01e4984bac

View File

@ -43,6 +43,8 @@ class TorchExport(exporter.FXGraphExtractor):
# kwargs=model_kwargs, # type: ignore[arg-type]
# )
model = model.run_decompositions(options.decomposition_table)
# Export FX graph to ONNX ModelProto.
self.input_adapter.append_step(
io_adapter.FlattenInputWithTreeSpecValidationInputStep()
@ -50,6 +52,11 @@ class TorchExport(exporter.FXGraphExtractor):
self.input_adapter.append_step(
io_adapter.PrependParamsAndBuffersAotAutogradInputStep(model)
)
# ONNX does not support None inputs. During graph building, all None inputs
# are removed. Here we register this step to input adapter.
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
updated_model_args = self.input_adapter.apply(*model_args, **model_kwargs)
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of