diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 530e90768f9..7fe56facf5b 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -1,14 +1,17 @@ import argparse import datetime +import logging import re import sys -import warnings from collections import defaultdict import torch -from torch._C import parse_schema +from torch._C import parse_schema, Tag +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + # How to run this test locally: # 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly) # one with your local changes (venv_yours). @@ -22,7 +25,10 @@ from torch._C import parse_schema # 5. Run this test with # `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt` -# The date specifies how long the allowlist exclusion should apply to. +# The date specifies how long the allowlist exclusion should apply to. Note that core ATen opset +# (https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir) is guaranteed to be BC, based on this policy +# (https://dev-discuss.pytorch.org/t/core-aten-opset-backward-forward-compatibility-policy/1772) and hence the +# allowlist does not apply (or the date is always arbitrarily far for core ATen ops). # # - If we NEVER give BC guarantee for an operator, you can put the # date arbitrarily far in the future. @@ -228,6 +234,14 @@ def process_version_map(version_map): return output +def is_core_aten_op(schema) -> bool: + # Check if the schema is a core ATen op + if "::" not in schema.name: + return False + _, _, tags = torch._C._get_operation_overload(schema.name, schema.overload_name) + return Tag.core in tags + + def check_bc(existing_schemas): new_schema_dict = load_schemas_to_dict() version_map = process_version_map(torch._C._get_operator_version_map()) @@ -235,12 +249,23 @@ def check_bc(existing_schemas): broken_ops = [] for existing_schema in existing_schemas: if allow_listed(existing_schema): - print("schema: ", str(existing_schema), " found on allowlist, skipping") - continue + if not is_core_aten_op(existing_schema): + logging.info("schema: %s found on allowlist, skipping", existing_schema) + continue + else: + logging.info( + "schema: %s found on allowlist, but is a core ATen op, checking BC", + existing_schema, + ) if has_valid_upgraders(existing_schema, version_map): - print("schema: ", str(existing_schema), " has valid upgrader, skipping") - continue - print("processing existing schema: ", str(existing_schema)) + if not is_core_aten_op(existing_schema): + logging.info("schema: %s has valid upgrader, skipping", existing_schema) + continue + else: + logging.info( + "schema: %s has a valid upgrader, but is a core ATen op, checking BC" + ) + logging.debug("processing existing schema: %s", existing_schema) matching_new_schemas = new_schema_dict.get(existing_schema.name, []) found = False for matching_new_schema in matching_new_schemas: @@ -248,24 +273,24 @@ def check_bc(existing_schemas): found = True break if not found: - print( + logging.warning( "Can NOT find backward compatible schemas after changes " - "for schema {} from the following candidates:\n[\n{}\n]".format( - str(existing_schema), - "\n\t".join(str(s) for s in matching_new_schemas), - ) + "for schema %s from the following candidates:\n[\n%s\n]", + str(existing_schema), + "\n\t".join(str(s) for s in matching_new_schemas), ) # TODO Print out more details about why candidates don't match. broken_ops.append(str(existing_schema)) is_bc = False if is_bc: - print("Found backward compatible schemas for all existing schemas") + logging.info("Found backward compatible schemas for all existing schemas") else: - print( + logging.warning( "The PR is introducing backward incompatible changes to the " "operator library. Please contact PyTorch team to confirm " "whether this change is wanted or not. \n\nBroken ops: " - "[\n\t{}\n]".format("\n\t".join(broken_ops)) + "[\n\t%s\n]", + "\n\t".join(broken_ops), ) return is_bc @@ -276,9 +301,9 @@ def check_fc(existing_schemas): broken_ops = [] for existing_schema in existing_schemas: if allow_listed(existing_schema): - print("schema: ", str(existing_schema), " found on allowlist, skipping") + logging.info("schema: %s found on allowlist, skipping", existing_schema) continue - print("processing existing schema: ", str(existing_schema)) + logging.info("processing existing schema: %s", existing_schema) matching_new_schemas = new_schema_dict.get(existing_schema.name, []) found = False possible_failure_reasons = [] @@ -292,29 +317,28 @@ def check_fc(existing_schemas): if reason != "": possible_failure_reasons.append(reason) if not found: - print( + logging.warning( "Can NOT find forward compatible schemas after changes " - "for schema {} from the following candidates:\n[\n{}\n]".format( - str(existing_schema), - "\n\t".join(str(s) for s in matching_new_schemas), - ) + "for schema %s from the following candidates:\n[\n\t%s\n]", + str(existing_schema), + "\n\t".join(str(s) for s in matching_new_schemas), ) - print( + logging.warning( "Refer to following reasons for failure " - "to find FC schema:\n[\n{}\n]".format( - "\n\t".join(str(r) for r in possible_failure_reasons) - ) + "to find FC schema:\n[\n%s\n]", + "\n\t".join(str(r) for r in possible_failure_reasons), ) broken_ops.append(str(existing_schema)) is_fc = False if is_fc: - print("Found forward compatible schemas for all existing schemas") + logging.info("Found forward compatible schemas for all existing schemas") else: - warnings.warn( + logging.warning( "The PR is introducing a potentially forward incompatible changes to the " "operator library. Please contact PyTorch team to confirm " "whether this change is wanted or not. \n\nBroken ops: " - "[\n\t{}\n]".format("\n\t".join(broken_ops)) + "[\n\t%s\n]", + "\n\t".join(broken_ops), ) @@ -336,7 +360,7 @@ if __name__ == "__main__": break if dont_parse(line.strip()): - print("Not parsing schema line: ", line.strip()) + logging.info("Not parsing schema line: %s", line.strip()) continue s = parse_schema(line.strip()) slist.append(s)