mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BC] Add check for core ATen opset schema BC (#137664)
Summary: Based on core ATen opset BC policy: https://dev-discuss.pytorch.org/t/core-aten-opset-backward-forward-compatibility-policy/1772 Encorcing this policy in `check_forward_backward_compatibility.py`. Basically the script will error out if any BC breaking schema changes occurs to core ATen operators. Test Plan: Run `python test/forward_backward_compatibility/dump_all_function_schemas.py --filename nightly_schemas.txt` Manually added a argument to `nightly_schemas.txt`, `convolution` schema, see the following error: ``` [WARNING 2024-10-09 15:54:36,224 check_forward_backward_compatibility.py:329] Can NOT find backward compatible schemas after changes for schema aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, SymInt new_arg) -> Tensor from the following candidates: [ aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) ]. Please contact PyTorch team to confirm if this BC breaking change is safe or not. ... [WARNING 2024-10-09 15:54:36,224 check_forward_backward_compatibility.py:342] The PR is introducing backward incompatible changes to core ATen operators. Please contact PyTorch team to confirm whether this change is wanted or not. Broken ops: [ aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, SymInt new_arg) -> Tensor ] ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137664 Approved by: https://github.com/albanD
This commit is contained in:
parent
21a9c06ca9
commit
7365a57dc0
|
|
@ -1,14 +1,17 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch._C import parse_schema
|
||||
from torch._C import parse_schema, Tag
|
||||
|
||||
|
||||
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
|
||||
# How to run this test locally:
|
||||
# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly)
|
||||
# one with your local changes (venv_yours).
|
||||
|
|
@ -22,7 +25,10 @@ from torch._C import parse_schema
|
|||
# 5. Run this test with
|
||||
# `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt`
|
||||
|
||||
# The date specifies how long the allowlist exclusion should apply to.
|
||||
# The date specifies how long the allowlist exclusion should apply to. Note that core ATen opset
|
||||
# (https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir) is guaranteed to be BC, based on this policy
|
||||
# (https://dev-discuss.pytorch.org/t/core-aten-opset-backward-forward-compatibility-policy/1772) and hence the
|
||||
# allowlist does not apply (or the date is always arbitrarily far for core ATen ops).
|
||||
#
|
||||
# - If we NEVER give BC guarantee for an operator, you can put the
|
||||
# date arbitrarily far in the future.
|
||||
|
|
@ -228,6 +234,14 @@ def process_version_map(version_map):
|
|||
return output
|
||||
|
||||
|
||||
def is_core_aten_op(schema) -> bool:
|
||||
# Check if the schema is a core ATen op
|
||||
if "::" not in schema.name:
|
||||
return False
|
||||
_, _, tags = torch._C._get_operation_overload(schema.name, schema.overload_name)
|
||||
return Tag.core in tags
|
||||
|
||||
|
||||
def check_bc(existing_schemas):
|
||||
new_schema_dict = load_schemas_to_dict()
|
||||
version_map = process_version_map(torch._C._get_operator_version_map())
|
||||
|
|
@ -235,12 +249,23 @@ def check_bc(existing_schemas):
|
|||
broken_ops = []
|
||||
for existing_schema in existing_schemas:
|
||||
if allow_listed(existing_schema):
|
||||
print("schema: ", str(existing_schema), " found on allowlist, skipping")
|
||||
continue
|
||||
if not is_core_aten_op(existing_schema):
|
||||
logging.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||
continue
|
||||
else:
|
||||
logging.info(
|
||||
"schema: %s found on allowlist, but is a core ATen op, checking BC",
|
||||
existing_schema,
|
||||
)
|
||||
if has_valid_upgraders(existing_schema, version_map):
|
||||
print("schema: ", str(existing_schema), " has valid upgrader, skipping")
|
||||
continue
|
||||
print("processing existing schema: ", str(existing_schema))
|
||||
if not is_core_aten_op(existing_schema):
|
||||
logging.info("schema: %s has valid upgrader, skipping", existing_schema)
|
||||
continue
|
||||
else:
|
||||
logging.info(
|
||||
"schema: %s has a valid upgrader, but is a core ATen op, checking BC"
|
||||
)
|
||||
logging.debug("processing existing schema: %s", existing_schema)
|
||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||
found = False
|
||||
for matching_new_schema in matching_new_schemas:
|
||||
|
|
@ -248,24 +273,24 @@ def check_bc(existing_schemas):
|
|||
found = True
|
||||
break
|
||||
if not found:
|
||||
print(
|
||||
logging.warning(
|
||||
"Can NOT find backward compatible schemas after changes "
|
||||
"for schema {} from the following candidates:\n[\n{}\n]".format(
|
||||
str(existing_schema),
|
||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
"for schema %s from the following candidates:\n[\n%s\n]",
|
||||
str(existing_schema),
|
||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
# TODO Print out more details about why candidates don't match.
|
||||
broken_ops.append(str(existing_schema))
|
||||
is_bc = False
|
||||
if is_bc:
|
||||
print("Found backward compatible schemas for all existing schemas")
|
||||
logging.info("Found backward compatible schemas for all existing schemas")
|
||||
else:
|
||||
print(
|
||||
logging.warning(
|
||||
"The PR is introducing backward incompatible changes to the "
|
||||
"operator library. Please contact PyTorch team to confirm "
|
||||
"whether this change is wanted or not. \n\nBroken ops: "
|
||||
"[\n\t{}\n]".format("\n\t".join(broken_ops))
|
||||
"[\n\t%s\n]",
|
||||
"\n\t".join(broken_ops),
|
||||
)
|
||||
return is_bc
|
||||
|
||||
|
|
@ -276,9 +301,9 @@ def check_fc(existing_schemas):
|
|||
broken_ops = []
|
||||
for existing_schema in existing_schemas:
|
||||
if allow_listed(existing_schema):
|
||||
print("schema: ", str(existing_schema), " found on allowlist, skipping")
|
||||
logging.info("schema: %s found on allowlist, skipping", existing_schema)
|
||||
continue
|
||||
print("processing existing schema: ", str(existing_schema))
|
||||
logging.info("processing existing schema: %s", existing_schema)
|
||||
matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
|
||||
found = False
|
||||
possible_failure_reasons = []
|
||||
|
|
@ -292,29 +317,28 @@ def check_fc(existing_schemas):
|
|||
if reason != "":
|
||||
possible_failure_reasons.append(reason)
|
||||
if not found:
|
||||
print(
|
||||
logging.warning(
|
||||
"Can NOT find forward compatible schemas after changes "
|
||||
"for schema {} from the following candidates:\n[\n{}\n]".format(
|
||||
str(existing_schema),
|
||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
"for schema %s from the following candidates:\n[\n\t%s\n]",
|
||||
str(existing_schema),
|
||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
print(
|
||||
logging.warning(
|
||||
"Refer to following reasons for failure "
|
||||
"to find FC schema:\n[\n{}\n]".format(
|
||||
"\n\t".join(str(r) for r in possible_failure_reasons)
|
||||
)
|
||||
"to find FC schema:\n[\n%s\n]",
|
||||
"\n\t".join(str(r) for r in possible_failure_reasons),
|
||||
)
|
||||
broken_ops.append(str(existing_schema))
|
||||
is_fc = False
|
||||
if is_fc:
|
||||
print("Found forward compatible schemas for all existing schemas")
|
||||
logging.info("Found forward compatible schemas for all existing schemas")
|
||||
else:
|
||||
warnings.warn(
|
||||
logging.warning(
|
||||
"The PR is introducing a potentially forward incompatible changes to the "
|
||||
"operator library. Please contact PyTorch team to confirm "
|
||||
"whether this change is wanted or not. \n\nBroken ops: "
|
||||
"[\n\t{}\n]".format("\n\t".join(broken_ops))
|
||||
"[\n\t%s\n]",
|
||||
"\n\t".join(broken_ops),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -336,7 +360,7 @@ if __name__ == "__main__":
|
|||
break
|
||||
|
||||
if dont_parse(line.strip()):
|
||||
print("Not parsing schema line: ", line.strip())
|
||||
logging.info("Not parsing schema line: %s", line.strip())
|
||||
continue
|
||||
s = parse_schema(line.strip())
|
||||
slist.append(s)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user