[aps] skip version check for export IR. (#140573)

Summary: mitigating potential export compatibility issue for production (temporarily).

Test Plan: CI

Differential Revision: D65890958

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140573
Approved by: https://github.com/desertfire
This commit is contained in:
Zhengxu Chen 2024-11-14 17:13:42 +00:00 committed by PyTorch MergeBot
parent dcf22fa58c
commit add6bb2e96

View File

@ -2315,16 +2315,19 @@ class ExportedProgramDeserializer(metaclass=Final):
state_dict: Union[Dict[str, torch.Tensor], bytes],
constants: Union[Dict[str, torch.Tensor], bytes],
example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
*,
_unsafe_skip_version_check=False,
) -> ep.ExportedProgram:
assert isinstance(exported_program, ExportedProgram)
version = exported_program.schema_version
# TODO(zhxchen17) blocked on thrift schema refactor
if version.major != SCHEMA_VERSION[0] and not (version.major == 0 and version.minor == 0):
raise SerializeError(
f"Serialized schema version {exported_program.schema_version} "
f"does not match our current schema version {SCHEMA_VERSION}."
)
if not _unsafe_skip_version_check:
raise SerializeError(
f"Serialized schema version {exported_program.schema_version} "
f"does not match our current schema version {SCHEMA_VERSION}."
)
symbol_name_to_range = {
k: symbolic_shapes.ValueRanges(
@ -2448,6 +2451,8 @@ def _dict_to_dataclass(cls, data):
def deserialize(
artifact: SerializedArtifact,
expected_opset_version: Optional[Dict[str, int]] = None,
*,
_unsafe_skip_version_check=False,
) -> ep.ExportedProgram:
assert isinstance(artifact.exported_program, bytes)
exported_program_str = artifact.exported_program.decode("utf-8")
@ -2460,6 +2465,7 @@ def deserialize(
artifact.state_dict,
artifact.constants,
artifact.example_inputs,
_unsafe_skip_version_check=_unsafe_skip_version_check,
)
)