[export] Add test to enforce consistency between synced thrift and generated thrift from schema.py (#141989)

Summary:
In this diff we implement a way to ensure the internal thrift schema from cfgr (configerator/structs/caffe2/torch/export/schema.thrift) and the schema in OSS (torch/_export/serde/schema.thrift) are in sync, by adding a unittest to reflect on the type names and fields from each schema and compare them field by field.

When we detect new fields/types from torch/_export/serde/schema.thrift, there'll be a test failure on the trunk and the error message hints people to add the missing field/type to the thrift schema from cfgr, so that they are always in sync in practice.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_thrift_schema_in_sync

Differential Revision: D66716834

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141989
Approved by: https://github.com/yiming0416
This commit is contained in:
Zhengxu Chen 2024-12-06 18:42:18 +00:00 committed by PyTorch MergeBot
parent bab15df40a
commit 1a7da6e7e9
8 changed files with 117 additions and 55 deletions

View File

@ -58,7 +58,7 @@ if __name__ == "__main__":
first_line = ( first_line = (
"@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py" "@" + "generated by " + os.path.basename(__file__).rsplit(".", 1)[0] + ".py"
) )
checksum = f"checksum<<{commit.checksum_result}>>" checksum = f"checksum<<{commit.checksum_next}>>"
yaml_header = "# " + first_line yaml_header = "# " + first_line
yaml_header += "\n# " + checksum yaml_header += "\n# " + checksum
yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False) yaml_payload = dump(commit.result, Dumper=Dumper, sort_keys=False)
@ -73,7 +73,7 @@ if __name__ == "__main__":
yaml_content = yaml_header + "\n" + yaml_payload yaml_content = yaml_header + "\n" + yaml_payload
thrift_schema = "// " + first_line thrift_schema = "// " + first_line
thrift_schema += "\n// " + checksum thrift_schema += f"\n// checksum<<{commit.thrift_checksum_next}>>"
thrift_schema += "\n" + commit.thrift_schema thrift_schema += "\n" + commit.thrift_schema
if args.dry_run: if args.dry_run:

View File

@ -1338,6 +1338,7 @@ def main():
"_inductor/codegen/*.h", "_inductor/codegen/*.h",
"_inductor/codegen/aoti_runtime/*.cpp", "_inductor/codegen/aoti_runtime/*.cpp",
"_export/serde/*.yaml", "_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake", "share/cmake/ATen/*.cmake",
"share/cmake/Caffe2/*.cmake", "share/cmake/Caffe2/*.cmake",
"share/cmake/Caffe2/public/*.cmake", "share/cmake/Caffe2/public/*.cmake",

View File

@ -26,7 +26,27 @@ Example(s):
except SchemaUpdateError as e: except SchemaUpdateError as e:
self.fail(f"Failed to update schema: {e}\n{msg}") self.fail(f"Failed to update schema: {e}\n{msg}")
self.assertEqual(commit.checksum_base, commit.checksum_result, msg) self.assertEqual(commit.checksum_head, commit.checksum_next, msg)
def test_thrift_schema_unchanged(self):
msg = """
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
Example(s):
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
"""
if IS_FBCODE:
msg += """or
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
"""
try:
commit = update_schema()
except SchemaUpdateError as e:
self.fail(f"Failed to update schema: {e}\n{msg}")
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_real, msg)
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_next, msg)
def test_schema_diff(self): def test_schema_diff(self):
additions, subtractions = _diff_schema( additions, subtractions = _diff_schema(
@ -105,14 +125,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )
@ -141,14 +164,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )
@ -180,14 +206,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )
@ -242,14 +271,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )
@ -274,14 +306,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )
@ -313,14 +348,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )
@ -349,14 +387,17 @@ Example(s):
commit = _Commit( commit = _Commit(
result=src, result=src,
checksum_result="", checksum_next="",
yaml_path="", yaml_path="",
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base="", checksum_head="",
cpp_header="", cpp_header="",
cpp_header_path="", cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="", thrift_schema="",
thrift_schema_path="", thrift_schema_path="",
) )

View File

@ -58,8 +58,8 @@ class Device:
@dataclass(repr=False) @dataclass(repr=False)
class SymExprHint(_Union): class SymExprHint(_Union):
as_int: Annotated[int, 10] as_int: Annotated[int, 10]
as_float: Annotated[float, 20] as_bool: Annotated[bool, 20]
as_bool: Annotated[bool, 30] as_float: Annotated[float, 30]
# This is for storing the symbolic expressions behind symints/symfloats/symbools # This is for storing the symbolic expressions behind symints/symfloats/symbools

View File

