mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[aps] skip version check for export IR. (#140573)
Summary: mitigating potential export compatibility issue for production (temporarily). Test Plan: CI Differential Revision: D65890958 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140573 Approved by: https://github.com/desertfire
This commit is contained in:
parent
dcf22fa58c
commit
add6bb2e96
|
|
@ -2315,16 +2315,19 @@ class ExportedProgramDeserializer(metaclass=Final):
|
||||||
state_dict: Union[Dict[str, torch.Tensor], bytes],
|
state_dict: Union[Dict[str, torch.Tensor], bytes],
|
||||||
constants: Union[Dict[str, torch.Tensor], bytes],
|
constants: Union[Dict[str, torch.Tensor], bytes],
|
||||||
example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
|
example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
|
||||||
|
*,
|
||||||
|
_unsafe_skip_version_check=False,
|
||||||
) -> ep.ExportedProgram:
|
) -> ep.ExportedProgram:
|
||||||
assert isinstance(exported_program, ExportedProgram)
|
assert isinstance(exported_program, ExportedProgram)
|
||||||
version = exported_program.schema_version
|
version = exported_program.schema_version
|
||||||
|
|
||||||
# TODO(zhxchen17) blocked on thrift schema refactor
|
# TODO(zhxchen17) blocked on thrift schema refactor
|
||||||
if version.major != SCHEMA_VERSION[0] and not (version.major == 0 and version.minor == 0):
|
if version.major != SCHEMA_VERSION[0] and not (version.major == 0 and version.minor == 0):
|
||||||
raise SerializeError(
|
if not _unsafe_skip_version_check:
|
||||||
f"Serialized schema version {exported_program.schema_version} "
|
raise SerializeError(
|
||||||
f"does not match our current schema version {SCHEMA_VERSION}."
|
f"Serialized schema version {exported_program.schema_version} "
|
||||||
)
|
f"does not match our current schema version {SCHEMA_VERSION}."
|
||||||
|
)
|
||||||
|
|
||||||
symbol_name_to_range = {
|
symbol_name_to_range = {
|
||||||
k: symbolic_shapes.ValueRanges(
|
k: symbolic_shapes.ValueRanges(
|
||||||
|
|
@ -2448,6 +2451,8 @@ def _dict_to_dataclass(cls, data):
|
||||||
def deserialize(
|
def deserialize(
|
||||||
artifact: SerializedArtifact,
|
artifact: SerializedArtifact,
|
||||||
expected_opset_version: Optional[Dict[str, int]] = None,
|
expected_opset_version: Optional[Dict[str, int]] = None,
|
||||||
|
*,
|
||||||
|
_unsafe_skip_version_check=False,
|
||||||
) -> ep.ExportedProgram:
|
) -> ep.ExportedProgram:
|
||||||
assert isinstance(artifact.exported_program, bytes)
|
assert isinstance(artifact.exported_program, bytes)
|
||||||
exported_program_str = artifact.exported_program.decode("utf-8")
|
exported_program_str = artifact.exported_program.decode("utf-8")
|
||||||
|
|
@ -2460,6 +2465,7 @@ def deserialize(
|
||||||
artifact.state_dict,
|
artifact.state_dict,
|
||||||
artifact.constants,
|
artifact.constants,
|
||||||
artifact.example_inputs,
|
artifact.example_inputs,
|
||||||
|
_unsafe_skip_version_check=_unsafe_skip_version_check,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user