pytorch/test/export/test_schema.py
Aaron Orenstein cd8d0fa20c Tweak schema_check to handle annotated builtin types (#145154)
As of python 3.9 annotated lists can be written as `list[T]` and `List[T]` has been deprecated.  However schema_check was converting `list[T]` to simply be `list`. This change teaches it to handle `list[T]` the same as `List[T]`.

A couple small drive-by changes I noticed as well:
- Path concatenation should use `os.path.join`, not `+`
- Spelling in error message

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145154
Approved by: https://github.com/bobrenjc93
2025-01-19 18:48:35 +00:00

410 lines
11 KiB
Python

# Owner(s): ["oncall: export"]
from torch._export.serde.schema_check import (
_Commit,
_diff_schema,
check,
SchemaUpdateError,
update_schema,
)
from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase
class TestSchema(TestCase):
def test_schema_compatibility(self):
msg = """
Detected an invalidated change to export schema. Please run the following script to update the schema:
Example(s):
python scripts/export/update_schema.py --prefix <path_to_torch_development_directory>
"""
if IS_FBCODE:
msg += """or
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
"""
try:
commit = update_schema()
except SchemaUpdateError as e:
self.fail(f"Failed to update schema: {e}\n{msg}")
self.assertEqual(commit.checksum_head, commit.checksum_next, msg)
def test_thrift_schema_unchanged(self):
msg = """
Detected an unexpected change to schema.thrift. Please update schema.py instead and run the following script:
Example(s):
python scripts/export/update_schema.py --prefix <path_to_torch_development_directory>
"""
if IS_FBCODE:
msg += """or
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
"""
try:
commit = update_schema()
except SchemaUpdateError as e:
self.fail(f"Failed to update schema: {e}\n{msg}")
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_real, msg)
self.assertEqual(commit.thrift_checksum_head, commit.thrift_checksum_next, msg)
def test_schema_diff(self):
additions, subtractions = _diff_schema(
{
"Type0": {"kind": "struct", "fields": {}},
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
"field2": {"type": ""},
"field3": {"type": "", "default": "[]"},
},
},
},
{
"Type2": {
"kind": "struct",
"fields": {
"field1": {"type": "", "default": "0"},
"field2": {"type": "", "default": "[]"},
"field3": {"type": ""},
},
},
"Type1": {"kind": "struct", "fields": {}},
},
)
self.assertEqual(
additions,
{
"Type1": {"kind": "struct", "fields": {}},
"Type2": {
"fields": {
"field1": {"type": "", "default": "0"},
"field2": {"default": "[]"},
},
},
},
)
self.assertEqual(
subtractions,
{
"Type0": {"kind": "struct", "fields": {}},
"Type2": {
"fields": {
"field0": {"type": ""},
"field3": {"default": "[]"},
},
},
},
)
def test_schema_check(self):
# Adding field without default value
dst = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
"field1": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
# Removing field
dst = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "struct",
"fields": {},
},
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
# Adding field with default value
dst = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
"field1": {"type": "", "default": "[]"},
},
},
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
# Changing field type
dst = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": "int"},
},
},
"SCHEMA_VERSION": [3, 2],
}
with self.assertRaises(SchemaUpdateError):
_diff_schema(dst, src)
# Adding new type.
dst = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"Type1": {"kind": "struct", "fields": {}},
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
# Removing a type.
dst = {
"Type2": {
"kind": "struct",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
# Adding new field in union.
dst = {
"Type2": {
"kind": "union",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "union",
"fields": {
"field0": {"type": ""},
"field1": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [3, 3])
# Removing a field in union.
dst = {
"Type2": {
"kind": "union",
"fields": {
"field0": {"type": ""},
},
},
"SCHEMA_VERSION": [3, 2],
}
src = {
"Type2": {
"kind": "union",
"fields": {},
},
"SCHEMA_VERSION": [3, 2],
}
additions, subtractions = _diff_schema(dst, src)
commit = _Commit(
result=src,
checksum_next="",
yaml_path="",
additions=additions,
subtractions=subtractions,
base=dst,
checksum_head="",
cpp_header="",
cpp_header_path="",
thrift_checksum_head="",
thrift_checksum_real="",
thrift_checksum_next="",
thrift_schema="",
thrift_schema_path="",
)
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
if __name__ == "__main__":
run_tests()