mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add decomposition for dynamo_export + ExportedProgram and remove None from input (#112444)
This PR introduces the ability to produce GraphModules with Core ATen IR only through decompositions. It also removes `None` from user inputs as ONNX does not supports them Tests for these features will be executed when #112289 is merged, but for reference, they are as below: ```python def test_log_sigmoid(self): # This produces op as `torch.ops.aten.log_sigmoid_forward`, instead of the more # conventional `torch.ops.aten.log_sigmoid`. class Model(torch.nn.Module): def __init__(self): super().__init__() self.m = torch.nn.LogSigmoid() def forward(self, x): return self.m(x) input = torch.randn(2) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( Model(), (input,), model_type=self.model_type ) def test_none_input(self): class NoneInputModel(torch.nn.Module): def forward( self, x: torch.Tensor, y: Optional[torch.Tensor], z: torch.Tensor ): if y is None: return x + z return x + y + z self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( NoneInputModel(), (torch.randn(1, 2), None, torch.randn(1, 2)), model_type=self.model_type, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112444 Approved by: https://github.com/BowenBao
This commit is contained in:
parent
6c19de07cd
commit
01e4984bac
|
|
@ -43,6 +43,8 @@ class TorchExport(exporter.FXGraphExtractor):
|
|||
# kwargs=model_kwargs, # type: ignore[arg-type]
|
||||
# )
|
||||
|
||||
model = model.run_decompositions(options.decomposition_table)
|
||||
|
||||
# Export FX graph to ONNX ModelProto.
|
||||
self.input_adapter.append_step(
|
||||
io_adapter.FlattenInputWithTreeSpecValidationInputStep()
|
||||
|
|
@ -50,6 +52,11 @@ class TorchExport(exporter.FXGraphExtractor):
|
|||
self.input_adapter.append_step(
|
||||
io_adapter.PrependParamsAndBuffersAotAutogradInputStep(model)
|
||||
)
|
||||
|
||||
# ONNX does not support None inputs. During graph building, all None inputs
|
||||
# are removed. Here we register this step to input adapter.
|
||||
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
|
||||
|
||||
updated_model_args = self.input_adapter.apply(*model_args, **model_kwargs)
|
||||
|
||||
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user