From 17002d25c52f20c89b155881f8923f09bccfc2d0 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 17 Oct 2023 21:01:37 +0000 Subject: [PATCH] [export] Remove call_spec argument from ExportedProgram ctor. (#111407) Summary: call_spec arg is not used anymore. Test Plan: CI Reviewed By: SherlockNoMad, tugsbayasgalan Differential Revision: D50335365 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111407 Approved by: https://github.com/izaitsevfb --- torch/_export/__init__.py | 3 --- torch/_export/serde/schema.py | 2 +- torch/_export/serde/serialize.py | 1 - torch/export/exported_program.py | 3 --- 4 files changed, 1 insertion(+), 8 deletions(-) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 9984d7def2f..e61e4339db7 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -748,8 +748,6 @@ def _export( gm, gm.graph, export_graph_signature, - # TODO(zhxchen17) Remove this field. - CallSpec(in_spec, orig_out_spec), # TODO(zhxchen17) Return empty state_dict for functions. params_buffers, range_constraints, @@ -940,7 +938,6 @@ def aot_compile( ) for o in user_outputs ] ), - call_spec=copy.deepcopy(ep.call_spec), state_dict={}, range_constraints=copy.deepcopy(ep.range_constraints), equality_constraints=copy.deepcopy(ep.equality_constraints), diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 041e48143eb..f2a7dc8bdab 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -257,7 +257,7 @@ class GradientToUserInputSpec: @dataclass class OutputSpec(_Union): user_output: UserOutputSpec - loss_outout: LossOutputSpec + loss_output: LossOutputSpec buffer_mutation: BufferMutationSpec gradient_to_parameter: GradientToParameterSpec gradient_to_user_input: GradientToUserInputSpec diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 98a5ddafd38..496f86e23c7 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1524,7 +1524,6 @@ class ExportedProgramDeserializer: res.graph_module, res.graph_module.graph, res.signature, - None, # TODO(zhxchen17) Remove this. state_dict, # type: ignore[arg-type] range_constraints, equality_constraints, diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f0079f2e581..21cc8389916 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -425,7 +425,6 @@ class ExportedProgram: root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, graph_signature: ExportGraphSignature, - call_spec: Any, state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], range_constraints: Dict[sympy.Symbol, Any], equality_constraints: List[Tuple[Any, Any]], @@ -772,7 +771,6 @@ class ExportedProgram: gm, gm.graph, new_graph_signature, - copy.deepcopy(self.call_spec), self.state_dict, new_range_constraints, new_equality_constraints, @@ -889,7 +887,6 @@ class ExportedProgram: transformed_gm, transformed_gm.graph, _get_updated_graph_signature(self.graph_signature, transformed_gm), - copy.deepcopy(self.call_spec), self.state_dict, _get_updated_range_constraints(transformed_gm), copy.deepcopy(self.equality_constraints),