@ -1,7 +1,7 @@
// @generated by update_schema.py // @generated by update_schema.py
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>> // checksum<<0e89c5e620ad16c05bfe4fa2060ad43dcb0938dc31d77faad36b92f216c2c903>>
namespace py3 torch._export.schema namespace py3 torch._export
namespace cpp2 torch._export.schema namespace cpp2 torch._export.schema
enum Layout { enum Layout {
@ -51,8 +51,8 @@ struct Device {
union SymExprHint { union SymExprHint {
10: i64 as_int; 10: i64 as_int;
20: double as_float; 20: bool as_bool;
30: bool as_bool; 30: double as_float;
} }
struct SymExpr { struct SymExpr {

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py # @generated by update_schema.py
# checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>> # checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
Argument: Argument:
kind: union kind: union
fields: fields:
@ -380,10 +380,10 @@ SymExprHint:
fields: fields:
as_int: as_int:
type: int type: int
as_float:
type: float
as_bool: as_bool:
type: bool type: bool
as_float:
type: float
SymFloat: SymFloat:
kind: union kind: union
fields: fields:

View File

@ -31,7 +31,7 @@ def _staged_schema():
thrift_type_defs: Dict[str, str] = {} thrift_type_defs: Dict[str, str] = {}
def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: def _handle_aggregate(ty) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def dump_type(t) -> Tuple[str, str, str]: def dump_type(t, level: int) -> Tuple[str, str, str]:
CPP_TYPE_MAP = { CPP_TYPE_MAP = {
str: "std::string", str: "std::string",
int: "int64_t", int: "int64_t",
@ -90,20 +90,21 @@ def _staged_schema():
"", "",
) )
elif o == Union: elif o == Union:
assert level == 0, "Optional is only supported at the top level."
args = typing.get_args(t) args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None) assert len(args) == 2 and args[1] == type(None)
yaml_type, cpp_type, thrift_type = dump_type(args[0]) yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
return ( return (
f"Optional[{yaml_type}]", f"Optional[{yaml_type}]",
f"std::optional<{cpp_type}>", f"std::optional<{cpp_type}>",
f"optional {thrift_type}", f"optional {thrift_type}",
) )
elif o == Annotated: elif o == Annotated:
return dump_type(t.__origin__) return dump_type(t.__origin__, level)
else: else:
raise AssertionError(f"Type {t} is not supported in export schema.") raise AssertionError(f"Type {t} is not supported in export schema.")
yaml_arg_types, cpp_arg_types, thrift_arg_types = zip( yaml_arg_types, cpp_arg_types, thrift_arg_types = zip(
*[dump_type(x) for x in typing.get_args(t)] *[dump_type(x, level + 1) for x in typing.get_args(t)]
) )
return ( return (
(f"{yaml_head}[{', '.join(yaml_arg_types)}]"), (f"{yaml_head}[{', '.join(yaml_arg_types)}]"),
@ -136,7 +137,7 @@ def _staged_schema():
) )
def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]: def dump_field(f) -> Tuple[Dict[str, Any], str, Optional[str], str, int]:
t, cpp_type, thrift_type = dump_type(f.type) t, cpp_type, thrift_type = dump_type(f.type, 0)
ret = {"type": t} ret = {"type": t}
cpp_default: Optional[str] = None cpp_default: Optional[str] = None
assert ( assert (
@ -455,7 +456,7 @@ void from_json(const nlohmann::json& j, ForwardRef<T>& p) {{
}} // namespace torch }} // namespace torch
""" """
thrift_schema = f""" thrift_schema = f"""
namespace py3 torch._export.schema namespace py3 torch._export
namespace cpp2 torch._export.schema namespace cpp2 torch._export.schema
{chr(10).join(thrift_enum_defs)} {chr(10).join(thrift_enum_defs)}
{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())} {chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())}
@ -528,21 +529,24 @@ def _diff_schema(dst, src):
return additions, subtractions return additions, subtractions
def _hash_schema(s): def _hash_content(s: str):
return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() return hashlib.sha256(s.strip().encode("utf-8")).hexdigest()
@dataclasses.dataclass @dataclasses.dataclass
class _Commit: class _Commit:
result: Dict[str, Any] result: Dict[str, Any]
checksum_result: str checksum_next: str
yaml_path: str yaml_path: str
additions: Dict[str, Any] additions: Dict[str, Any]
subtractions: Dict[str, Any] subtractions: Dict[str, Any]
base: Dict[str, Any] base: Dict[str, Any]
checksum_base: Optional[str] checksum_head: Optional[str]
cpp_header: str cpp_header: str
cpp_header_path: str cpp_header_path: str
thrift_checksum_head: Optional[str]
thrift_checksum_real: Optional[str]
thrift_checksum_next: str
thrift_schema: str thrift_schema: str
thrift_schema_path: str thrift_schema_path: str
@ -555,13 +559,26 @@ def update_schema():
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
_check(match is not None, "checksum not found in schema.yaml") _check(match is not None, "checksum not found in schema.yaml")
assert match is not None assert match is not None
checksum_base = match.group(1) checksum_head = match.group(1)
thrift_content = importlib.resources.read_text(__package__, "schema.thrift")
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
_check(match is not None, "checksum not found in schema.thrift")
assert match is not None
thrift_checksum_head = match.group(1)
thrift_content = thrift_content.splitlines()
assert thrift_content[0].startswith("// @" + "generated")
assert thrift_content[1].startswith("// checksum<<")
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
from yaml import load, Loader from yaml import load, Loader
dst = load(content, Loader=Loader) dst = load(content, Loader=Loader)
assert isinstance(dst, dict) assert isinstance(dst, dict)
else: else:
checksum_base = None checksum_head = None
thrift_checksum_head = None
thrift_checksum_real = None
dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
src, cpp_header, thrift_schema = _staged_schema() src, cpp_header, thrift_schema = _staged_schema()
@ -574,14 +591,17 @@ def update_schema():
return _Commit( return _Commit(
result=src, result=src,
checksum_result=_hash_schema(src), checksum_next=_hash_content(repr(src)),
yaml_path=yaml_path, yaml_path=yaml_path,
additions=additions, additions=additions,
subtractions=subtractions, subtractions=subtractions,
base=dst, base=dst,
checksum_base=checksum_base, checksum_head=checksum_head,
cpp_header=cpp_header, cpp_header=cpp_header,
cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h",
thrift_checksum_head=thrift_checksum_head,
thrift_checksum_real=thrift_checksum_real,
thrift_checksum_next=_hash_content(thrift_schema),
thrift_schema=thrift_schema, thrift_schema=thrift_schema,
thrift_schema_path=thrift_schema_path, thrift_schema_path=thrift_schema_path,
) )

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py // @generated by update_schema.py
// checksum<<4d7fed9eff0dc31422e15dc73bd5d4d31b2feba660d85d9d0a35881670166ebb>> // checksum<<0335ca6e44a8a815ea638d538de0ad4f78a644af2689f6e93c0e8219117466e7>>
// clang-format off // clang-format off
#pragma once #pragma once
@ -191,11 +191,11 @@ class SymExprHint {
public: public:
enum class Tag { enum class Tag {
AS_INT, AS_FLOAT, AS_BOOL AS_INT, AS_BOOL, AS_FLOAT
}; };
private: private:
std::variant<Void, int64_t, double, bool> variant_; std::variant<Void, int64_t, bool, double> variant_;
Tag tag_; Tag tag_;
public: public:
@ -207,11 +207,11 @@ class SymExprHint {
return std::get<1>(variant_); return std::get<1>(variant_);
} }
const double& get_as_float() const { const bool& get_as_bool() const {
return std::get<2>(variant_); return std::get<2>(variant_);
} }
const bool& get_as_bool() const { const double& get_as_float() const {
return std::get<3>(variant_); return std::get<3>(variant_);
} }
@ -221,14 +221,14 @@ class SymExprHint {
nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int(); nlohmann_json_j["as_int"] = nlohmann_json_t.get_as_int();
return; return;
} }
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
return;
}
if (nlohmann_json_t.tag_ == Tag::AS_BOOL) { if (nlohmann_json_t.tag_ == Tag::AS_BOOL) {
nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool(); nlohmann_json_j["as_bool"] = nlohmann_json_t.get_as_bool();
return; return;
} }
if (nlohmann_json_t.tag_ == Tag::AS_FLOAT) {
nlohmann_json_j["as_float"] = nlohmann_json_t.get_as_float();
return;
}
} }
friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) { friend void from_json(const nlohmann::json& nlohmann_json_j, SymExprHint& nlohmann_json_t) {
@ -238,14 +238,14 @@ class SymExprHint {
nlohmann_json_t.tag_ = Tag::AS_INT; nlohmann_json_t.tag_ = Tag::AS_INT;
return; return;
} }
if (nlohmann_json_j.contains("as_float")) { if (nlohmann_json_j.contains("as_bool")) {
nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_float").template get<double>()); nlohmann_json_t.variant_.emplace<2>(nlohmann_json_j.at("as_bool").template get<bool>());
nlohmann_json_t.tag_ = Tag::AS_FLOAT; nlohmann_json_t.tag_ = Tag::AS_BOOL;
return; return;
} }
if (nlohmann_json_j.contains("as_bool")) { if (nlohmann_json_j.contains("as_float")) {
nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_bool").template get<bool>()); nlohmann_json_t.variant_.emplace<3>(nlohmann_json_j.at("as_float").template get<double>());
nlohmann_json_t.tag_ = Tag::AS_BOOL; nlohmann_json_t.tag_ = Tag::AS_FLOAT;
return; return;
} }
} }