diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index f3c1ec6cd36..a08ede5fcd8 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -2315,16 +2315,19 @@ class ExportedProgramDeserializer(metaclass=Final): state_dict: Union[Dict[str, torch.Tensor], bytes], constants: Union[Dict[str, torch.Tensor], bytes], example_inputs: Optional[Union[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, + *, + _unsafe_skip_version_check=False, ) -> ep.ExportedProgram: assert isinstance(exported_program, ExportedProgram) version = exported_program.schema_version # TODO(zhxchen17) blocked on thrift schema refactor if version.major != SCHEMA_VERSION[0] and not (version.major == 0 and version.minor == 0): - raise SerializeError( - f"Serialized schema version {exported_program.schema_version} " - f"does not match our current schema version {SCHEMA_VERSION}." - ) + if not _unsafe_skip_version_check: + raise SerializeError( + f"Serialized schema version {exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) symbol_name_to_range = { k: symbolic_shapes.ValueRanges( @@ -2448,6 +2451,8 @@ def _dict_to_dataclass(cls, data): def deserialize( artifact: SerializedArtifact, expected_opset_version: Optional[Dict[str, int]] = None, + *, + _unsafe_skip_version_check=False, ) -> ep.ExportedProgram: assert isinstance(artifact.exported_program, bytes) exported_program_str = artifact.exported_program.decode("utf-8") @@ -2460,6 +2465,7 @@ def deserialize( artifact.state_dict, artifact.constants, artifact.example_inputs, + _unsafe_skip_version_check=_unsafe_skip_version_check, ) )