diff --git a/scripts/export/update_schema.py b/scripts/export/update_schema.py index 798e792cce8..bc76e4b7bfc 100644 --- a/scripts/export/update_schema.py +++ b/scripts/export/update_schema.py @@ -58,7 +58,7 @@ if __name__ == "__main__": first_line = ( "@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py" ) - checksum = f"checksum<<{commit.checksum_result}>>" + checksum = f"checksum<<{commit.checksum_next}>>" yaml_header = "# " + first_line yaml_header += "\n# " + checksum yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False) @@ -73,7 +73,7 @@ if __name__ == "__main__": yaml_content = yaml_header + "\n" + yaml_payload thrift_schema = "// " + first_line - thrift_schema += "\n// " + checksum + thrift_schema += f"\n// checksum<<{commit.thrift_checksum_next}>>" thrift_schema += "\n" + commit.thrift_schema if args.dry_run: diff --git a/setup.py b/setup.py index bac2a7ac534..24c8c19b9bb 100644 --- a/setup.py +++ b/setup.py @@ -1338,6 +1338,7 @@ def main(): "_inductor/codegen/*.h", "_inductor/codegen/aoti_runtime/*.cpp", "_export/serde/*.yaml", + "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", "share/cmake/Caffe2/*.cmake", "share/cmake/Caffe2/public/*.cmake", diff --git a/test/export/test_schema.py b/test/export/test_schema.py index b8b72656c94..fef9ee796d5 100644 --- a/test/export/test_schema.py +++ b/test/export/test_schema.py @@ -26,7 +26,27 @@ Example(s): except SchemaUpdateError as e: self.fail(f"Failed to update schema: {e}\n{msg}") - self.assertEqual(commit.checksum_base, commit.checksum_result, msg) + self.assertEqual(commit.checksum_head, commit.checksum_next, msg) + + def test_thrift_schema_unchanged(self): + msg = """ +Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script: +Example(s): + python scripts/export/update_schema.py --prefix + """ + + if IS_FBCODE: + msg += """or + buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ + """ + + try: + commit = update_schema() + except SchemaUpdateError as e: + self.fail(f"Failed to update schema: {e}\n{msg}") + + self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_real, msg) + self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_next, msg) def test_schema_diff(self): additions, subtractions = _diff_schema( @@ -105,14 +125,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) @@ -141,14 +164,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) @@ -180,14 +206,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) @@ -242,14 +271,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) @@ -274,14 +306,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) @@ -313,14 +348,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) @@ -349,14 +387,17 @@ Example(s): commit = _Commit( result=src, - checksum_result="", + checksum_next="", yaml_path="", additions=additions, subtractions=subtractions, base=dst, - checksum_base="", + checksum_head="", cpp_header="", cpp_header_path="", + thrift_checksum_head="", + thrift_checksum_real="", + thrift_checksum_next="", thrift_schema="", thrift_schema_path="", ) diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index b9fa2552a6a..ad6d3a57893 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -58,8 +58,8 @@ class Device: @dataclass(repr=False) class SymExprHint(_Union): as_int: Annotated[int, 10] - as_float: Annotated[float, 20] - as_bool: Annotated[bool, 30] + as_bool: Annotated[bool, 20] + as_float: Annotated[float, 30] # This is for storing the symbolic expressions behind symints/symfloats/symbools diff --git a/torch/_export/serde/schema.thrift b/torch/_export/serde/schema.thrift index 6a61e36402a..8d986f06c9e 100644 --- a/torch/_export/serde/schema.thrift +++ b/torch/_export/serde/schema.thrift @@ -1,7 +1,7 @@ // @generated by update_schema.py -// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>> +// checksum<<0e89c5e620ad16c05bfe4fa2060ad43dcb0938dc31d77faad36b92f216c2c903>> -namespace py3 torch._export.schema +namespace py3 torch._export namespace cpp2 torch._export.schema enum Layout { @@ -51,8 +51,8 @@ struct Device { union SymExprHint { 10: i64 as_int; - 20: double as_float; - 30: bool as_bool; + 20: bool as_bool; + 30: double as_float; } struct SymExpr { diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 87fb73c3762..8719f014fcd 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>> +# checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>> Argument: kind: union fields: @@ -380,10 +380,10 @@ SymExprHint: fields: as_int: type: int - as_float: - type: float as_bool: type: bool + as_float: + type: float SymFloat: kind: union fields: diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 80b28f21e15..979a90d838f 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -31,7 +31,7 @@ def _staged_schema(): thrift_type_defs: Dict[str, str] = {} def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: - def dump_type(t) -> Tuple[str, str, str]: + def dump_type(t, level: int) -> Tuple[str, str, str]: CPP_TYPE_MAP = { str: "std::string", int: "int64_t", @@ -90,20 +90,21 @@ def _staged_schema(): "", ) elif o == Union: + assert level == 0, "Optional is only supported at the top level." args = typing.get_args(t) assert len(args) == 2 and args[1] == type(None) - yaml_type, cpp_type, thrift_type = dump_type(args[0]) + yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1) return ( f"Optional[{yaml_type}]", f"std::optional<{cpp_type}>", f"optional {thrift_type}", ) elif o == Annotated: - return dump_type(t.__origin__) + return dump_type(t.__origin__, level) else: raise AssertionError(f"Type {t} is not supported in export schema.") yaml_arg_types, cpp_arg_types, thrift_arg_types = zip( - *[dump_type(x) for x in typing.get_args(t)] + *[dump_type(x, level + 1) for x in typing.get_args(t)] ) return ( (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), @@ -136,7 +137,7 @@ def _staged_schema(): ) def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]: - t, cpp_type, thrift_type = dump_type(f.type) + t, cpp_type, thrift_type = dump_type(f.type, 0) ret = {"type": t} cpp_default: Optional[str] = None assert ( @@ -455,7 +456,7 @@ void from_json(const nlohmann::json& j, ForwardRef& p) {{ }} // namespace torch """ thrift_schema = f""" -namespace py3 torch._export.schema +namespace py3 torch._export namespace cpp2 torch._export.schema {chr(10).join(thrift_enum_defs)} {chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())} @@ -528,21 +529,24 @@ def _diff_schema(dst, src): return additions, subtractions -def _hash_schema(s): - return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() +def _hash_content(s: str): + return hashlib.sha256(s.strip().encode("utf-8")).hexdigest() @dataclasses.dataclass class _Commit: result: Dict[str, Any] - checksum_result: str + checksum_next: str yaml_path: str additions: Dict[str, Any] subtractions: Dict[str, Any] base: Dict[str, Any] - checksum_base: Optional[str] + checksum_head: Optional[str] cpp_header: str cpp_header_path: str + thrift_checksum_head: Optional[str] + thrift_checksum_real: Optional[str] + thrift_checksum_next: str thrift_schema: str thrift_schema_path: str @@ -555,13 +559,26 @@ def update_schema(): match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) _check(match is not None, "checksum not found in schema.yaml") assert match is not None - checksum_base = match.group(1) + checksum_head = match.group(1) + + thrift_content = importlib.resources.read_text(__package__, "schema.thrift") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content) + _check(match is not None, "checksum not found in schema.thrift") + assert match is not None + thrift_checksum_head = match.group(1) + thrift_content = thrift_content.splitlines() + assert thrift_content[0].startswith("// @" + "generated") + assert thrift_content[1].startswith("// checksum<<") + thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) + from yaml import load, Loader dst = load(content, Loader=Loader) assert isinstance(dst, dict) else: - checksum_base = None + checksum_head = None + thrift_checksum_head = None + thrift_checksum_real = None dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} src, cpp_header, thrift_schema = _staged_schema() @@ -574,14 +591,17 @@ def update_schema(): return _Commit( result=src, - checksum_result=_hash_schema(src), + checksum_next=_hash_content(repr(src)), yaml_path=yaml_path, additions=additions, subtractions=subtractions, base=dst, - checksum_base=checksum_base, + checksum_head=checksum_head, cpp_header=cpp_header, cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", + thrift_checksum_head=thrift_checksum_head, + thrift_checksum_real=thrift_checksum_real, + thrift_checksum_next=_hash_content(thrift_schema), thrift_schema=thrift_schema, thrift_schema_path=thrift_schema_path, ) diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 6a7069881ba..8cf3cb86da1 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>> +// checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>> // clang-format off #pragma once @@ -191,11 +191,11 @@ class SymExprHint { public: enum class Tag { - AS_INT, AS_FLOAT, AS_BOOL + AS_INT, AS_BOOL, AS_FLOAT }; private: - std::variant variant_; + std::variant variant_; Tag tag_; public: @@ -207,11 +207,11 @@ class SymExprHint { return std::get<1>(variant_); } - const double& get_as_float() const { + const bool& get_as_bool() const { return std::get<2>(variant_); } - const bool& get_as_bool() const { + const double& get_as_float() const { return std::get<3>(variant_); } @@ -221,14 +221,14 @@ class SymExprHint { nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); return; } - if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { - nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); - return; - } if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); return; } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) { + nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float(); + return; + } } friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) { @@ -238,14 +238,14 @@ class SymExprHint { nlohmann_json_t.tag_ = Tag::AS_INT; return; } - if (nlohmann_json_j.contains("as_float")) { - nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get()); - nlohmann_json_t.tag_ = Tag::AS_FLOAT; + if (nlohmann_json_j.contains("as_bool")) { + nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get()); + nlohmann_json_t.tag_ = Tag::AS_BOOL; return; } - if (nlohmann_json_j.contains("as_bool")) { - nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_bool").template get()); - nlohmann_json_t.tag_ = Tag::AS_BOOL; + if (nlohmann_json_j.contains("as_float")) { + nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT; return; } }