Add forward compatability tests in CI (#64139)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64139

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D30626912

Pulled By: tugsbayasgalan

fbshipit-source-id: 781a88386701b42e2e86daaca0a779d1fc1c4df3
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2022-01-05 23:38:45 -08:00 committed by Facebook GitHub Bot
parent 402f2934bf
commit 8bdbe94344
9 changed files with 272 additions and 10 deletions

2
.gitignore vendored
View File

@ -51,7 +51,7 @@ test/custom_operator/model.pt
test/jit_hooks/*.pt
test/data/legacy_modules.t7
test/data/*.pt
test/backward_compatibility/nightly_schemas.txt
test/forward_backward_compatibility/nightly_schemas.txt
dropout_model.pt
test/generated_type_hints_smoketest.py
test/htmlcov

View File

@ -436,9 +436,9 @@ test_xla() {
# Do NOT run this test before any other tests, like test_python_shard, etc.
# Because this function uninstalls the torch built from branch, and install
# nightly version.
test_backward_compatibility() {
test_forward_backward_compatibility() {
set -x
pushd test/backward_compatibility
pushd test/forward_backward_compatibility
python -m venv venv
# shellcheck disable=SC1091
. venv/bin/activate
@ -448,7 +448,7 @@ test_backward_compatibility() {
deactivate
rm -r venv
pip show torch
python check_backward_compatibility.py --existing-schemas nightly_schemas.txt
python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt
popd
set +x
assert_git_not_dirty
@ -529,7 +529,7 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze
fi
if [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then
test_backward_compatibility
test_forward_backward_compatibility
# Do NOT add tests after bc check tests, see its comment.
elif [[ "${TEST_CONFIG}" == *xla* ]]; then
install_torchvision

View File

@ -1316,7 +1316,7 @@ This choice depends on several factors; here is the decision tree as of
- pytorch_linux_xenial_py3_6_gcc5_4_build
- pytorch_cpp_doc_build
- pytorch_doc_test
- pytorch_linux_backward_compatibility_check_test
- pytorch_linux_forward_backward_compatibility_check_test
- pytorch_linux_xenial_py3_6_gcc5_4_jit_legacy_test
- pytorch_linux_xenial_py3_6_gcc5_4_test
- pytorch_python_doc_build

View File

@ -141,6 +141,15 @@ struct Argument {
const Argument& old,
std::ostream* why_not=nullptr) const;
// this function checks whether this Argument is forward compatible with
// the old one. we consider the following cases are forward compatible:
// 1) two arguments are equal
// 2) this arg's type should be subtype of old
// 3) this arg must provide the same default value if old arg has one,
bool isForwardCompatibleWith(
const Argument& old,
std::ostream* why_not = nullptr) const;
private:
std::string name_;
TypePtr type_;
@ -238,6 +247,28 @@ struct FunctionSchema {
const FunctionSchema& old,
std::ostream* why_not = nullptr) const;
// Checks whether this schema is forward compatible with the old one.
// The following conditions must be true:
// [Function structure] The new schema's name, overload-name, varargs, and
// return arity are the same.
// [Output Narrowing] The new schema's output type must be the same class
// or inherit from the old schema's output type.
// [Arg Compatibility] Every argument in the old schema has a corresponding
// argument in the new schema that:
// * is at the same position.
// * has the same name.
// * is either positional, or kwarg and the old argument was kwarg.
// * has the same type, or the old argument's type inherits from the
// new argument's type.
// [Default Values] Every new argument must have a default value.
// Each default value type should NOT be a container type.
// [Positioning] All defaults arguments MUST go after either old
// default arguments or the end of positional arguments
// and right BEFORE all out arguments
bool isForwardCompatibleWith(
const FunctionSchema& old,
std::ostringstream& why_not) const;
private:
OperatorName name_;
std::vector<Argument> arguments_;

View File

@ -1,4 +1,5 @@
#pragma once
#include <iostream>
// note: windows build doesn't find symbols in operator files unless
// this is a header file
@ -86,6 +87,34 @@ inline bool Argument::isBackwardCompatibleWith(
return true;
}
inline bool Argument::isForwardCompatibleWith(
const Argument& old,
std::ostream* why_not) const {
const Argument* lhs = this;
const Argument* rhs = &old;
if (!(lhs->name() == rhs->name()
&& lhs->N() == rhs->N()
&& (lhs->alias_info() == rhs->alias_info()
|| (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
&& *lhs->alias_info() == *rhs->alias_info())))) {
return false;
}
if (lhs->kwarg_only() && !rhs->kwarg_only()) {
return false;
}
if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) {
return false;
}
if (rhs->default_value().has_value() &&
lhs->default_value() != rhs->default_value()) {
return false;
}
if (lhs->default_value().has_value() && !rhs->default_value().has_value()) {
return false;
}
return true;
}
inline std::string FunctionSchema::formatTypeMismatchMsg(
const Argument& expected,
const std::string& actual_type,
@ -145,7 +174,7 @@ inline bool FunctionSchema::isBackwardCompatibleWith(
}
}
// // Validate that all new arguments provided has a default value
// Validate that all new arguments provided has a default value
for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) {
if (!arguments().at(i).default_value()) {
if (why_not) {
@ -171,6 +200,86 @@ inline bool FunctionSchema::isBackwardCompatibleWith(
return true;
}
inline bool FunctionSchema::isForwardCompatibleWith(
const FunctionSchema& old,
std::ostringstream& why_not) const {
if (!(name() == old.name() &&
overload_name() == old.overload_name()
// we are conservative on is_vararg and is_varret,
// since they are only used by internal operators
&& is_vararg() == old.is_vararg() && is_varret() == old.is_varret() &&
returns().size() == old.returns().size())) {
return false;
}
// we want to test both out and default args seperately
size_t old_out_start_idx = findFirstOutArg(old.arguments());
size_t new_out_start_idx = findFirstOutArg(arguments());
if (old.arguments().size() - old_out_start_idx !=
arguments().size() - new_out_start_idx) {
if (why_not) {
why_not << "Function schema should have the "
<< "same number of out arguments";
}
return false;
}
// make sure among the default args, they are forward compatible
for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) {
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
if (why_not) {
why_not
<< "'" << arguments().at(i).name() << "'"
<< " is not forward compatible with the older version of the schema";
}
return false;
}
}
// Validate that all new arguments provided has a default value
for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
if (!arguments().at(i).default_value()) {
if (why_not) {
why_not
<< "Function schema is not forward compatible since the new argument '"
<< arguments().at(i).name() << "' of type "
<< arguments().at(i).type()->str()
<< " did not provide a default value.";
}
return false;
}
auto default_val = arguments().at(i).default_value().value();
if (default_val.isList() || default_val.isGenericDict()) {
if (why_not) {
why_not
<< "Function schema is not forward compatible since the new argument '"
<< arguments().at(i).name() << "' of type "
<< arguments().at(i).type()->str() << " has a container type "
<< "as its default value.";
}
return false;
}
}
// now compare the out args
for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
if (!arguments()
.at(i - old_out_start_idx + new_out_start_idx)
.isForwardCompatibleWith(old.arguments().at(i))) {
if (why_not) {
why_not << "Out argument '"
<< "'" << arguments().at(i).name()
<< " is not FC with the older version of the schema";
}
return false;
}
}
return true;
}
inline void FunctionSchema::checkArg(
const IValue& value,
const Argument& argument,

View File

@ -2,6 +2,7 @@ import argparse
import datetime
import re
import sys
import warnings
from collections import defaultdict
import torch
@ -149,14 +150,16 @@ def dont_parse(schema_line):
return True
return False
def check_bc(existing_schemas):
def load_schemas_to_dict():
new_schemas = torch._C._jit_get_all_schemas()
new_schemas += torch._C._jit_get_custom_class_schemas()
new_schema_dict = defaultdict(list)
for s in new_schemas:
new_schema_dict[s.name].append(s)
return new_schema_dict
def check_bc(existing_schemas):
new_schema_dict = load_schemas_to_dict()
is_bc = True
broken_ops = []
for existing_schema in existing_schemas:
@ -192,6 +195,51 @@ def check_bc(existing_schemas):
)
return is_bc
def check_fc(existing_schemas):
new_schema_dict = load_schemas_to_dict()
is_fc = True
broken_ops = []
for existing_schema in existing_schemas:
if allow_listed(existing_schema):
print("schema: ", str(existing_schema), " found on allowlist, skipping")
continue
print("processing existing schema: ", str(existing_schema))
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
found = False
possible_failure_reasons = []
for matching_new_schema in matching_new_schemas:
is_compatible, reason = matching_new_schema.check_forward_compatible_with(existing_schema)
if is_compatible:
found = True
break
if reason != "":
possible_failure_reasons.append(reason)
if not found:
print(
"Can NOT find forward compatible schemas after changes "
"for schema {} from the following candidates:\n[\n{}\n]".format(
str(existing_schema),
"\n\t".join(str(s) for s in matching_new_schemas),
)
)
print(
"Refer to following reasons for failure "
"to find FC schema:\n[\n{}\n]".format(
"\n\t".join(str(r) for r in possible_failure_reasons)
)
)
broken_ops.append(str(existing_schema))
is_fc = False
if is_fc:
print("Found forward compatible schemas for all existing schemas")
else:
warnings.warn(
"The PR is introducing a potentially forward incompatible changes to the "
"operator library. Please contact PyTorch team to confirm "
"whether this change is wanted or not. \n\nBroken ops: "
"[\n\t{}\n]".format("\n\t".join(broken_ops))
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some integers.")
@ -216,5 +264,9 @@ if __name__ == "__main__":
s = parse_schema(line.strip())
slist.append(s)
# TODO in case there is FC breaking changes,
# we just warn for now until there is a policy.
check_fc(slist)
if not check_bc(slist):
sys.exit(1)

View File

@ -111,7 +111,70 @@ class TestFunctionSchema(TestCase):
def test_string_optional_parameter_default_value(self):
schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)")
schema_b = parse_schema(str(schema_a))
self.assertEquals(schema_a, schema_b)
self.assertEqual(schema_a, schema_b)
def test_forward_compatible_arguments_without_out(self):
old_schema = parse_schema('any(Tensor self, int a, int b=1) -> Tensor')
# deleting default arg is FC compatible
new_schema = parse_schema('any(Tensor self, int a) -> Tensor')
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
self.assertTrue(is_fc)
# adding default arg is FC compatible
new_schema = parse_schema('any(Tensor self, int a, int b=1, int c=1) -> Tensor')
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
self.assertTrue(is_fc)
# adding default arg with container type is NOT FC compatible
new_schema = parse_schema('any(Tensor self, int a, int b=1, int[2] c=1) -> Tensor')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "Function schema is not forward compatible since the new argument"
" \'c\' of type int[] has a container type as its default value.")
# updating the default value of a default arg is NOT FC compatible
new_schema = parse_schema('any(Tensor self, int a, int b=4) -> Tensor')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "\'b\' is not forward compatible with the older version of the schema")
# updating the arg name of a default arg is NOT FC compatible
new_schema = parse_schema('any(Tensor self, int a, int c=1) -> Tensor')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "\'c\' is not forward compatible with the older version of the schema")
# not adding default arg in the end is NOT FC compatible
new_schema = parse_schema('any(Tensor self, int a, int c=1, int b=1) -> Tensor')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "\'c\' is not forward compatible with the older version of the schema")
# making default arg into positional arg is NOT FC compatible
new_schema = parse_schema('any(Tensor self, int a, int b) -> Tensor')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "\'b\' is not forward compatible with the older version of the schema")
# making positional arg into default arg is NOT FC compatible
new_schema = parse_schema('any(Tensor self, int a=1, int b=1) -> Tensor')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "\'a\' is not forward compatible with the older version of the schema")
def test_forward_compatible_arguments_real_use_case(self):
# this change introduced forward incompatibility in the past
old_slice_schema = parse_schema('slice(Tensor(a) self, int dim=0, int start=0, int end=0, int step=1) -> Tensor(a)')
new_slice_schema = parse_schema('slice(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)')
is_fc, reason = new_slice_schema.check_forward_compatible_with(old_slice_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "\'start\' is not forward compatible with the older version of the schema")
def test_forward_compatible_arguments_with_out(self):
old_schema = parse_schema('any(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)')
new_schema = parse_schema('any(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)')
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
self.assertTrue(is_fc)
new_schema = parse_schema('any(Tensor self, *, int a, int b=1, int c=1, Tensor(a!) out) -> Tensor(a!)')
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
self.assertTrue(is_fc)
new_schema = parse_schema('any(Tensor self, *, int a, Tensor(d!) d, int b=1, Tensor(a!) out) -> Tensor(a!)')
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
self.assertFalse(is_fc)
self.assertEqual(reason, "Function schema should have the same number of out arguments")
def test_schema_error(self):
with self.assertRaisesRegex(RuntimeError, r"schemas with vararg \(...\) can't have default value args"):

View File

@ -1390,6 +1390,13 @@ void initJITBindings(PyObject* module) {
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
return self.isBackwardCompatibleWith(old_schema);
})
.def(
"check_forward_compatible_with",
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
std::ostringstream out;
auto result = self.isForwardCompatibleWith(old_schema, out);
return std::make_pair(result, out.str());
})
.def(
"__eq__",
[](const FunctionSchema& self, const FunctionSchema& other) {