mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add forward compatability tests in CI (#64139)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64139 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D30626912 Pulled By: tugsbayasgalan fbshipit-source-id: 781a88386701b42e2e86daaca0a779d1fc1c4df3
This commit is contained in:
parent
402f2934bf
commit
8bdbe94344
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -51,7 +51,7 @@ test/custom_operator/model.pt
|
|||
test/jit_hooks/*.pt
|
||||
test/data/legacy_modules.t7
|
||||
test/data/*.pt
|
||||
test/backward_compatibility/nightly_schemas.txt
|
||||
test/forward_backward_compatibility/nightly_schemas.txt
|
||||
dropout_model.pt
|
||||
test/generated_type_hints_smoketest.py
|
||||
test/htmlcov
|
||||
|
|
|
|||
|
|
@ -436,9 +436,9 @@ test_xla() {
|
|||
# Do NOT run this test before any other tests, like test_python_shard, etc.
|
||||
# Because this function uninstalls the torch built from branch, and install
|
||||
# nightly version.
|
||||
test_backward_compatibility() {
|
||||
test_forward_backward_compatibility() {
|
||||
set -x
|
||||
pushd test/backward_compatibility
|
||||
pushd test/forward_backward_compatibility
|
||||
python -m venv venv
|
||||
# shellcheck disable=SC1091
|
||||
. venv/bin/activate
|
||||
|
|
@ -448,7 +448,7 @@ test_backward_compatibility() {
|
|||
deactivate
|
||||
rm -r venv
|
||||
pip show torch
|
||||
python check_backward_compatibility.py --existing-schemas nightly_schemas.txt
|
||||
python check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt
|
||||
popd
|
||||
set +x
|
||||
assert_git_not_dirty
|
||||
|
|
@ -529,7 +529,7 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze
|
|||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then
|
||||
test_backward_compatibility
|
||||
test_forward_backward_compatibility
|
||||
# Do NOT add tests after bc check tests, see its comment.
|
||||
elif [[ "${TEST_CONFIG}" == *xla* ]]; then
|
||||
install_torchvision
|
||||
|
|
|
|||
|
|
@ -1316,7 +1316,7 @@ This choice depends on several factors; here is the decision tree as of
|
|||
- pytorch_linux_xenial_py3_6_gcc5_4_build
|
||||
- pytorch_cpp_doc_build
|
||||
- pytorch_doc_test
|
||||
- pytorch_linux_backward_compatibility_check_test
|
||||
- pytorch_linux_forward_backward_compatibility_check_test
|
||||
- pytorch_linux_xenial_py3_6_gcc5_4_jit_legacy_test
|
||||
- pytorch_linux_xenial_py3_6_gcc5_4_test
|
||||
- pytorch_python_doc_build
|
||||
|
|
|
|||
|
|
@ -141,6 +141,15 @@ struct Argument {
|
|||
const Argument& old,
|
||||
std::ostream* why_not=nullptr) const;
|
||||
|
||||
// this function checks whether this Argument is forward compatible with
|
||||
// the old one. we consider the following cases are forward compatible:
|
||||
// 1) two arguments are equal
|
||||
// 2) this arg's type should be subtype of old
|
||||
// 3) this arg must provide the same default value if old arg has one,
|
||||
bool isForwardCompatibleWith(
|
||||
const Argument& old,
|
||||
std::ostream* why_not = nullptr) const;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
TypePtr type_;
|
||||
|
|
@ -238,6 +247,28 @@ struct FunctionSchema {
|
|||
const FunctionSchema& old,
|
||||
std::ostream* why_not = nullptr) const;
|
||||
|
||||
// Checks whether this schema is forward compatible with the old one.
|
||||
// The following conditions must be true:
|
||||
// [Function structure] The new schema's name, overload-name, varargs, and
|
||||
// return arity are the same.
|
||||
// [Output Narrowing] The new schema's output type must be the same class
|
||||
// or inherit from the old schema's output type.
|
||||
// [Arg Compatibility] Every argument in the old schema has a corresponding
|
||||
// argument in the new schema that:
|
||||
// * is at the same position.
|
||||
// * has the same name.
|
||||
// * is either positional, or kwarg and the old argument was kwarg.
|
||||
// * has the same type, or the old argument's type inherits from the
|
||||
// new argument's type.
|
||||
// [Default Values] Every new argument must have a default value.
|
||||
// Each default value type should NOT be a container type.
|
||||
// [Positioning] All defaults arguments MUST go after either old
|
||||
// default arguments or the end of positional arguments
|
||||
// and right BEFORE all out arguments
|
||||
bool isForwardCompatibleWith(
|
||||
const FunctionSchema& old,
|
||||
std::ostringstream& why_not) const;
|
||||
|
||||
private:
|
||||
OperatorName name_;
|
||||
std::vector<Argument> arguments_;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#pragma once
|
||||
#include <iostream>
|
||||
|
||||
// note: windows build doesn't find symbols in operator files unless
|
||||
// this is a header file
|
||||
|
|
@ -86,6 +87,34 @@ inline bool Argument::isBackwardCompatibleWith(
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool Argument::isForwardCompatibleWith(
|
||||
const Argument& old,
|
||||
std::ostream* why_not) const {
|
||||
const Argument* lhs = this;
|
||||
const Argument* rhs = &old;
|
||||
if (!(lhs->name() == rhs->name()
|
||||
&& lhs->N() == rhs->N()
|
||||
&& (lhs->alias_info() == rhs->alias_info()
|
||||
|| (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
|
||||
&& *lhs->alias_info() == *rhs->alias_info())))) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->kwarg_only() && !rhs->kwarg_only()) {
|
||||
return false;
|
||||
}
|
||||
if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) {
|
||||
return false;
|
||||
}
|
||||
if (rhs->default_value().has_value() &&
|
||||
lhs->default_value() != rhs->default_value()) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->default_value().has_value() && !rhs->default_value().has_value()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline std::string FunctionSchema::formatTypeMismatchMsg(
|
||||
const Argument& expected,
|
||||
const std::string& actual_type,
|
||||
|
|
@ -145,7 +174,7 @@ inline bool FunctionSchema::isBackwardCompatibleWith(
|
|||
}
|
||||
}
|
||||
|
||||
// // Validate that all new arguments provided has a default value
|
||||
// Validate that all new arguments provided has a default value
|
||||
for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) {
|
||||
if (!arguments().at(i).default_value()) {
|
||||
if (why_not) {
|
||||
|
|
@ -171,6 +200,86 @@ inline bool FunctionSchema::isBackwardCompatibleWith(
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool FunctionSchema::isForwardCompatibleWith(
|
||||
const FunctionSchema& old,
|
||||
std::ostringstream& why_not) const {
|
||||
if (!(name() == old.name() &&
|
||||
overload_name() == old.overload_name()
|
||||
// we are conservative on is_vararg and is_varret,
|
||||
// since they are only used by internal operators
|
||||
&& is_vararg() == old.is_vararg() && is_varret() == old.is_varret() &&
|
||||
returns().size() == old.returns().size())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// we want to test both out and default args seperately
|
||||
size_t old_out_start_idx = findFirstOutArg(old.arguments());
|
||||
size_t new_out_start_idx = findFirstOutArg(arguments());
|
||||
|
||||
if (old.arguments().size() - old_out_start_idx !=
|
||||
arguments().size() - new_out_start_idx) {
|
||||
if (why_not) {
|
||||
why_not << "Function schema should have the "
|
||||
<< "same number of out arguments";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure among the default args, they are forward compatible
|
||||
for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) {
|
||||
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not
|
||||
<< "'" << arguments().at(i).name() << "'"
|
||||
<< " is not forward compatible with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that all new arguments provided has a default value
|
||||
for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
|
||||
if (!arguments().at(i).default_value()) {
|
||||
if (why_not) {
|
||||
why_not
|
||||
<< "Function schema is not forward compatible since the new argument '"
|
||||
<< arguments().at(i).name() << "' of type "
|
||||
<< arguments().at(i).type()->str()
|
||||
<< " did not provide a default value.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
auto default_val = arguments().at(i).default_value().value();
|
||||
if (default_val.isList() || default_val.isGenericDict()) {
|
||||
if (why_not) {
|
||||
why_not
|
||||
<< "Function schema is not forward compatible since the new argument '"
|
||||
<< arguments().at(i).name() << "' of type "
|
||||
<< arguments().at(i).type()->str() << " has a container type "
|
||||
<< "as its default value.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// now compare the out args
|
||||
for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
|
||||
if (!arguments()
|
||||
.at(i - old_out_start_idx + new_out_start_idx)
|
||||
.isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not << "Out argument '"
|
||||
<< "'" << arguments().at(i).name()
|
||||
<< " is not FC with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void FunctionSchema::checkArg(
|
||||
const IValue& value,
|
||||
const Argument& argument,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
import datetime
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
|
@ -149,14 +150,16 @@ def dont_parse(schema_line):
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_bc(existing_schemas):
|
||||
def load_schemas_to_dict():
|
||||
new_schemas = torch._C._jit_get_all_schemas()
|
||||
new_schemas += torch._C._jit_get_custom_class_schemas()
|
||||
new_schema_dict = defaultdict(list)
|
||||
for s in new_schemas:
|
||||
new_schema_dict[s.name].append(s)
|
||||
return new_schema_dict
|
||||
|
||||
def check_bc(existing_schemas):
|
||||
new_schema_dict = load_schemas_to_dict()
|
||||
is_bc = True
|
||||
broken_ops = []
|
||||
for existing_schema in existing_schemas:
|
||||
|
|
@ -192,6 +195,51 @@ def check_bc(existing_schemas):
|
|||
)
|
||||
return is_bc
|
||||
|
||||
def check_fc(existing_schemas):
|
||||
new_schema_dict = load_schemas_to_dict()
|
||||
is_fc = True
|
||||
broken_ops = []
|
||||
for existing_schema in existing_schemas:
|
||||
if allow_listed(existing_schema):
|
||||
print("schema: ", str(existing_schema), " found on allowlist, skipping")
|
||||
continue
|
||||
print("processing existing schema: ", str(existing_schema))
|
||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||
found = False
|
||||
possible_failure_reasons = []
|
||||
for matching_new_schema in matching_new_schemas:
|
||||
is_compatible, reason = matching_new_schema.check_forward_compatible_with(existing_schema)
|
||||
if is_compatible:
|
||||
found = True
|
||||
break
|
||||
if reason != "":
|
||||
possible_failure_reasons.append(reason)
|
||||
if not found:
|
||||
print(
|
||||
"Can NOT find forward compatible schemas after changes "
|
||||
"for schema {} from the following candidates:\n[\n{}\n]".format(
|
||||
str(existing_schema),
|
||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Refer to following reasons for failure "
|
||||
"to find FC schema:\n[\n{}\n]".format(
|
||||
"\n\t".join(str(r) for r in possible_failure_reasons)
|
||||
)
|
||||
)
|
||||
broken_ops.append(str(existing_schema))
|
||||
is_fc = False
|
||||
if is_fc:
|
||||
print("Found forward compatible schemas for all existing schemas")
|
||||
else:
|
||||
warnings.warn(
|
||||
"The PR is introducing a potentially forward incompatible changes to the "
|
||||
"operator library. Please contact PyTorch team to confirm "
|
||||
"whether this change is wanted or not. \n\nBroken ops: "
|
||||
"[\n\t{}\n]".format("\n\t".join(broken_ops))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Process some integers.")
|
||||
|
|
@ -216,5 +264,9 @@ if __name__ == "__main__":
|
|||
s = parse_schema(line.strip())
|
||||
slist.append(s)
|
||||
|
||||
# TODO in case there is FC breaking changes,
|
||||
# we just warn for now until there is a policy.
|
||||
check_fc(slist)
|
||||
|
||||
if not check_bc(slist):
|
||||
sys.exit(1)
|
||||
|
|
@ -111,7 +111,70 @@ class TestFunctionSchema(TestCase):
|
|||
def test_string_optional_parameter_default_value(self):
|
||||
schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)")
|
||||
schema_b = parse_schema(str(schema_a))
|
||||
self.assertEquals(schema_a, schema_b)
|
||||
self.assertEqual(schema_a, schema_b)
|
||||
|
||||
def test_forward_compatible_arguments_without_out(self):
|
||||
old_schema = parse_schema('any(Tensor self, int a, int b=1) -> Tensor')
|
||||
# deleting default arg is FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a) -> Tensor')
|
||||
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertTrue(is_fc)
|
||||
# adding default arg is FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a, int b=1, int c=1) -> Tensor')
|
||||
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertTrue(is_fc)
|
||||
# adding default arg with container type is NOT FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a, int b=1, int[2] c=1) -> Tensor')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "Function schema is not forward compatible since the new argument"
|
||||
" \'c\' of type int[] has a container type as its default value.")
|
||||
# updating the default value of a default arg is NOT FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a, int b=4) -> Tensor')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "\'b\' is not forward compatible with the older version of the schema")
|
||||
# updating the arg name of a default arg is NOT FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a, int c=1) -> Tensor')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "\'c\' is not forward compatible with the older version of the schema")
|
||||
# not adding default arg in the end is NOT FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a, int c=1, int b=1) -> Tensor')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "\'c\' is not forward compatible with the older version of the schema")
|
||||
# making default arg into positional arg is NOT FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a, int b) -> Tensor')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "\'b\' is not forward compatible with the older version of the schema")
|
||||
# making positional arg into default arg is NOT FC compatible
|
||||
new_schema = parse_schema('any(Tensor self, int a=1, int b=1) -> Tensor')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "\'a\' is not forward compatible with the older version of the schema")
|
||||
|
||||
def test_forward_compatible_arguments_real_use_case(self):
|
||||
# this change introduced forward incompatibility in the past
|
||||
old_slice_schema = parse_schema('slice(Tensor(a) self, int dim=0, int start=0, int end=0, int step=1) -> Tensor(a)')
|
||||
new_slice_schema = parse_schema('slice(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)')
|
||||
is_fc, reason = new_slice_schema.check_forward_compatible_with(old_slice_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "\'start\' is not forward compatible with the older version of the schema")
|
||||
|
||||
def test_forward_compatible_arguments_with_out(self):
|
||||
old_schema = parse_schema('any(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)')
|
||||
new_schema = parse_schema('any(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)')
|
||||
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertTrue(is_fc)
|
||||
new_schema = parse_schema('any(Tensor self, *, int a, int b=1, int c=1, Tensor(a!) out) -> Tensor(a!)')
|
||||
is_fc, _ = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertTrue(is_fc)
|
||||
new_schema = parse_schema('any(Tensor self, *, int a, Tensor(d!) d, int b=1, Tensor(a!) out) -> Tensor(a!)')
|
||||
is_fc, reason = new_schema.check_forward_compatible_with(old_schema)
|
||||
self.assertFalse(is_fc)
|
||||
self.assertEqual(reason, "Function schema should have the same number of out arguments")
|
||||
|
||||
def test_schema_error(self):
|
||||
with self.assertRaisesRegex(RuntimeError, r"schemas with vararg \(...\) can't have default value args"):
|
||||
|
|
|
|||
|
|
@ -1390,6 +1390,13 @@ void initJITBindings(PyObject* module) {
|
|||
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
|
||||
return self.isBackwardCompatibleWith(old_schema);
|
||||
})
|
||||
.def(
|
||||
"check_forward_compatible_with",
|
||||
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
|
||||
std::ostringstream out;
|
||||
auto result = self.isForwardCompatibleWith(old_schema, out);
|
||||
return std::make_pair(result, out.str());
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const FunctionSchema& self, const FunctionSchema& other) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user