mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129760 Approved by: https://github.com/ezyang
341 lines
9.0 KiB
Python
341 lines
9.0 KiB
Python
# Owner(s): ["oncall: export"]
|
|
from torch._export.serde.schema_check import (
|
|
_Commit,
|
|
_diff_schema,
|
|
check,
|
|
SchemaUpdateError,
|
|
update_schema,
|
|
)
|
|
from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase
|
|
|
|
|
|
class TestSchema(TestCase):
|
|
def test_schema_compatibility(self):
|
|
msg = """
|
|
Detected an invalidated change to export schema. Please run the following script to update the schema:
|
|
Example(s):
|
|
python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory>
|
|
"""
|
|
|
|
if IS_FBCODE:
|
|
msg += """or
|
|
buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/
|
|
"""
|
|
try:
|
|
commit = update_schema()
|
|
except SchemaUpdateError as e:
|
|
self.fail(f"Failed to update schema: {e}\n{msg}")
|
|
|
|
self.assertEqual(commit.checksum_base, commit.checksum_result, msg)
|
|
|
|
def test_schema_diff(self):
|
|
additions, subtractions = _diff_schema(
|
|
{
|
|
"Type0": {"kind": "struct", "fields": {}},
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
"field2": {"type": ""},
|
|
"field3": {"type": "", "default": "[]"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field1": {"type": "", "default": "0"},
|
|
"field2": {"type": "", "default": "[]"},
|
|
"field3": {"type": ""},
|
|
},
|
|
},
|
|
"Type1": {"kind": "struct", "fields": {}},
|
|
},
|
|
)
|
|
|
|
self.assertEqual(
|
|
additions,
|
|
{
|
|
"Type1": {"kind": "struct", "fields": {}},
|
|
"Type2": {
|
|
"fields": {
|
|
"field1": {"type": "", "default": "0"},
|
|
"field2": {"default": "[]"},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
self.assertEqual(
|
|
subtractions,
|
|
{
|
|
"Type0": {"kind": "struct", "fields": {}},
|
|
"Type2": {
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
"field3": {"default": "[]"},
|
|
},
|
|
},
|
|
},
|
|
)
|
|
|
|
def test_schema_check(self):
|
|
# Adding field without default value
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
"field1": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [4, 1])
|
|
|
|
# Removing field
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [4, 1])
|
|
|
|
# Adding field with default value
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
"field1": {"type": "", "default": "[]"},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [3, 3])
|
|
|
|
# Changing field type
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": "int"},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
with self.assertRaises(SchemaUpdateError):
|
|
_diff_schema(dst, src)
|
|
|
|
# Adding new type.
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"Type1": {"kind": "struct", "fields": {}},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [3, 3])
|
|
|
|
# Removing a type.
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "struct",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [3, 3])
|
|
|
|
# Adding new field in union.
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "union",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "union",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
"field1": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [3, 3])
|
|
|
|
# Removing a field in union.
|
|
dst = {
|
|
"Type2": {
|
|
"kind": "union",
|
|
"fields": {
|
|
"field0": {"type": ""},
|
|
},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
src = {
|
|
"Type2": {
|
|
"kind": "union",
|
|
"fields": {},
|
|
},
|
|
"SCHEMA_VERSION": [3, 2],
|
|
}
|
|
|
|
additions, subtractions = _diff_schema(dst, src)
|
|
|
|
commit = _Commit(
|
|
result=src,
|
|
checksum_result="",
|
|
path="",
|
|
additions=additions,
|
|
subtractions=subtractions,
|
|
base=dst,
|
|
checksum_base="",
|
|
)
|
|
next_version, _ = check(commit)
|
|
self.assertEqual(next_version, [4, 1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|