[BC] Add check for core ATen opset schema BC (#137664)

Summary: Based on core ATen opset BC policy: https://dev-discuss.pytorch.org/t/core-aten-opset-backward-forward-compatibility-policy/1772

Encorcing this policy in `check_forward_backward_compatibility.py`.
Basically the script will error out if any BC breaking schema changes
occurs to core ATen operators.

Test Plan:

Run `python test/forward_backward_compatibility/dump_all_function_schemas.py --filename nightly_schemas.txt`

Manually added a argument to `nightly_schemas.txt`, `convolution`
schema, see the following error:

```
[WARNING 2024-10-09 15:54:36,224 check_forward_backward_compatibility.py:329] Can NOT find backward compatible schemas after changes for schema aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, SymInt new_arg) -> Tensor from the following candidates:
[
        aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
	aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
]. Please contact PyTorch team to confirm if this BC breaking change is safe or not.
...
[WARNING 2024-10-09 15:54:36,224 check_forward_backward_compatibility.py:342] The PR is introducing backward incompatible changes to core ATen operators. Please contact PyTorch team to confirm whether this change is wanted or not.

Broken ops: [
	aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, SymInt new_arg) -> Tensor
]
```
Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137664
Approved by: https://github.com/albanD
This commit is contained in:
Mengwei Liu 2024-10-15 10:03:47 -07:00 committed by PyTorch MergeBot
parent 21a9c06ca9
commit 7365a57dc0

View File

@ -1,14 +1,17 @@
import argparse
import datetime
import logging
import re
import sys
import warnings
from collections import defaultdict
import torch
from torch._C import parse_schema
from torch._C import parse_schema, Tag
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
# How to run this test locally:
# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly)
# one with your local changes (venv_yours).
@ -22,7 +25,10 @@ from torch._C import parse_schema
# 5. Run this test with
# `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt`
# The date specifies how long the allowlist exclusion should apply to.
# The date specifies how long the allowlist exclusion should apply to. Note that core ATen opset
# (https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir) is guaranteed to be BC, based on this policy
# (https://dev-discuss.pytorch.org/t/core-aten-opset-backward-forward-compatibility-policy/1772) and hence the
# allowlist does not apply (or the date is always arbitrarily far for core ATen ops).
#
# - If we NEVER give BC guarantee for an operator, you can put the
# date arbitrarily far in the future.
@ -228,6 +234,14 @@ def process_version_map(version_map):
return output
def is_core_aten_op(schema) -> bool:
# Check if the schema is a core ATen op
if "::" not in schema.name:
return False
_, _, tags = torch._C._get_operation_overload(schema.name, schema.overload_name)
return Tag.core in tags
def check_bc(existing_schemas):
new_schema_dict = load_schemas_to_dict()
version_map = process_version_map(torch._C._get_operator_version_map())
@ -235,12 +249,23 @@ def check_bc(existing_schemas):
broken_ops = []
for existing_schema in existing_schemas:
if allow_listed(existing_schema):
print("schema: ", str(existing_schema), " found on allowlist, skipping")
continue
if not is_core_aten_op(existing_schema):
logging.info("schema: %s found on allowlist, skipping", existing_schema)
continue
else:
logging.info(
"schema: %s found on allowlist, but is a core ATen op, checking BC",
existing_schema,
)
if has_valid_upgraders(existing_schema, version_map):
print("schema: ", str(existing_schema), " has valid upgrader, skipping")
continue
print("processing existing schema: ", str(existing_schema))
if not is_core_aten_op(existing_schema):
logging.info("schema: %s has valid upgrader, skipping", existing_schema)
continue
else:
logging.info(
"schema: %s has a valid upgrader, but is a core ATen op, checking BC"
)
logging.debug("processing existing schema: %s", existing_schema)
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
found = False
for matching_new_schema in matching_new_schemas:
@ -248,24 +273,24 @@ def check_bc(existing_schemas):
found = True
break
if not found:
print(
logging.warning(
"Can NOT find backward compatible schemas after changes "
"for schema {} from the following candidates:\n[\n{}\n]".format(
str(existing_schema),
"\n\t".join(str(s) for s in matching_new_schemas),
)
"for schema %s from the following candidates:\n[\n%s\n]",
str(existing_schema),
"\n\t".join(str(s) for s in matching_new_schemas),
)
# TODO Print out more details about why candidates don't match.
broken_ops.append(str(existing_schema))
is_bc = False
if is_bc:
print("Found backward compatible schemas for all existing schemas")
logging.info("Found backward compatible schemas for all existing schemas")
else:
print(
logging.warning(
"The PR is introducing backward incompatible changes to the "
"operator library. Please contact PyTorch team to confirm "
"whether this change is wanted or not. \n\nBroken ops: "
"[\n\t{}\n]".format("\n\t".join(broken_ops))
"[\n\t%s\n]",
"\n\t".join(broken_ops),
)
return is_bc
@ -276,9 +301,9 @@ def check_fc(existing_schemas):
broken_ops = []
for existing_schema in existing_schemas:
if allow_listed(existing_schema):
print("schema: ", str(existing_schema), " found on allowlist, skipping")
logging.info("schema: %s found on allowlist, skipping", existing_schema)
continue
print("processing existing schema: ", str(existing_schema))
logging.info("processing existing schema: %s", existing_schema)
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
found = False
possible_failure_reasons = []
@ -292,29 +317,28 @@ def check_fc(existing_schemas):
if reason != "":
possible_failure_reasons.append(reason)
if not found:
print(
logging.warning(
"Can NOT find forward compatible schemas after changes "
"for schema {} from the following candidates:\n[\n{}\n]".format(
str(existing_schema),
"\n\t".join(str(s) for s in matching_new_schemas),
)
"for schema %s from the following candidates:\n[\n\t%s\n]",
str(existing_schema),
"\n\t".join(str(s) for s in matching_new_schemas),
)
print(
logging.warning(
"Refer to following reasons for failure "
"to find FC schema:\n[\n{}\n]".format(
"\n\t".join(str(r) for r in possible_failure_reasons)
)
"to find FC schema:\n[\n%s\n]",
"\n\t".join(str(r) for r in possible_failure_reasons),
)
broken_ops.append(str(existing_schema))
is_fc = False
if is_fc:
print("Found forward compatible schemas for all existing schemas")
logging.info("Found forward compatible schemas for all existing schemas")
else:
warnings.warn(
logging.warning(
"The PR is introducing a potentially forward incompatible changes to the "
"operator library. Please contact PyTorch team to confirm "
"whether this change is wanted or not. \n\nBroken ops: "
"[\n\t{}\n]".format("\n\t".join(broken_ops))
"[\n\t%s\n]",
"\n\t".join(broken_ops),
)
@ -336,7 +360,7 @@ if __name__ == "__main__":
break
if dont_parse(line.strip()):
print("Not parsing schema line: ", line.strip())
logging.info("Not parsing schema line: %s", line.strip())
continue
s = parse_schema(line.strip())
slist.append(s)