mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
112 lines
6.3 KiB
Python
112 lines
6.3 KiB
Python
import torch
|
|
from torch._C import parse_schema
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
|
|
class TestFunctionSchema(TestCase):
|
|
def test_serialize_and_deserialize(self):
|
|
schemas = torch._C._jit_get_all_schemas()
|
|
# so far we have around 1700 registered schemas
|
|
self.assertGreater(len(schemas), 1000)
|
|
for schema in schemas:
|
|
parsed_schema = parse_schema(str(schema))
|
|
self.assertEqual(parsed_schema, schema)
|
|
self.assertTrue(parsed_schema.is_backward_compatible_with(schema))
|
|
|
|
def test_backward_compatible_structure(self):
|
|
old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
|
|
# BC: A new schema without changes.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with different name.
|
|
new_schema = parse_schema("any_.over(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with different overload name.
|
|
new_schema = parse_schema("any.other(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema that adds vararg.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b, ...) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with different number of outputs.
|
|
new_schema = parse_schema(
|
|
"any.over(Tensor self, *, Tensor b) -> (Tensor, Tensor)"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
|
|
def test_backward_compatible_outputs(self):
|
|
old_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor")
|
|
# No-BC: A new schema with output becoming of optional type.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor?")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: (the opposite case) An schema where the output is not of optional type anymore.
|
|
self.assertTrue(old_schema.is_backward_compatible_with(new_schema))
|
|
# No-BC: A new schema with a different output type.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> int")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with a different output type.
|
|
new_schema = parse_schema("any.over(Tensor self, *, Tensor b) -> Tensor out")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
|
|
def test_backward_compatible_arguments(self):
|
|
old_schema = parse_schema("any(Tensor self, *, Tensor b, int c) -> Tensor")
|
|
# No-BC: A new schema with less arguments.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with more arguments, appended, but no default value.
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, Tensor b, int c, int d) -> Tensor"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema with more arguments, appended, that have a default value.
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, Tensor b, int c, int d=1) -> Tensor"
|
|
)
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema with more arguments, not-appended, that have a default value.
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, int d=1, *, Tensor b, int c) -> Tensor"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema where old kwargs becomes positional.
|
|
new_schema = parse_schema("any(Tensor self, Tensor b, *, int c) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: (the opposite case) A new schema where an old positional argument becomes kwarg.
|
|
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
|
|
# BC: A new schema where all old kwargs become positional.
|
|
new_schema = parse_schema("any(Tensor self, Tensor b, int c) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: (the opposite case) A new schema where all old positional arguments become kwarg.
|
|
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
|
|
# No-BC: A new schema where old kwargs appear in different order.
|
|
new_schema = parse_schema("any(Tensor self, *, int c, Tensor b) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema where argument becomes of type optional.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b, int? c) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# BC: A new schema where argument gains a default value.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b, int c=1) -> Tensor")
|
|
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema where argument is "renamed".
|
|
new_schema = parse_schema(
|
|
"any(Tensor self, *, Tensor b, int renamed) -> Tensor"
|
|
)
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
# No-BC: A new schema where argument type changes to an incompatible type.
|
|
new_schema = parse_schema("any(Tensor self, *, Tensor b, int[] c) -> Tensor")
|
|
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
|
|
|
|
def test_string_optional_parameter_default_value(self):
|
|
schema_a = parse_schema('example::op(str? order="NCHW") -> (Tensor)')
|
|
schema_b = parse_schema(str(schema_a))
|
|
self.assertEquals(schema_a, schema_b)
|
|
|
|
def test_schema_error(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"schemas with vararg \(...\) can't have default value args"
|
|
):
|
|
schema = parse_schema("any.foo(int arg1, int arg2=0, ...)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|