mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[export] Remove call_spec argument from ExportedProgram ctor. (#111407)
Summary: call_spec arg is not used anymore. Test Plan: CI Reviewed By: SherlockNoMad, tugsbayasgalan Differential Revision: D50335365 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111407 Approved by: https://github.com/izaitsevfb
This commit is contained in:
parent
2bb1692334
commit
17002d25c5
|
|
@ -748,8 +748,6 @@ def _export(
|
|||
gm,
|
||||
gm.graph,
|
||||
export_graph_signature,
|
||||
# TODO(zhxchen17) Remove this field.
|
||||
CallSpec(in_spec, orig_out_spec),
|
||||
# TODO(zhxchen17) Return empty state_dict for functions.
|
||||
params_buffers,
|
||||
range_constraints,
|
||||
|
|
@ -940,7 +938,6 @@ def aot_compile(
|
|||
) for o in user_outputs
|
||||
]
|
||||
),
|
||||
call_spec=copy.deepcopy(ep.call_spec),
|
||||
state_dict={},
|
||||
range_constraints=copy.deepcopy(ep.range_constraints),
|
||||
equality_constraints=copy.deepcopy(ep.equality_constraints),
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ class GradientToUserInputSpec:
|
|||
@dataclass
|
||||
class OutputSpec(_Union):
|
||||
user_output: UserOutputSpec
|
||||
loss_outout: LossOutputSpec
|
||||
loss_output: LossOutputSpec
|
||||
buffer_mutation: BufferMutationSpec
|
||||
gradient_to_parameter: GradientToParameterSpec
|
||||
gradient_to_user_input: GradientToUserInputSpec
|
||||
|
|
|
|||
|
|
@ -1524,7 +1524,6 @@ class ExportedProgramDeserializer:
|
|||
res.graph_module,
|
||||
res.graph_module.graph,
|
||||
res.signature,
|
||||
None, # TODO(zhxchen17) Remove this.
|
||||
state_dict, # type: ignore[arg-type]
|
||||
range_constraints,
|
||||
equality_constraints,
|
||||
|
|
|
|||
|
|
@ -425,7 +425,6 @@ class ExportedProgram:
|
|||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: torch.fx.Graph,
|
||||
graph_signature: ExportGraphSignature,
|
||||
call_spec: Any,
|
||||
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
|
||||
range_constraints: Dict[sympy.Symbol, Any],
|
||||
equality_constraints: List[Tuple[Any, Any]],
|
||||
|
|
@ -772,7 +771,6 @@ class ExportedProgram:
|
|||
gm,
|
||||
gm.graph,
|
||||
new_graph_signature,
|
||||
copy.deepcopy(self.call_spec),
|
||||
self.state_dict,
|
||||
new_range_constraints,
|
||||
new_equality_constraints,
|
||||
|
|
@ -889,7 +887,6 @@ class ExportedProgram:
|
|||
transformed_gm,
|
||||
transformed_gm.graph,
|
||||
_get_updated_graph_signature(self.graph_signature, transformed_gm),
|
||||
copy.deepcopy(self.call_spec),
|
||||
self.state_dict,
|
||||
_get_updated_range_constraints(transformed_gm),
|
||||
copy.deepcopy(self.equality_constraints),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user