[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan 2024-06-29 12:48:06 +08:00 committed by PyTorch MergeBot
parent 58f346c874
commit 8a67daf283
123 changed files with 1274 additions and 1053 deletions

View File

@ -1,16 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import json import json
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
from difflib import SequenceMatcher from difflib import SequenceMatcher
from typing import Any, Dict, List, Set, Tuple from typing import Any
import requests import requests
from setuptools import distutils # type: ignore[import] from setuptools import distutils # type: ignore[import]
ALL_SKIPPED_THRESHOLD = 100 ALL_SKIPPED_THRESHOLD = 100
SIMILARITY_THRESHOLD = 0.75 SIMILARITY_THRESHOLD = 0.75
FAILURE_CHAIN_THRESHOLD = 2 FAILURE_CHAIN_THRESHOLD = 2
@ -65,14 +68,14 @@ DISABLED_ALERTS = [
class JobStatus: class JobStatus:
job_name: str = "" job_name: str = ""
jobs: List[Any] = [] jobs: list[Any] = []
current_status: Any = None current_status: Any = None
job_statuses: List[Any] = [] job_statuses: list[Any] = []
filtered_statuses: List[Any] = [] filtered_statuses: list[Any] = []
failure_chain: List[Any] = [] failure_chain: list[Any] = []
flaky_jobs: List[Any] = [] flaky_jobs: list[Any] = []
def __init__(self, job_name: str, job_statuses: List[Any]): def __init__(self, job_name: str, job_statuses: list[Any]) -> None:
self.job_name = job_name self.job_name = job_name
self.job_statuses = job_statuses self.job_statuses = job_statuses
@ -93,7 +96,7 @@ class JobStatus:
return status return status
return None return None
def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]: def get_unique_failures(self, jobs: list[Any]) -> dict[str, list[Any]]:
""" """
Returns list of jobs grouped by failureCaptures from the input list Returns list of jobs grouped by failureCaptures from the input list
""" """
@ -120,7 +123,7 @@ class JobStatus:
return failures return failures
# A flaky job is if it's the only job that has that failureCapture and is not the most recent job # A flaky job is if it's the only job that has that failureCapture and is not the most recent job
def get_flaky_jobs(self) -> List[Any]: def get_flaky_jobs(self) -> list[Any]:
unique_failures = self.get_unique_failures(self.filtered_statuses) unique_failures = self.get_unique_failures(self.filtered_statuses)
flaky_jobs = [] flaky_jobs = []
for failure in unique_failures: for failure in unique_failures:
@ -134,7 +137,7 @@ class JobStatus:
# The most recent failure chain is an array of jobs that have the same-ish failures. # The most recent failure chain is an array of jobs that have the same-ish failures.
# A success in the middle of the chain will terminate the chain. # A success in the middle of the chain will terminate the chain.
def get_most_recent_failure_chain(self) -> List[Any]: def get_most_recent_failure_chain(self) -> list[Any]:
failures = [] failures = []
found_most_recent_failure = False found_most_recent_failure = False
@ -178,7 +181,7 @@ def fetch_hud_data(repo: str, branch: str) -> Any:
# Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD # Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD
def map_job_data(jobNames: Any, shaGrid: Any) -> Dict[str, Any]: def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
jobData = defaultdict(list) jobData = defaultdict(list)
for sha in shaGrid: for sha in shaGrid:
for ind, job in enumerate(sha["jobs"]): for ind, job in enumerate(sha["jobs"]):
@ -196,13 +199,13 @@ def is_job_skipped(job: Any) -> bool:
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
def get_failed_jobs(job_data: List[Any]) -> List[Any]: def get_failed_jobs(job_data: list[Any]) -> list[Any]:
return [job for job in job_data if job["conclusion"] == "failure"] return [job for job in job_data if job["conclusion"] == "failure"]
def classify_jobs( def classify_jobs(
all_job_names: List[str], sha_grid: Any, filtered_jobs_names: Set[str] all_job_names: list[str], sha_grid: Any, filtered_jobs_names: set[str]
) -> Tuple[List[JobStatus], List[Any]]: ) -> tuple[list[JobStatus], list[Any]]:
""" """
Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs. Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs.
Classifies jobs into jobs to alert on and flaky jobs. Classifies jobs into jobs to alert on and flaky jobs.
@ -212,7 +215,7 @@ def classify_jobs(
:return: :return:
""" """
job_data = map_job_data(all_job_names, sha_grid) job_data = map_job_data(all_job_names, sha_grid)
job_statuses: List[JobStatus] = [] job_statuses: list[JobStatus] = []
for job in job_data: for job in job_data:
job_statuses.append(JobStatus(job, job_data[job])) job_statuses.append(JobStatus(job, job_data[job]))
@ -230,7 +233,7 @@ def classify_jobs(
# filter job names that don't match the regex # filter job names that don't match the regex
def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]: def filter_job_names(job_names: list[str], job_name_regex: str) -> list[str]:
if job_name_regex: if job_name_regex:
return [ return [
job_name for job_name in job_names if re.match(job_name_regex, job_name) job_name for job_name in job_names if re.match(job_name_regex, job_name)
@ -240,7 +243,7 @@ def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
def get_recurrently_failing_jobs_alerts( def get_recurrently_failing_jobs_alerts(
repo: str, branch: str, job_name_regex: str repo: str, branch: str, job_name_regex: str
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch) job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch)
filtered_job_names = set(filter_job_names(job_names, job_name_regex)) filtered_job_names = set(filter_job_names(job_names, job_name_regex))

View File

@ -14,18 +14,17 @@ generated. In the full build system, OUTPUT_DIR is
torch/testing/_internal/generated torch/testing/_internal/generated
""" """
from __future__ import annotations
import argparse import argparse
import os import os
import textwrap import textwrap
from collections import defaultdict from collections import defaultdict
from typing import Any, Sequence, TYPE_CHECKING
from typing import Any, Dict, List, Sequence
import torchgen.api.python as python import torchgen.api.python as python
from torchgen.context import with_native_function from torchgen.context import with_native_function
from torchgen.gen import parse_native_yaml from torchgen.gen import parse_native_yaml
from torchgen.model import Argument, BaseOperatorName, NativeFunction
from torchgen.utils import FileManager from torchgen.utils import FileManager
from .gen_python_functions import ( from .gen_python_functions import (
@ -39,6 +38,10 @@ from .gen_python_functions import (
) )
if TYPE_CHECKING:
from torchgen.model import Argument, BaseOperatorName, NativeFunction
def gen_annotated( def gen_annotated(
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
) -> None: ) -> None:
@ -53,9 +56,9 @@ def gen_annotated(
(is_py_fft_function, "torch._C._fft"), (is_py_fft_function, "torch._C._fft"),
(is_py_variable_method, "torch.Tensor"), (is_py_variable_method, "torch.Tensor"),
) )
annotated_args: List[str] = [] annotated_args: list[str] = []
for pred, namespace in mappings: for pred, namespace in mappings:
groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
for f in native_functions: for f in native_functions:
if not should_generate_py_binding(f) or not pred(f): if not should_generate_py_binding(f) or not pred(f):
continue continue
@ -77,7 +80,7 @@ def gen_annotated(
@with_native_function @with_native_function
def gen_annotated_args(f: NativeFunction) -> str: def gen_annotated_args(f: NativeFunction) -> str:
def _get_kwargs_func_exclusion_list() -> List[str]: def _get_kwargs_func_exclusion_list() -> list[str]:
# functions that currently don't work with kwargs in test_overrides.py # functions that currently don't work with kwargs in test_overrides.py
return [ return [
"diagonal", "diagonal",
@ -87,12 +90,12 @@ def gen_annotated_args(f: NativeFunction) -> str:
] ]
def _add_out_arg( def _add_out_arg(
out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
) -> None: ) -> None:
for arg in args: for arg in args:
if arg.default is not None: if arg.default is not None:
continue continue
out_arg: Dict[str, Any] = {} out_arg: dict[str, Any] = {}
out_arg["is_kwarg_only"] = str(is_kwarg_only) out_arg["is_kwarg_only"] = str(is_kwarg_only)
out_arg["name"] = arg.name out_arg["name"] = arg.name
out_arg["simple_type"] = python.argument_type_str( out_arg["simple_type"] = python.argument_type_str(
@ -103,7 +106,7 @@ def gen_annotated_args(f: NativeFunction) -> str:
out_arg["size"] = size_t out_arg["size"] = size_t
out_args.append(out_arg) out_args.append(out_arg)
out_args: List[Dict[str, Any]] = [] out_args: list[dict[str, Any]] = []
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False) _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list(): if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True) _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)

View File

@ -22,9 +22,10 @@ torch/csrc/autograd/generated/
# gen_python_functions.py: generates Python bindings to THPVariable # gen_python_functions.py: generates Python bindings to THPVariable
# #
from __future__ import annotations
import argparse import argparse
import os import os
from typing import List
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.autograd import ( from torchgen.api.autograd import (
@ -69,7 +70,7 @@ def gen_autograd(
), ),
key=lambda f: cpp.name(f.func), key=lambda f: cpp.name(f.func),
) )
fns_with_diff_infos: List[ fns_with_diff_infos: list[
NativeFunctionWithDifferentiabilityInfo NativeFunctionWithDifferentiabilityInfo
] = match_differentiability_info(fns, differentiability_infos) ] = match_differentiability_info(fns, differentiability_infos)

View File

@ -4,7 +4,10 @@
# Functions.h/cpp: subclasses of autograd::Node # Functions.h/cpp: subclasses of autograd::Node
# python_functions.h/cpp: Python bindings for the above classes # python_functions.h/cpp: Python bindings for the above classes
# #
from typing import Dict, List, Sequence, Tuple
from __future__ import annotations
from typing import Sequence
from torchgen.api.autograd import ( from torchgen.api.autograd import (
Derivative, Derivative,
@ -43,6 +46,7 @@ from torchgen.utils import FileManager
from .gen_inplace_or_view_type import VIEW_FUNCTIONS from .gen_inplace_or_view_type import VIEW_FUNCTIONS
FUNCTION_DECLARATION = CodeTemplate( FUNCTION_DECLARATION = CodeTemplate(
"""\ """\
#ifdef _WIN32 #ifdef _WIN32
@ -443,8 +447,8 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def get_infos_with_derivatives_list( def get_infos_with_derivatives_list(
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
) -> List[DifferentiabilityInfo]: ) -> list[DifferentiabilityInfo]:
diff_info_list = [ diff_info_list = [
info info
for diffinfo_dict in differentiability_infos.values() for diffinfo_dict in differentiability_infos.values()
@ -456,7 +460,7 @@ def get_infos_with_derivatives_list(
def gen_autograd_functions_lib( def gen_autograd_functions_lib(
out: str, out: str,
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
template_path: str, template_path: str,
) -> None: ) -> None:
"""Functions.h and Functions.cpp body """Functions.h and Functions.cpp body
@ -490,7 +494,7 @@ def gen_autograd_functions_lib(
def gen_autograd_functions_python( def gen_autograd_functions_python(
out: str, out: str,
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
template_path: str, template_path: str,
) -> None: ) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
@ -536,17 +540,17 @@ def gen_autograd_functions_python(
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
saved_variables: List[str] = [] saved_variables: list[str] = []
release_variables: List[str] = [] release_variables: list[str] = []
saved_list_sizes: List[str] = [] saved_list_sizes: list[str] = []
unpack: List[str] = [] unpack: list[str] = []
asserts: List[str] = [] asserts: list[str] = []
compute_index_ranges: List[str] = [] compute_index_ranges: list[str] = []
getter_definitions: List[str] = [] getter_definitions: list[str] = []
py_getsetdef_structs: List[str] = [] py_getsetdef_structs: list[str] = []
compiled_args: List[str] = [] compiled_args: list[str] = []
apply_with_saved_before: List[str] = [] apply_with_saved_before: list[str] = []
apply_with_saved_after: List[str] = [] apply_with_saved_after: list[str] = []
for arg in info.args_with_derivatives: for arg in info.args_with_derivatives:
if arg.type in TENSOR_LIST_LIKE_CTYPES: if arg.type in TENSOR_LIST_LIKE_CTYPES:
@ -807,7 +811,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
else: else:
will_release_variables = "" will_release_variables = ""
body: List[str] = [] body: list[str] = []
if uses_single_grad(info): if uses_single_grad(info):
body.append("const auto& grad = grads[0];") body.append("const auto& grad = grads[0];")
@ -821,7 +825,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
def emit_derivative( def emit_derivative(
derivative: Derivative, derivative: Derivative,
args_with_derivatives: Sequence[Binding], args_with_derivatives: Sequence[Binding],
) -> Tuple[bool, str]: ) -> tuple[bool, str]:
formula = derivative.formula formula = derivative.formula
var_names = derivative.var_names var_names = derivative.var_names
if len(var_names) == 1: if len(var_names) == 1:
@ -857,7 +861,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
else: else:
grad_input_mask = "" grad_input_mask = ""
idx_ranges = ", ".join(f"{n}_ix" for n in var_names) idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
copy_ranges: List[str] = [] copy_ranges: list[str] = []
for i, n in enumerate(var_names): for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return False, DERIVATIVE_MULTI.substitute( return False, DERIVATIVE_MULTI.substitute(

View File

@ -4,7 +4,7 @@
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimick this codegen, so we should keep the two in sync. # The fallback is expected to mimick this codegen, so we should keep the two in sync.
from typing import Dict, List, Optional, Tuple from __future__ import annotations
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.autograd import ( from torchgen.api.autograd import (
@ -24,8 +24,7 @@ from torchgen.api.types import (
OptionalCType, OptionalCType,
symIntArrayRefT, symIntArrayRefT,
SymIntT, SymIntT,
# See Note [Nested Arg Types] tensorT, # See Note [Nested Arg Types]
tensorT,
) )
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
from torchgen.context import with_native_function from torchgen.context import with_native_function
@ -46,6 +45,7 @@ from .gen_trace_type import (
type_wrapper_name, type_wrapper_name,
) )
# See NOTE [ Autograd View Variables ] in variable.h for details. # See NOTE [ Autograd View Variables ] in variable.h for details.
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, # If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
# you **MUST** also update the public list of view ops accordingly in # you **MUST** also update the public list of view ops accordingly in
@ -281,7 +281,7 @@ def inverse_view_name(f: NativeFunction) -> str:
return f"{copy_variant}{overload}_inverse" return f"{copy_variant}{overload}_inverse"
def extract_bindings(f: NativeFunction) -> List[Binding]: def extract_bindings(f: NativeFunction) -> list[Binding]:
return [ return [
r r
for a in f.func.schema_order_arguments() for a in f.func.schema_order_arguments()
@ -297,9 +297,9 @@ def extract_bindings(f: NativeFunction) -> List[Binding]:
@with_native_function @with_native_function
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]: def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
body: List[str] = [] body: list[str] = []
unpacked_bindings: List[Binding] = [] unpacked_bindings: list[Binding] = []
for i, binding in enumerate(extract_bindings(f)): for i, binding in enumerate(extract_bindings(f)):
assert not isinstance(binding.argument, SelfArgument) assert not isinstance(binding.argument, SelfArgument)
@ -338,7 +338,7 @@ def get_base_name(f: NativeFunction) -> str:
return f.func.name.name.base # TODO: should be str(f.func.name.name)? return f.func.name.name.base # TODO: should be str(f.func.name.name)?
def get_view_info(f: NativeFunction) -> Optional[str]: def get_view_info(f: NativeFunction) -> str | None:
base_name = get_base_name(f) base_name = get_base_name(f)
view_info = VIEW_FUNCTIONS.get(base_name, None) view_info = VIEW_FUNCTIONS.get(base_name, None)
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
@ -347,7 +347,7 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
def emit_view_func( def emit_view_func(
f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
) -> str: ) -> str:
"""Generate an additional lambda function to recover views in backward when as_strided is not supported. """Generate an additional lambda function to recover views in backward when as_strided is not supported.
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
@ -355,8 +355,8 @@ def emit_view_func(
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them. # TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
input_base = "input_base" input_base = "input_base"
replay_view_func = "" replay_view_func = ""
updated_args: List[str] = [] updated_args: list[str] = []
known_view_arg_simple_types: List[CType] = [ known_view_arg_simple_types: list[CType] = [
BaseCType(longT), BaseCType(longT),
OptionalCType(BaseCType(longT)), OptionalCType(BaseCType(longT)),
BaseCType(SymIntT), BaseCType(SymIntT),
@ -448,7 +448,7 @@ def emit_view_func(
def emit_view_body( def emit_view_body(
fn: NativeFunctionWithDifferentiabilityInfo, var: str fn: NativeFunctionWithDifferentiabilityInfo, var: str
) -> Tuple[str, str]: ) -> tuple[str, str]:
# See NOTE [ Autograd View Variables ] in variable.h for details. # See NOTE [ Autograd View Variables ] in variable.h for details.
f = fn.func f = fn.func
base_name = get_base_name(f) base_name = get_base_name(f)
@ -523,9 +523,9 @@ def modifies_arguments(f: NativeFunction) -> bool:
@with_native_function_with_differentiability_info @with_native_function_with_differentiability_info
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
f = fn.func f = fn.func
inplace_view_body: List[str] = [] inplace_view_body: list[str] = []
dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_sig = DispatcherSignature.from_schema(f.func)
dispatcher_exprs = dispatcher_sig.exprs() dispatcher_exprs = dispatcher_sig.exprs()
@ -584,7 +584,7 @@ def gen_formals(f: NativeFunction) -> str:
@with_native_function_with_differentiability_info @with_native_function_with_differentiability_info
def inplace_or_view_method_definition( def inplace_or_view_method_definition(
fn: NativeFunctionWithDifferentiabilityInfo, fn: NativeFunctionWithDifferentiabilityInfo,
) -> Optional[str]: ) -> str | None:
f = fn.func f = fn.func
if get_view_info(f) is None and ( if get_view_info(f) is None and (
# For functions that modify their inputs but don't return them, # For functions that modify their inputs but don't return them,
@ -605,7 +605,7 @@ def inplace_or_view_method_definition(
@with_native_function_with_differentiability_info @with_native_function_with_differentiability_info
def inplace_or_view_method_registration( def inplace_or_view_method_registration(
fn: NativeFunctionWithDifferentiabilityInfo, fn: NativeFunctionWithDifferentiabilityInfo,
) -> Optional[str]: ) -> str | None:
f = fn.func f = fn.func
if get_view_info(f) is None and ( if get_view_info(f) is None and (
not modifies_arguments(f) or len(f.func.returns) == 0 not modifies_arguments(f) or len(f.func.returns) == 0
@ -626,7 +626,7 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
def gen_inplace_or_view_type_env( def gen_inplace_or_view_type_env(
fn: NativeFunctionWithDifferentiabilityInfo, fn: NativeFunctionWithDifferentiabilityInfo,
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
definition = inplace_or_view_method_definition(fn) definition = inplace_or_view_method_definition(fn)
registration = inplace_or_view_method_registration(fn) registration = inplace_or_view_method_registration(fn)
@ -649,7 +649,7 @@ def gen_inplace_or_view_type(
out: str, out: str,
native_yaml_path: str, native_yaml_path: str,
tags_yaml_path: str, tags_yaml_path: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
template_path: str, template_path: str,
) -> None: ) -> None:
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp # NOTE: see Note [Sharded File] at the top of the VariableType.cpp

View File

@ -31,11 +31,12 @@
# message, but use what's there # message, but use what's there
# #
from __future__ import annotations
import itertools import itertools
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Callable, Iterable, Sequence
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import yaml import yaml
@ -56,7 +57,6 @@ from torchgen.api.python import (
signature_from_schema, signature_from_schema,
structseq_fieldnames, structseq_fieldnames,
) )
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
from torchgen.context import with_native_function from torchgen.context import with_native_function
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
@ -75,6 +75,7 @@ from torchgen.yaml_utils import YamlLoader
from .gen_inplace_or_view_type import is_tensor_list_type from .gen_inplace_or_view_type import is_tensor_list_type
from .gen_trace_type import should_trace from .gen_trace_type import should_trace
# #
# declarations blocklist # declarations blocklist
# We skip codegen for these functions, for various reasons. # We skip codegen for these functions, for various reasons.
@ -369,7 +370,7 @@ def gen(
valid_tags = parse_tags_yaml(tags_yaml_path) valid_tags = parse_tags_yaml(tags_yaml_path)
def gen_tags_enum() -> Dict[str, str]: def gen_tags_enum() -> dict[str, str]:
return { return {
"enum_of_valid_tags": ( "enum_of_valid_tags": (
"".join( "".join(
@ -384,9 +385,9 @@ def gen(
def group_filter_overloads( def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair], pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool], pred: Callable[[NativeFunction], bool],
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]: ) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
grouped: Dict[ grouped: dict[
BaseOperatorName, List[PythonSignatureNativeFunctionPair] BaseOperatorName, list[PythonSignatureNativeFunctionPair]
] = defaultdict(list) ] = defaultdict(list)
for pair in pairs: for pair in pairs:
if pred(pair.function): if pred(pair.function):
@ -398,17 +399,17 @@ def create_python_bindings(
fm: FileManager, fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair], pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool], pred: Callable[[NativeFunction], bool],
module: Optional[str], module: str | None,
filename: str, filename: str,
*, *,
method: bool, method: bool,
symint: bool = True, symint: bool = True,
) -> None: ) -> None:
"""Generates Python bindings to ATen functions""" """Generates Python bindings to ATen functions"""
py_methods: List[str] = [] py_methods: list[str] = []
ops_headers: List[str] = [] ops_headers: list[str] = []
py_method_defs: List[str] = [] py_method_defs: list[str] = []
py_forwards: List[str] = [] py_forwards: list[str] = []
grouped = group_filter_overloads(pairs, pred) grouped = group_filter_overloads(pairs, pred)
@ -445,8 +446,8 @@ def create_python_return_type_bindings(
Generate function to initialize and return named tuple for native functions Generate function to initialize and return named tuple for native functions
which returns named tuple and registration invocations in `python_return_types.cpp`. which returns named tuple and registration invocations in `python_return_types.cpp`.
""" """
py_return_types_definition: List[str] = [] py_return_types_definition: list[str] = []
py_return_types_registrations: List[str] = [] py_return_types_registrations: list[str] = []
grouped = group_filter_overloads(pairs, pred) grouped = group_filter_overloads(pairs, pred)
@ -484,7 +485,7 @@ def create_python_return_type_bindings_header(
Generate function to initialize and return named tuple for native functions Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`. which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
""" """
py_return_types_declarations: List[str] = [] py_return_types_declarations: list[str] = []
grouped = group_filter_overloads(pairs, pred) grouped = group_filter_overloads(pairs, pred)
@ -510,7 +511,7 @@ def create_python_bindings_sharded(
fm: FileManager, fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair], pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool], pred: Callable[[NativeFunction], bool],
module: Optional[str], module: str | None,
filename: str, filename: str,
*, *,
method: bool, method: bool,
@ -521,13 +522,13 @@ def create_python_bindings_sharded(
grouped = group_filter_overloads(pairs, pred) grouped = group_filter_overloads(pairs, pred)
def key_func( def key_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
) -> str: ) -> str:
return kv[0].base return kv[0].base
def env_func( def env_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
name, fn_pairs = kv name, fn_pairs = kv
return { return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"], "ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
@ -553,7 +554,7 @@ def create_python_bindings_sharded(
def load_signatures( def load_signatures(
native_functions: List[NativeFunction], native_functions: list[NativeFunction],
deprecated_yaml_path: str, deprecated_yaml_path: str,
*, *,
method: bool, method: bool,
@ -580,19 +581,19 @@ def load_deprecated_signatures(
*, *,
method: bool, method: bool,
pyi: bool, pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]: ) -> list[PythonSignatureNativeFunctionPair]:
# The deprecated.yaml doesn't have complete type information, we need # The deprecated.yaml doesn't have complete type information, we need
# find and leverage the original ATen signature (to which it delegates # find and leverage the original ATen signature (to which it delegates
# the call) to generate the full python signature. # the call) to generate the full python signature.
# We join the deprecated and the original signatures using type-only form. # We join the deprecated and the original signatures using type-only form.
# group the original ATen signatures by name # group the original ATen signatures by name
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs: for pair in pairs:
grouped[pair.signature.name].append(pair) grouped[pair.signature.name].append(pair)
# find matching original signatures for each deprecated signature # find matching original signatures for each deprecated signature
results: List[PythonSignatureNativeFunctionPair] = [] results: list[PythonSignatureNativeFunctionPair] = []
with open(deprecated_yaml_path) as f: with open(deprecated_yaml_path) as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader) deprecated_defs = yaml.load(f, Loader=YamlLoader)
@ -701,15 +702,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str:
def emit_structseq_call( def emit_structseq_call(
overloads: Sequence[PythonSignatureNativeFunctionPair], overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], Dict[str, str]]: ) -> tuple[list[str], dict[str, str]]:
""" """
Generate block of named tuple type def inits, and add typeref snippets Generate block of named tuple type def inits, and add typeref snippets
to declarations that use them to declarations that use them
""" """
typenames: Dict[ typenames: dict[
str, str str, str
] = {} # map from unique name + field name lists to typedef name ] = {} # map from unique name + field name lists to typedef name
typedefs: List[str] = [] # typedef declarations and init code typedefs: list[str] = [] # typedef declarations and init code
for overload in overloads: for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns) fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -732,17 +733,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
def generate_return_type_definition_and_registrations( def generate_return_type_definition_and_registrations(
overloads: Sequence[PythonSignatureNativeFunctionPair], overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
""" """
Generate block of function in `python_return_types.cpp` to initialize Generate block of function in `python_return_types.cpp` to initialize
and return named tuple for a native function which returns named tuple and return named tuple for a native function which returns named tuple
and registration invocations in same file. and registration invocations in same file.
""" """
typenames: Dict[ typenames: dict[
str, str str, str
] = {} # map from unique name + field name lists to typedef name ] = {} # map from unique name + field name lists to typedef name
definitions: List[str] = [] # function definition to register the typedef definitions: list[str] = [] # function definition to register the typedef
registrations: List[str] = [] # register call for the typedef registrations: list[str] = [] # register call for the typedef
for overload in overloads: for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns) fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -783,15 +784,15 @@ PyTypeObject* get_{name}_structseq() {{
def generate_return_type_declarations( def generate_return_type_declarations(
overloads: Sequence[PythonSignatureNativeFunctionPair], overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> List[str]: ) -> list[str]:
""" """
Generate block of function declarations in `python_return_types.h` to initialize Generate block of function declarations in `python_return_types.h` to initialize
and return named tuple for a native function. and return named tuple for a native function.
""" """
typenames: Dict[ typenames: dict[
str, str str, str
] = {} # map from unique name + field name lists to typedef name ] = {} # map from unique name + field name lists to typedef name
declarations: List[str] = [] # function declaration to register the typedef declarations: list[str] = [] # function declaration to register the typedef
for overload in overloads: for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns) fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -891,7 +892,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
def method_impl( def method_impl(
name: BaseOperatorName, name: BaseOperatorName,
module: Optional[str], module: str | None,
overloads: Sequence[PythonSignatureNativeFunctionPair], overloads: Sequence[PythonSignatureNativeFunctionPair],
*, *,
method: bool, method: bool,
@ -918,8 +919,8 @@ def method_impl(
overloads, symint=symint overloads, symint=symint
) )
is_singleton = len(grouped_overloads) == 1 is_singleton = len(grouped_overloads) == 1
signatures: List[str] = [] signatures: list[str] = []
dispatch: List[str] = [] dispatch: list[str] = []
for overload_index, overload in enumerate(grouped_overloads): for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str(symint=symint) signature = overload.signature.signature_str(symint=symint)
signatures.append(f"{cpp_string(str(signature))},") signatures.append(f"{cpp_string(str(signature))},")
@ -959,7 +960,7 @@ def method_impl(
def gen_has_torch_function_check( def gen_has_torch_function_check(
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
) -> str: ) -> str:
if noarg: if noarg:
if method: if method:
@ -1007,7 +1008,7 @@ if (_r.isNone(${out_idx})) {
def emit_dispatch_case( def emit_dispatch_case(
overload: PythonSignatureGroup, overload: PythonSignatureGroup,
structseq_typenames: Dict[str, str], structseq_typenames: dict[str, str],
*, *,
symint: bool = True, symint: bool = True,
) -> str: ) -> str:
@ -1050,7 +1051,7 @@ def forward_decls(
overloads: Sequence[PythonSignatureNativeFunctionPair], overloads: Sequence[PythonSignatureNativeFunctionPair],
*, *,
method: bool, method: bool,
) -> Tuple[str, ...]: ) -> tuple[str, ...]:
if method: if method:
return () return ()
@ -1078,7 +1079,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
def method_def( def method_def(
name: BaseOperatorName, name: BaseOperatorName,
module: Optional[str], module: str | None,
overloads: Sequence[PythonSignatureNativeFunctionPair], overloads: Sequence[PythonSignatureNativeFunctionPair],
*, *,
method: bool, method: bool,
@ -1114,8 +1115,8 @@ def method_def(
def group_overloads( def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
) -> Sequence[PythonSignatureGroup]: ) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {} bases: dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments # first group by signature ignoring out arguments
for overload in overloads: for overload in overloads:
@ -1137,7 +1138,7 @@ def group_overloads(
for sig, out in outplaces.items(): for sig, out in outplaces.items():
if sig not in bases: if sig not in bases:
candidates: List[str] = [] candidates: list[str] = []
for overload in overloads: for overload in overloads:
if ( if (
str(overload.function.func.name.name) str(overload.function.func.name.name)
@ -1268,7 +1269,7 @@ def sort_overloads(
) )
# Construct the relation graph # Construct the relation graph
larger_than: Dict[int, Set[int]] = defaultdict(set) larger_than: dict[int, set[int]] = defaultdict(set)
for i1, overload1 in enumerate(grouped_overloads): for i1, overload1 in enumerate(grouped_overloads):
for i2, overload2 in enumerate(grouped_overloads): for i2, overload2 in enumerate(grouped_overloads):
if is_smaller(overload1.signature, overload2.signature): if is_smaller(overload1.signature, overload2.signature):
@ -1279,7 +1280,7 @@ def sort_overloads(
# Use a topological sort to sort overloads according to the partial order. # Use a topological sort to sort overloads according to the partial order.
N = len(grouped_overloads) N = len(grouped_overloads)
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N))) sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
for idx in range(N): for idx in range(N):
# The size of sorted_ids will grow to N eventually. # The size of sorted_ids will grow to N eventually.
@ -1304,7 +1305,7 @@ def sort_overloads(
def emit_single_dispatch( def emit_single_dispatch(
ps: PythonSignature, ps: PythonSignature,
f: NativeFunction, f: NativeFunction,
structseq_typenames: Dict[str, str], structseq_typenames: dict[str, str],
*, *,
symint: bool = True, symint: bool = True,
) -> str: ) -> str:

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import itertools import itertools
from typing import Dict, List, Sequence, Union from typing import Sequence
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import DispatcherSignature from torchgen.api.types import DispatcherSignature
@ -8,6 +10,7 @@ from torchgen.context import with_native_function
from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
from torchgen.utils import FileManager from torchgen.utils import FileManager
# Note [Manual Backend kernels] # Note [Manual Backend kernels]
# For these ops, we want to manually register to dispatch key Backend and # For these ops, we want to manually register to dispatch key Backend and
# skip codegen-ed registeration to all keys before Backend. # skip codegen-ed registeration to all keys before Backend.
@ -136,9 +139,7 @@ ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${inpu
def format_trace_inputs(f: NativeFunction) -> str: def format_trace_inputs(f: NativeFunction) -> str:
def dispatch_trace_input( def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
arg: Union[Argument, TensorOptionsArguments]
) -> Sequence[str]:
if isinstance(arg, TensorOptionsArguments): if isinstance(arg, TensorOptionsArguments):
name = "options" name = "options"
return [ return [
@ -156,7 +157,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
else: else:
return [ADD_TRACE_INPUT.substitute(name=name, input=name)] return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
args: List[Union[Argument, TensorOptionsArguments]] = list( args: list[Argument | TensorOptionsArguments] = list(
f.func.schema_order_arguments() f.func.schema_order_arguments()
) )
@ -399,8 +400,8 @@ ${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args
) )
def emit_trace_body(f: NativeFunction) -> List[str]: def emit_trace_body(f: NativeFunction) -> list[str]:
trace_body: List[str] = [] trace_body: list[str] = []
trace_body.append(format_prerecord_trace(f)) trace_body.append(format_prerecord_trace(f))
@ -503,7 +504,7 @@ def method_registration(f: NativeFunction) -> str:
) )
def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]: def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
return { return {
"ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"], "ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
"trace_method_definitions": [method_definition(fn)], "trace_method_definitions": [method_definition(fn)],
@ -512,7 +513,7 @@ def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
def gen_trace_type( def gen_trace_type(
out: str, native_functions: List[NativeFunction], template_path: str out: str, native_functions: list[NativeFunction], template_path: str
) -> None: ) -> None:
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files. # template regarding sharding of the generated files.

View File

@ -2,18 +2,19 @@
# #
# This writes one file: variable_factories.h # This writes one file: variable_factories.h
from __future__ import annotations
import re import re
from typing import List, Optional
import torchgen.api.python as python import torchgen.api.python as python
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import CppSignatureGroup from torchgen.api.types import CppSignatureGroup
from torchgen.context import with_native_function from torchgen.context import with_native_function
from torchgen.gen import parse_native_yaml from torchgen.gen import parse_native_yaml
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
from torchgen.utils import FileManager, mapMaybe from torchgen.utils import FileManager, mapMaybe
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>") OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
@ -69,7 +70,7 @@ def is_factory_function(f: NativeFunction) -> bool:
@with_native_function @with_native_function
def process_function(f: NativeFunction) -> Optional[str]: def process_function(f: NativeFunction) -> str | None:
name = cpp.name(f.func) name = cpp.name(f.func)
has_tensor_options = python.has_tensor_options(f) has_tensor_options = python.has_tensor_options(f)
is_factory = has_tensor_options or name.endswith("_like") is_factory = has_tensor_options or name.endswith("_like")
@ -83,8 +84,8 @@ def process_function(f: NativeFunction) -> Optional[str]:
sigs.append(cpp_sigs.symint_signature) sigs.append(cpp_sigs.symint_signature)
r = "" r = ""
for sig in sigs: for sig in sigs:
formals: List[str] = [] formals: list[str] = []
exprs: List[str] = [] exprs: list[str] = []
requires_grad = "false" requires_grad = "false"
for arg in sig.arguments(): for arg in sig.arguments():
qualified_type = fully_qualified_type(arg.type) qualified_type = fully_qualified_type(arg.type)

View File

@ -25,8 +25,11 @@
# which will in turn dispatch back to VariableType for its # which will in turn dispatch back to VariableType for its
# differentiable subcomponents. # differentiable subcomponents.
# #
from __future__ import annotations
import re import re
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from typing import Callable, Sequence
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.autograd import ( from torchgen.api.autograd import (
@ -38,7 +41,6 @@ from torchgen.api.autograd import (
NativeFunctionWithDifferentiabilityInfo, NativeFunctionWithDifferentiabilityInfo,
SavedAttribute, SavedAttribute,
) )
from torchgen.api.types import ( from torchgen.api.types import (
ArrayRefCType, ArrayRefCType,
BaseCppType, BaseCppType,
@ -103,6 +105,7 @@ from .gen_trace_type import (
type_wrapper_name, type_wrapper_name,
) )
# We don't set or modify grad_fn on these methods. Generally, they return # We don't set or modify grad_fn on these methods. Generally, they return
# tensors that have requires_grad=False. In-place functions listed here will # tensors that have requires_grad=False. In-place functions listed here will
# not examine or modify requires_grad or grad_fn. # not examine or modify requires_grad or grad_fn.
@ -837,9 +840,9 @@ def gen_variable_type(
out: str, out: str,
native_yaml_path: str, native_yaml_path: str,
tags_yaml_path: str, tags_yaml_path: str,
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo], fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo],
template_path: str, template_path: str,
used_keys: Set[str], used_keys: set[str],
) -> None: ) -> None:
"""VariableType.h and VariableType.cpp body """VariableType.h and VariableType.cpp body
@ -858,8 +861,8 @@ def gen_variable_type(
# helper that generates a TORCH_LIBRARY_IMPL macro for each # helper that generates a TORCH_LIBRARY_IMPL macro for each
# dispatch key that appears in derivatives.yaml # dispatch key that appears in derivatives.yaml
def wrapper_registrations(used_keys: Set[str]) -> str: def wrapper_registrations(used_keys: set[str]) -> str:
library_impl_macro_list: List[str] = [] library_impl_macro_list: list[str] = []
for key in sorted(used_keys): for key in sorted(used_keys):
dispatch_key = key dispatch_key = key
if key == "Default": if key == "Default":
@ -926,7 +929,7 @@ def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
def gen_variable_type_func( def gen_variable_type_func(
fn: NativeFunctionWithDifferentiabilityInfo, fn: NativeFunctionWithDifferentiabilityInfo,
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
f = fn.func f = fn.func
result = {} result = {}
with native_function_manager(f): with native_function_manager(f):
@ -1034,7 +1037,7 @@ _foreach_ops_with_different_arity = {
@with_native_function_with_differentiability_info_and_key @with_native_function_with_differentiability_info_and_key
def emit_body( def emit_body(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> List[str]: ) -> list[str]:
assert dispatch_strategy(fn) == "use_derived" assert dispatch_strategy(fn) == "use_derived"
f = fn.func f = fn.func
info = fn.info[key] if fn.info else None info = fn.info[key] if fn.info else None
@ -1050,8 +1053,8 @@ def emit_body(
is_foreach = name.startswith("_foreach") is_foreach = name.startswith("_foreach")
is_inplace_foreach = is_foreach and inplace is_inplace_foreach = is_foreach and inplace
if is_inplace_foreach: if is_inplace_foreach:
inplace_foreacharg2refarg: Dict[Argument, Argument] = {} inplace_foreacharg2refarg: dict[Argument, Argument] = {}
refargname2inplace_foreacharg: Dict[str, Argument] = {} refargname2inplace_foreacharg: dict[str, Argument] = {}
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name) base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
if info is None: if info is None:
assert ( assert (
@ -1077,8 +1080,8 @@ def emit_body(
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
def gen_differentiable_input( def gen_differentiable_input(
arg: Union[Argument, SelfArgument, TensorOptionsArguments] arg: Argument | SelfArgument | TensorOptionsArguments,
) -> Optional[DifferentiableInput]: ) -> DifferentiableInput | None:
if isinstance(arg, TensorOptionsArguments): if isinstance(arg, TensorOptionsArguments):
return None return None
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
@ -1097,7 +1100,7 @@ def emit_body(
) )
@with_native_function @with_native_function
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]: def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]:
arguments = list(f.func.arguments.non_out) arguments = list(f.func.arguments.non_out)
if is_inplace_foreach and info is not None: if is_inplace_foreach and info is not None:
for i, arg in enumerate(f.func.arguments.flat_non_out): for i, arg in enumerate(f.func.arguments.flat_non_out):
@ -1115,8 +1118,8 @@ def emit_body(
return list(mapMaybe(gen_differentiable_input, arguments)) return list(mapMaybe(gen_differentiable_input, arguments))
def find_args_with_derivatives( def find_args_with_derivatives(
differentiable_inputs: List[DifferentiableInput], differentiable_inputs: list[DifferentiableInput],
) -> List[DifferentiableInput]: ) -> list[DifferentiableInput]:
"""Find arguments that have derivative definitions""" """Find arguments that have derivative definitions"""
if info is None or not info.has_derivatives: if info is None or not info.has_derivatives:
return differentiable_inputs return differentiable_inputs
@ -1178,8 +1181,8 @@ def emit_body(
and (not returns_void) and (not returns_void)
) )
def emit_save_inputs() -> List[str]: def emit_save_inputs() -> list[str]:
setup: List[str] = [] setup: list[str] = []
if info is None or not info.has_derivatives: if info is None or not info.has_derivatives:
return setup return setup
@ -1189,7 +1192,7 @@ def emit_body(
# We don't want to save tensors if we know that they will never be used # We don't want to save tensors if we know that they will never be used
# when computing the derivative, so we add guards to those statements # when computing the derivative, so we add guards to those statements
def guard_for(arg: SavedAttribute) -> Optional[str]: def guard_for(arg: SavedAttribute) -> str | None:
assert info is not None assert info is not None
# It's hard to determine the edge offset if we have TensorLists # It's hard to determine the edge offset if we have TensorLists
@ -1276,8 +1279,8 @@ def emit_body(
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();") setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
return setup return setup
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]: def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]:
body: List[str] = [] body: list[str] = []
if is_out_fn: if is_out_fn:
# For out functions, ensure that no input or output requires grad # For out functions, ensure that no input or output requires grad
body.append(DECLARE_GRAD_FN.substitute(op="Node")) body.append(DECLARE_GRAD_FN.substitute(op="Node"))
@ -1343,8 +1346,8 @@ def emit_body(
body.append(SETUP_DERIVATIVE.substitute(setup=setup)) body.append(SETUP_DERIVATIVE.substitute(setup=setup))
return body return body
def emit_check_if_in_complex_autograd_allowlist() -> List[str]: def emit_check_if_in_complex_autograd_allowlist() -> list[str]:
body: List[str] = [] body: list[str] = []
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
return body return body
for arg in differentiable_outputs: for arg in differentiable_outputs:
@ -1355,11 +1358,11 @@ def emit_body(
return body return body
def emit_check_no_requires_grad( def emit_check_no_requires_grad(
tensor_args: List[DifferentiableInput], tensor_args: list[DifferentiableInput],
args_with_derivatives: List[DifferentiableInput], args_with_derivatives: list[DifferentiableInput],
) -> List[str]: ) -> list[str]:
"""Checks that arguments without derivatives don't require grad""" """Checks that arguments without derivatives don't require grad"""
body: List[str] = [] body: list[str] = []
for arg in tensor_args: for arg in tensor_args:
if arg in args_with_derivatives: if arg in args_with_derivatives:
continue continue
@ -1373,8 +1376,8 @@ def emit_body(
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");') body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
return body return body
def emit_original_self_definition() -> List[str]: def emit_original_self_definition() -> list[str]:
body: List[str] = [] body: list[str] = []
if inplace: if inplace:
if is_inplace_foreach: if is_inplace_foreach:
body.append( body.append(
@ -1412,17 +1415,17 @@ def emit_body(
def save_variables( def save_variables(
saved_variables: Sequence[SavedAttribute], saved_variables: Sequence[SavedAttribute],
is_output: bool, is_output: bool,
guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None, guard_for: Callable[[SavedAttribute], str | None] = lambda name: None,
) -> Sequence[str]: ) -> Sequence[str]:
# assign the saved variables to the generated grad_fn # assign the saved variables to the generated grad_fn
stmts: List[str] = [] stmts: list[str] = []
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)): for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
name = ( name = (
arg.nctype.name.name arg.nctype.name.name
if isinstance(arg.nctype.name, SpecialArgName) if isinstance(arg.nctype.name, SpecialArgName)
else arg.nctype.name else arg.nctype.name
) )
foreacharg: Optional[Argument] = None foreacharg: Argument | None = None
is_foreacharg_list_type: bool = False is_foreacharg_list_type: bool = False
type = arg.nctype.type type = arg.nctype.type
expr = arg.expr expr = arg.expr
@ -1539,10 +1542,10 @@ def emit_body(
return call return call
def wrap_output( def wrap_output(
f: NativeFunction, unpacked_bindings: List[Binding], var: str f: NativeFunction, unpacked_bindings: list[Binding], var: str
) -> str: ) -> str:
call = "" call = ""
rhs_value: Optional[str] = None rhs_value: str | None = None
if not any(r.type.is_tensor_like() for r in f.func.returns): if not any(r.type.is_tensor_like() for r in f.func.returns):
rhs_value = var rhs_value = var
else: else:
@ -1554,11 +1557,11 @@ def emit_body(
return call return call
def check_tensorimpl_and_storage( def check_tensorimpl_and_storage(
call: str, unpacked_bindings: List[Binding] call: str, unpacked_bindings: list[Binding]
) -> str: ) -> str:
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
stmts_before_call: List[str] = [] stmts_before_call: list[str] = []
stmts_after_call: List[str] = [] stmts_after_call: list[str] = []
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
return call return call
@ -1665,7 +1668,7 @@ def emit_body(
return call return call
def emit_call( def emit_call(
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool
) -> str: ) -> str:
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
@ -1764,7 +1767,7 @@ def emit_body(
) )
return "" return ""
def emit_any_requires_grad() -> List[str]: def emit_any_requires_grad() -> list[str]:
extra_condition = "" extra_condition = ""
if info and info.output_differentiability_conditions: if info and info.output_differentiability_conditions:
assert len(info.output_differentiability_conditions) == 1 assert len(info.output_differentiability_conditions) == 1
@ -1782,14 +1785,14 @@ def emit_body(
) )
] ]
def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str: def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str:
if len(var_names) == 1: if len(var_names) == 1:
return f"_any_has_forward_grad_{var_names[0]}" return f"_any_has_forward_grad_{var_names[0]}"
else: else:
return f'_any_has_forward_grad_{"_".join(var_names)}' return f'_any_has_forward_grad_{"_".join(var_names)}'
def emit_any_has_forward_grad() -> List[str]: def emit_any_has_forward_grad() -> list[str]:
content: List[str] = [] content: list[str] = []
if not is_foreach: if not is_foreach:
for derivative in fw_derivatives: for derivative in fw_derivatives:
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
@ -1844,7 +1847,7 @@ def emit_body(
content.append("}") content.append("}")
return content return content
def emit_check_inplace() -> List[str]: def emit_check_inplace() -> list[str]:
if not inplace: if not inplace:
return [] return []
return [ return [
@ -1852,9 +1855,9 @@ def emit_body(
for arg in differentiable_outputs for arg in differentiable_outputs
] ]
def emit_fw_derivatives() -> List[str]: def emit_fw_derivatives() -> list[str]:
content: List[str] = [] content: list[str] = []
fw_grad_setters: List[str] = [] fw_grad_setters: list[str] = []
for derivative in fw_derivatives: for derivative in fw_derivatives:
res = derivative.var_names res = derivative.var_names
if f.func.name.name.inplace: if f.func.name.name.inplace:
@ -2002,7 +2005,7 @@ def emit_body(
"(self.size(), c10::nullopt);" "(self.size(), c10::nullopt);"
) )
foreach_forward_grad_formula = derivative.formula foreach_forward_grad_formula = derivative.formula
_foreach_arg: Union[Argument, DifferentiableInput] _foreach_arg: Argument | DifferentiableInput
if inplace: if inplace:
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here. # note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
@ -2044,7 +2047,7 @@ def emit_body(
content.append("\n".join(fw_grad_setters)) content.append("\n".join(fw_grad_setters))
return content return content
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str: def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str:
# #
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)") # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
# #
@ -2053,7 +2056,7 @@ def emit_body(
# - Used in the out_fn case when we want to forbid fw derivatives # - Used in the out_fn case when we want to forbid fw derivatives
# - Used in the case where the fw_derivative is not defined, but we want # - Used in the case where the fw_derivative is not defined, but we want
# To check if there is a decomposition registered for jvp # To check if there is a decomposition registered for jvp
to_check: List[str] = [] to_check: list[str] = []
for inp in list( for inp in list(
mapMaybe( mapMaybe(
gen_differentiable_input, gen_differentiable_input,
@ -2126,7 +2129,7 @@ def emit_body(
else "" else ""
) )
body: List[str] = [] body: list[str] = []
unpack_args_stats, unpacked_bindings = unpack_args(f) unpack_args_stats, unpacked_bindings = unpack_args(f)
body.extend(unpack_args_stats) body.extend(unpack_args_stats)

View File

@ -4,10 +4,11 @@
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimic this codegen, so we should keep the two in sync. # The fallback is expected to mimic this codegen, so we should keep the two in sync.
from typing import List, Tuple from __future__ import annotations
from typing import TYPE_CHECKING
import torchgen.api.dispatcher as dispatcher import torchgen.api.dispatcher as dispatcher
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
from torchgen.api.translate import translate from torchgen.api.translate import translate
from torchgen.api.types import ( from torchgen.api.types import (
BaseCType, BaseCType,
@ -29,6 +30,11 @@ from .gen_inplace_or_view_type import (
use_derived, use_derived,
) )
if TYPE_CHECKING:
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
FUNCTION_DECLARATION = CodeTemplate( FUNCTION_DECLARATION = CodeTemplate(
"""\ """\
#define ${uppercase_op}_AVAILABLE #define ${uppercase_op}_AVAILABLE
@ -155,9 +161,9 @@ def returns_multi_tensor(fn: NativeFunction) -> bool:
# tuple: (list of getter logic strings, list of setter logic strings, string # tuple: (list of getter logic strings, list of setter logic strings, string
# with num items expression) # with num items expression)
def generate_state_getter_setter( def generate_state_getter_setter(
bindings: List[Binding], bindings: list[Binding],
state_vec_type: NamedCType, state_vec_type: NamedCType,
) -> Tuple[List[str], List[str], str]: ) -> tuple[list[str], list[str], str]:
getter_logic = [] getter_logic = []
setter_logic = [] setter_logic = []
@ -302,7 +308,7 @@ def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
def gen_view_funcs( def gen_view_funcs(
out: str, out: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
template_path: str, template_path: str,
) -> None: ) -> None:
# don't need the info parts, just the function # don't need the info parts, just the function

View File

@ -2,14 +2,16 @@
# #
# Each autograd function is represented by `DifferentiabilityInfo` containing # Each autograd function is represented by `DifferentiabilityInfo` containing
# a list of `Derivative`. See `torchgen.api.autograd` for the data models. # a list of `Derivative`. See `torchgen.api.autograd` for the data models.
from __future__ import annotations
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple from typing import Any, Counter, Dict, Sequence, Set, Tuple
import yaml import yaml
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.autograd import ( from torchgen.api.autograd import (
Derivative, Derivative,
DifferentiabilityInfo, DifferentiabilityInfo,
@ -50,9 +52,10 @@ from torchgen.model import (
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
from torchgen.yaml_utils import YamlLoader from torchgen.yaml_utils import YamlLoader
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {} _GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
@ -62,11 +65,11 @@ _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
# we generate them here instead of duplicating them in the yaml. # we generate them here instead of duplicating them in the yaml.
# See Note [Codegen'd {view}_copy Operators] # See Note [Codegen'd {view}_copy Operators]
def add_view_copy_derivatives( def add_view_copy_derivatives(
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
view_groups: List[NativeFunctionsViewGroup], view_groups: list[NativeFunctionsViewGroup],
) -> None: ) -> None:
# Get the map from each view op's name to its corresponding view group # Get the map from each view op's name to its corresponding view group
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = { view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
g.view.func.name: g for g in view_groups g.view.func.name: g for g in view_groups
} }
@ -125,10 +128,10 @@ def load_derivatives(
# function schema is the complete declaration including mutability annotation / default value and etc. # function schema is the complete declaration including mutability annotation / default value and etc.
# signature is the canonical schema for a group of functions (in-place/out/functional variants) # signature is the canonical schema for a group of functions (in-place/out/functional variants)
# that are semantically related. # that are semantically related.
functions_by_signature: Dict[ functions_by_signature: dict[
FunctionSchema, List[NativeFunction] FunctionSchema, list[NativeFunction]
] = defaultdict(list) ] = defaultdict(list)
functions_by_schema: Dict[str, NativeFunction] = {} functions_by_schema: dict[str, NativeFunction] = {}
for function in native_functions: for function in native_functions:
functions_by_signature[function.func.signature()].append(function) functions_by_signature[function.func.signature()].append(function)
assert str(function.func) not in functions_by_schema assert str(function.func) not in functions_by_schema
@ -141,8 +144,8 @@ def load_derivatives(
# infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
# this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
# we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {} infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
used_dispatch_keys: Set[str] = set() used_dispatch_keys: set[str] = set()
for defn_dict in definitions: for defn_dict in definitions:
# Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
if "dispatch" not in defn_dict: if "dispatch" not in defn_dict:
@ -185,11 +188,11 @@ def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
def create_derivative( def create_derivative(
f: NativeFunction, f: NativeFunction,
formula: str, formula: str,
var_names: Tuple[str, ...], var_names: tuple[str, ...],
available_named_gradients: Sequence[str], available_named_gradients: Sequence[str],
) -> Derivative: ) -> Derivative:
original_formula = formula original_formula = formula
arguments: List[NamedCType] = [ arguments: list[NamedCType] = [
a.nctype.remove_const_ref() for a in cpp_arguments(f) a.nctype.remove_const_ref() for a in cpp_arguments(f)
] ]
@ -230,10 +233,10 @@ def create_derivative(
def create_forward_derivative( def create_forward_derivative(
f: NativeFunction, formula: str, names: Tuple[str, ...] f: NativeFunction, formula: str, names: tuple[str, ...]
) -> ForwardDerivative: ) -> ForwardDerivative:
var_names = names var_names = names
var_types: Optional[Tuple[Type, ...]] = None var_types: tuple[Type, ...] | None = None
for r in f.func.returns: for r in f.func.returns:
if r.name in var_names: if r.name in var_names:
if var_types is None: if var_types is None:
@ -269,12 +272,12 @@ def create_forward_derivative(
def postprocess_forward_derivatives( def postprocess_forward_derivatives(
f: NativeFunction, f: NativeFunction,
defn_name: str, defn_name: str,
all_arg_names: List[str], all_arg_names: list[str],
derivatives: List[Derivative], derivatives: list[Derivative],
forward_derivatives: List[ForwardDerivative], forward_derivatives: list[ForwardDerivative],
args_with_derivatives: Sequence[Binding], args_with_derivatives: Sequence[Binding],
) -> List[ForwardDerivative]: ) -> list[ForwardDerivative]:
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
is_foreach = f.func.name.name.base.startswith("_foreach_") is_foreach = f.func.name.name.base.startswith("_foreach_")
required_inputs = set() required_inputs = set()
for arg in args_with_derivatives: for arg in args_with_derivatives:
@ -300,7 +303,7 @@ def postprocess_forward_derivatives(
return tuple(required_inputs) return tuple(required_inputs)
updated_derivatives: List[ForwardDerivative] = [] updated_derivatives: list[ForwardDerivative] = []
for defn in forward_derivatives: for defn in forward_derivatives:
formula = defn.formula formula = defn.formula
@ -430,7 +433,7 @@ def postprocess_forward_derivatives(
def is_forward_derivative_definition( def is_forward_derivative_definition(
all_arg_names: List[str], names: Tuple[str, ...] all_arg_names: list[str], names: tuple[str, ...]
) -> bool: ) -> bool:
for name in names: for name in names:
if name not in all_arg_names: if name not in all_arg_names:
@ -441,12 +444,12 @@ def is_forward_derivative_definition(
def create_differentiability_info( def create_differentiability_info(
defn_dict: Dict[Any, Any], defn_dict: dict[Any, Any],
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
functions_by_schema: Dict[str, NativeFunction], functions_by_schema: dict[str, NativeFunction],
op_counter: Counter[str], op_counter: Counter[str],
used_dispatch_keys: Set[str], used_dispatch_keys: set[str],
) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]: ) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
"""Processes a single entry `defn` in derivatives.yaml""" """Processes a single entry `defn` in derivatives.yaml"""
def canonical_function( def canonical_function(
@ -463,7 +466,7 @@ def create_differentiability_info(
assert name + "_" == cpp.name(functions[0].func) assert name + "_" == cpp.name(functions[0].func)
return functions[0] return functions[0]
def split_names(raw_names: str) -> Tuple[str, ...]: def split_names(raw_names: str) -> tuple[str, ...]:
"""Given "foo, bar", return ["foo", "bar"].""" """Given "foo, bar", return ["foo", "bar"]."""
return tuple(x.strip() for x in raw_names.split(",")) return tuple(x.strip() for x in raw_names.split(","))
@ -477,7 +480,7 @@ def create_differentiability_info(
uses_grad = False # true if any derivative uses "grad" uses_grad = False # true if any derivative uses "grad"
num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
uses_named_grads = False # true if any derivative uses "grad_{name}" uses_named_grads = False # true if any derivative uses "grad_{name}"
used_grads_indices: List[int] = [] # which indices of grads are used used_grads_indices: list[int] = [] # which indices of grads are used
for d in derivatives: for d in derivatives:
formula = d.formula formula = d.formula
uses_grad = uses_grad or bool( uses_grad = uses_grad or bool(
@ -521,7 +524,7 @@ def create_differentiability_info(
@with_native_function @with_native_function
def set_up_derivatives( def set_up_derivatives(
f: NativeFunction, f: NativeFunction,
) -> Tuple[ ) -> tuple[
Sequence[Derivative], Sequence[Derivative],
Sequence[ForwardDerivative], Sequence[ForwardDerivative],
Sequence[Binding], Sequence[Binding],
@ -529,10 +532,10 @@ def create_differentiability_info(
Sequence[str], Sequence[str],
]: ]:
# Set up the derivative information # Set up the derivative information
derivatives: List[Derivative] = [] derivatives: list[Derivative] = []
forward_derivatives: List[ForwardDerivative] = [] forward_derivatives: list[ForwardDerivative] = []
non_differentiable_arg_names: List[str] = [] non_differentiable_arg_names: list[str] = []
args_with_derivatives_set: Set[str] = set() args_with_derivatives_set: set[str] = set()
all_arg_names = [a.name for a in cpp_arguments(f)] all_arg_names = [a.name for a in cpp_arguments(f)]
all_ret_names = [ all_ret_names = [
@ -699,7 +702,7 @@ def create_differentiability_info(
available_named_gradients, available_named_gradients,
) = set_up_derivatives(canonical) ) = set_up_derivatives(canonical)
used_named_gradients: Set[str] = set() used_named_gradients: set[str] = set()
for d in derivatives: for d in derivatives:
used_named_gradients |= d.named_gradients used_named_gradients |= d.named_gradients
@ -738,7 +741,7 @@ def create_differentiability_info(
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
def used_gradient_indices(formula: str) -> List[int]: def used_gradient_indices(formula: str) -> list[int]:
"""Determine a list of gradient indices (the i in grads[i]) that """Determine a list of gradient indices (the i in grads[i]) that
are used by the formula. are used by the formula.
@ -750,9 +753,9 @@ def used_gradient_indices(formula: str) -> List[int]:
def saved_variables( def saved_variables(
formula: str, formula: str,
nctypes: List[NamedCType], nctypes: list[NamedCType],
var_names: Tuple[str, ...], var_names: tuple[str, ...],
) -> Tuple[str, Tuple[SavedAttribute, ...]]: ) -> tuple[str, tuple[SavedAttribute, ...]]:
def stride_expr(name: str) -> str: def stride_expr(name: str) -> str:
assert var_names == (name,), ( assert var_names == (name,), (
'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
@ -760,7 +763,7 @@ def saved_variables(
) )
return f'strides_or_error({name}, "{name}")' return f'strides_or_error({name}, "{name}")'
REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
# replace self.sym_sizes() with self_sym_sizes # replace self.sym_sizes() with self_sym_sizes
( (
r"{}.sym_sizes\(\)", r"{}.sym_sizes\(\)",
@ -914,7 +917,7 @@ def saved_variables(
] ]
# find which arguments need to be saved # find which arguments need to be saved
saved: List[SavedAttribute] = [] saved: list[SavedAttribute] = []
if ".sizes()" in formula or "->sizes()" in formula: if ".sizes()" in formula or "->sizes()" in formula:
raise RuntimeError( raise RuntimeError(
@ -941,7 +944,7 @@ def saved_variables(
# when the autograd Function is created to avoid saving variables # when the autograd Function is created to avoid saving variables
for regex, info in REPLACEMENTS: for regex, info in REPLACEMENTS:
def repl(m: Match[str]) -> str: def repl(m: re.Match[str]) -> str:
suffix: str = ( suffix: str = (
info["suffix"](m) if callable(info["suffix"]) else info["suffix"] info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
) )
@ -999,8 +1002,8 @@ def _create_op_prefix(name: str) -> str:
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
seen: Set[str] = set() seen: set[str] = set()
saved: List[SavedAttribute] = [] saved: list[SavedAttribute] = []
for var in vars: for var in vars:
name = ( name = (
var.nctype.name.name var.nctype.name.name

View File

@ -1,17 +1,17 @@
from __future__ import annotations
import os import os
import platform import platform
import shutil import shutil
from glob import glob from glob import glob
from typing import Dict, Optional
from setuptools import distutils # type: ignore[import] from setuptools import distutils # type: ignore[import]
from .setup_helpers.cmake import CMake, USE_NINJA from .setup_helpers.cmake import CMake, USE_NINJA
from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]: def _overlay_windows_vcvars(env: dict[str, str]) -> dict[str, str]:
vc_arch = "x64" if IS_64BIT else "x86" vc_arch = "x64" if IS_64BIT else "x86"
if platform.machine() == "ARM64": if platform.machine() == "ARM64":
@ -34,7 +34,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
"emulation is enabled!" "emulation is enabled!"
) )
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch) vc_env: dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
# Keys in `_get_vc_env` are always lowercase. # Keys in `_get_vc_env` are always lowercase.
# We turn them into uppercase before overlaying vcvars # We turn them into uppercase before overlaying vcvars
# because OS environ keys are always uppercase on Windows. # because OS environ keys are always uppercase on Windows.
@ -47,7 +47,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
return vc_env return vc_env
def _create_build_env() -> Dict[str, str]: def _create_build_env() -> dict[str, str]:
# XXX - our cmake file sometimes looks at the system environment # XXX - our cmake file sometimes looks at the system environment
# and not cmake flags! # and not cmake flags!
# you should NEVER add something to this list. It is bad practice to # you should NEVER add something to this list. It is bad practice to
@ -72,8 +72,8 @@ def _create_build_env() -> Dict[str, str]:
def build_caffe2( def build_caffe2(
version: Optional[str], version: str | None,
cmake_python_library: Optional[str], cmake_python_library: str | None,
build_python: bool, build_python: bool,
rerun_cmake: bool, rerun_cmake: bool,
cmake_only: bool, cmake_only: bool,

View File

@ -5,10 +5,13 @@
# - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh # - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh
# - Copy libs from build/lib to torch/lib folder # - Copy libs from build/lib to torch/lib folder
from __future__ import annotations
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Tuple from typing import Any
PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent
TORCH_DIR = PYTORCH_ROOTDIR / "torch" TORCH_DIR = PYTORCH_ROOTDIR / "torch"
@ -17,7 +20,7 @@ BUILD_DIR = PYTORCH_ROOTDIR / "build"
BUILD_LIB_DIR = BUILD_DIR / "lib" BUILD_LIB_DIR = BUILD_DIR / "lib"
def check_output(args: List[str], cwd: Optional[str] = None) -> str: def check_output(args: list[str], cwd: str | None = None) -> str:
return subprocess.check_output(args, cwd=cwd).decode("utf-8") return subprocess.check_output(args, cwd=cwd).decode("utf-8")
@ -63,7 +66,7 @@ def is_devel_setup() -> bool:
return output.strip() == str(TORCH_DIR / "__init__.py") return output.strip() == str(TORCH_DIR / "__init__.py")
def create_build_plan() -> List[Tuple[str, str]]: def create_build_plan() -> list[tuple[str, str]]:
output = check_output( output = check_output(
["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR) ["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR)
) )

View File

@ -8,13 +8,15 @@ For custom build with static dispatch, the op dependency graph will be omitted,
and it will directly output root ops as the allowlist. and it will directly output root ops as the allowlist.
""" """
import argparse from __future__ import annotations
import argparse
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Set from typing import Dict, Set
import yaml import yaml
DepGraph = Dict[str, Set[str]] DepGraph = Dict[str, Set[str]]
@ -34,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph:
return dict(result) return dict(result)
def load_root_ops(fname: str) -> List[str]: def load_root_ops(fname: str) -> list[str]:
result = [] result = []
with open(fname) as stream: with open(fname) as stream:
for op in yaml.safe_load(stream): for op in yaml.safe_load(stream):
@ -44,9 +46,9 @@ def load_root_ops(fname: str) -> List[str]:
def gen_transitive_closure( def gen_transitive_closure(
dep_graph: DepGraph, dep_graph: DepGraph,
root_ops: List[str], root_ops: list[str],
train: bool = False, train: bool = False,
) -> List[str]: ) -> list[str]:
result = set(root_ops) result = set(root_ops)
queue = root_ops.copy() queue = root_ops.copy()
@ -73,7 +75,7 @@ def gen_transitive_closure(
return sorted(result) return sorted(result)
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str: def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str:
return " ".join(gen_transitive_closure(dep_graph, root_ops)) return " ".join(gen_transitive_closure(dep_graph, root_ops))

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import json import json
import sys import sys
from typing import Any, Dict, List, Optional from typing import Any
import yaml import yaml
from gen_op_registration_allowlist import ( from gen_op_registration_allowlist import (
@ -17,6 +20,7 @@ from torchgen.selective_build.operator import (
) )
from torchgen.selective_build.selector import merge_kernel_metadata from torchgen.selective_build.selector import merge_kernel_metadata
# Generate YAML file containing the operators used for a specific PyTorch model. # Generate YAML file containing the operators used for a specific PyTorch model.
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# #
@ -84,17 +88,17 @@ from torchgen.selective_build.selector import merge_kernel_metadata
# #
def canonical_opnames(opnames: List[str]) -> List[str]: def canonical_opnames(opnames: list[str]) -> list[str]:
return [canonical_name(opname) for opname in opnames] return [canonical_name(opname) for opname in opnames]
def make_filter_from_options( def make_filter_from_options(
model_name: str, model_name: str,
model_versions: List[str], model_versions: list[str],
model_assets: Optional[List[str]], model_assets: list[str] | None,
model_backends: Optional[List[str]], model_backends: list[str] | None,
): ):
def is_model_included(model_info): def is_model_included(model_info) -> bool:
model = model_info["model"] model = model_info["model"]
if model["name"] != model_name: if model["name"] != model_name:
return False return False
@ -109,7 +113,7 @@ def make_filter_from_options(
# Returns if a the specified rule is a new or old style pt_operator_library # Returns if a the specified rule is a new or old style pt_operator_library
def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]): def is_new_style_rule(model_name: str, model_versions: list[str] | None):
return model_name is not None and model_versions is not None return model_name is not None and model_versions is not None
@ -117,13 +121,13 @@ def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
# appear in at least one model yaml. Throws if verification is failed, # appear in at least one model yaml. Throws if verification is failed,
# returns None on success # returns None on success
def verify_all_specified_present( def verify_all_specified_present(
model_assets: Optional[List[str]], model_assets: list[str] | None,
model_versions: List[str], model_versions: list[str],
selected_models_yaml: List[Dict[str, Any]], selected_models_yaml: list[dict[str, Any]],
rule_name: str, rule_name: str,
model_name: str, model_name: str,
new_style_rule: bool, new_style_rule: bool,
): ) -> None:
def find_missing_items(model_items, key, selected_models_yaml): def find_missing_items(model_items, key, selected_models_yaml):
missing_items = [] missing_items = []
if not new_style_rule or not model_items: if not new_style_rule or not model_items:
@ -179,10 +183,10 @@ def verify_all_specified_present(
# Uses the selected models configs and then combines them into one dictionary, # Uses the selected models configs and then combines them into one dictionary,
# formats them as a string, and places the string into output as a top level debug_info # formats them as a string, and places the string into output as a top level debug_info
def create_debug_info_from_selected_models( def create_debug_info_from_selected_models(
output: Dict[str, object], output: dict[str, object],
selected_models: List[dict], selected_models: list[dict],
new_style_rule: bool, new_style_rule: bool,
): ) -> None:
model_dict = { model_dict = {
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes "asset_info": {}, # maps asset name -> dict of asset metadata like hashes
"is_new_style_rule": new_style_rule, "is_new_style_rule": new_style_rule,
@ -201,7 +205,7 @@ def create_debug_info_from_selected_models(
output["debug_info"] = [json.dumps(model_dict)] output["debug_info"] = [json.dumps(model_dict)]
def fill_output(output: Dict[str, object], options: object): def fill_output(output: dict[str, object], options: object) -> None:
"""Populate the output dict with the information required to serialize """Populate the output dict with the information required to serialize
the YAML file used for selective build. the YAML file used for selective build.
""" """
@ -458,7 +462,7 @@ def fill_output(output: Dict[str, object], options: object):
# END TRACING BASED BUILD OPS # END TRACING BASED BUILD OPS
# Merge dictionaries together to remove op duplication # Merge dictionaries together to remove op duplication
operators: Dict[str, SelectiveBuildOperator] = {} operators: dict[str, SelectiveBuildOperator] = {}
for ops_dict in bucketed_ops: for ops_dict in bucketed_ops:
operators = merge_operator_dicts(operators, ops_dict) operators = merge_operator_dicts(operators, ops_dict)

View File

@ -1,10 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import json import json
import os import os
import sys import sys
from functools import reduce from functools import reduce
from typing import Any, List, Set from typing import Any
import yaml import yaml
from tools.lite_interpreter.gen_selected_mobile_ops_header import ( from tools.lite_interpreter.gen_selected_mobile_ops_header import (
@ -17,11 +20,11 @@ from torchgen.selective_build.selector import (
) )
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]: def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:
return set(selective_builder.operators.keys()) return set(selective_builder.operators.keys())
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]: def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:
ops = [] ops = []
for op_name, op in selective_builder.operators.items(): for op_name, op in selective_builder.operators.items():
if op.is_used_for_training: if op.is_used_for_training:
@ -44,7 +47,7 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
) )
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None: def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:
supported_mobile_models_source = """/* supported_mobile_models_source = """/*
* Generated by gen_oplist.py * Generated by gen_oplist.py
*/ */
@ -87,7 +90,7 @@ SupportedMobileModelCheckerRegistry register_model_versions;
out_file.write(source.encode("utf-8")) out_file.write(source.encode("utf-8"))
def main(argv: List[Any]) -> None: def main(argv: list[Any]) -> None:
"""This binary generates 3 files: """This binary generates 3 files:
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import argparse import argparse
import os import os
from typing import cast, List, Optional, Tuple from typing import cast
from ..util.setting import ( from ..util.setting import (
CompilerType, CompilerType,
@ -38,7 +40,7 @@ BLOCKED_PYTHON_TESTS = {
} }
def initialization() -> Tuple[Option, TestList, List[str]]: def initialization() -> tuple[Option, TestList, list[str]]:
# create folder if not exists # create folder if not exists
create_folders() create_folders()
# add arguments # add arguments
@ -77,7 +79,7 @@ def add_arguments_oss(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
def parse_arguments( def parse_arguments(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
) -> Tuple[Option, Optional[List[str]], Optional[List[str]], Optional[bool]]: ) -> tuple[Option, list[str] | None, list[str] | None, bool | None]:
# parse args # parse args
args = parser.parse_args() args = parser.parse_args()
# get option # get option
@ -85,9 +87,7 @@ def parse_arguments(
return (options, args.interest_only, args.run_only, args.clean) return (options, args.interest_only, args.run_only, args.clean)
def get_test_list_by_type( def get_test_list_by_type(run_only: list[str] | None, test_type: TestType) -> TestList:
run_only: Optional[List[str]], test_type: TestType
) -> TestList:
test_list: TestList = [] test_list: TestList = []
binary_folder = get_oss_binary_folder(test_type) binary_folder = get_oss_binary_folder(test_type)
g = os.walk(binary_folder) g = os.walk(binary_folder)
@ -106,7 +106,7 @@ def get_test_list_by_type(
return test_list return test_list
def get_test_list(run_only: Optional[List[str]]) -> TestList: def get_test_list(run_only: list[str] | None) -> TestList:
test_list: TestList = [] test_list: TestList = []
# add c++ test list # add c++ test list
test_list.extend(get_test_list_by_type(run_only, TestType.CPP)) test_list.extend(get_test_list_by_type(run_only, TestType.CPP))
@ -122,7 +122,7 @@ def get_test_list(run_only: Optional[List[str]]) -> TestList:
return test_list return test_list
def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]: def empty_list_if_none(arg_interested_folder: list[str] | None) -> list[str]:
if arg_interested_folder is None: if arg_interested_folder is None:
return [] return []
# if this argument is specified, just return itself # if this argument is specified, just return itself
@ -134,7 +134,7 @@ def gcc_export_init() -> None:
create_folder(JSON_FOLDER_BASE_DIR) create_folder(JSON_FOLDER_BASE_DIR)
def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]: def get_python_run_only(args_run_only: list[str] | None) -> list[str]:
# if user specifies run-only option # if user specifies run-only option
if args_run_only: if args_run_only:
return args_run_only return args_run_only
@ -144,7 +144,7 @@ def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
return ["run_test.py"] return ["run_test.py"]
else: else:
# for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them # for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them
run_only: List[str] = [] run_only: list[str] = []
binary_folder = get_oss_binary_folder(TestType.PY) binary_folder = get_oss_binary_folder(TestType.PY)
g = os.walk(binary_folder) g = os.walk(binary_folder)
for _, _, file_list in g: for _, _, file_list in g:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import os import os
import subprocess import subprocess
from typing import List, Optional
from ..util.setting import CompilerType, TestType, TOOLS_FOLDER from ..util.setting import CompilerType, TestType, TOOLS_FOLDER
from ..util.utils import print_error, remove_file from ..util.utils import print_error, remove_file
@ -14,7 +15,7 @@ def get_oss_binary_folder(test_type: TestType) -> str:
) )
def get_oss_shared_library() -> List[str]: def get_oss_shared_library() -> list[str]:
lib_dir = os.path.join(get_pytorch_folder(), "build", "lib") lib_dir = os.path.join(get_pytorch_folder(), "build", "lib")
return [ return [
os.path.join(lib_dir, lib) os.path.join(lib_dir, lib)
@ -48,7 +49,7 @@ def get_pytorch_folder() -> str:
) )
def detect_compiler_type() -> Optional[CompilerType]: def detect_compiler_type() -> CompilerType | None:
# check if user specifies the compiler type # check if user specifies the compiler type
user_specify = os.environ.get("CXX", None) user_specify = os.environ.get("CXX", None)
if user_specify: if user_specify:
@ -76,7 +77,7 @@ def clean_up_gcda() -> None:
remove_file(item) remove_file(item)
def get_gcda_files() -> List[str]: def get_gcda_files() -> list[str]:
folder_has_gcda = os.path.join(get_pytorch_folder(), "build") folder_has_gcda = os.path.join(get_pytorch_folder(), "build")
if os.path.isdir(folder_has_gcda): if os.path.isdir(folder_has_gcda):
# TODO use glob # TODO use glob

View File

@ -1,7 +1,8 @@
from __future__ import annotations
import os import os
import subprocess import subprocess
import time import time
from typing import List
from ..util.setting import ( from ..util.setting import (
JSON_FOLDER_BASE_DIR, JSON_FOLDER_BASE_DIR,
@ -25,7 +26,7 @@ from .utils import get_tool_path_by_platform, run_cpp_test
def create_corresponding_folder( def create_corresponding_folder(
cur_path: str, prefix_cur_path: str, dir_list: List[str], new_base_folder: str cur_path: str, prefix_cur_path: str, dir_list: list[str], new_base_folder: str
) -> None: ) -> None:
for dir_name in dir_list: for dir_name in dir_list:
relative_path = convert_to_relative_path( relative_path = convert_to_relative_path(
@ -70,7 +71,7 @@ def export_target(
merged_file: str, merged_file: str,
json_file: str, json_file: str,
binary_file: str, binary_file: str,
shared_library_list: List[str], shared_library_list: list[str],
platform_type: TestPlatform, platform_type: TestPlatform,
) -> None: ) -> None:
if binary_file is None: if binary_file is None:

View File

@ -1,7 +1,8 @@
from __future__ import annotations
import os import os
import subprocess import subprocess
import time import time
from typing import Dict
# gcc is only used in oss # gcc is only used in oss
from ..oss.utils import get_gcda_files, run_oss_python_test from ..oss.utils import get_gcda_files, run_oss_python_test
@ -10,7 +11,7 @@ from ..util.utils import print_log, print_time
from .utils import run_cpp_test from .utils import run_cpp_test
def update_gzip_dict(gzip_dict: Dict[str, int], file_name: str) -> str: def update_gzip_dict(gzip_dict: dict[str, int], file_name: str) -> str:
file_name = file_name.lower() file_name = file_name.lower()
gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1 gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1
num = gzip_dict[file_name] num = gzip_dict[file_name]
@ -34,7 +35,7 @@ def export() -> None:
# collect .gcda files # collect .gcda files
gcda_files = get_gcda_files() gcda_files = get_gcda_files()
# file name like utils.cpp may have same name in different folder # file name like utils.cpp may have same name in different folder
gzip_dict: Dict[str, int] = {} gzip_dict: dict[str, int] = {}
for gcda_item in gcda_files: for gcda_item in gcda_files:
# generate json.gz # generate json.gz
subprocess.check_call(["gcov", "-i", gcda_item]) subprocess.check_call(["gcov", "-i", gcda_item])

View File

@ -1,12 +1,14 @@
import typing as t from __future__ import annotations
from typing import Any, NamedTuple
class CoverageRecord(t.NamedTuple): class CoverageRecord(NamedTuple):
filepath: str filepath: str
covered_lines: t.List[int] covered_lines: list[int]
uncovered_lines: t.Optional[t.List[int]] = None uncovered_lines: list[int] | None = None
def to_dict(self) -> t.Dict[str, t.Any]: def to_dict(self) -> dict[str, Any]:
return { return {
"filepath": self.filepath, "filepath": self.filepath,
"covered_lines": self.covered_lines, "covered_lines": self.covered_lines,

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Set from __future__ import annotations
from typing import Any
from .coverage_record import CoverageRecord from .coverage_record import CoverageRecord
@ -10,7 +12,7 @@ class GcovCoverageParser:
of CoverageRecord(s). of CoverageRecord(s).
""" """
def __init__(self, llvm_coverage: Dict[str, Any]) -> None: def __init__(self, llvm_coverage: dict[str, Any]) -> None:
self._llvm_coverage = llvm_coverage self._llvm_coverage = llvm_coverage
@staticmethod @staticmethod
@ -24,17 +26,17 @@ class GcovCoverageParser:
return True return True
return False return False
def parse(self) -> List[CoverageRecord]: def parse(self) -> list[CoverageRecord]:
# The JSON format is described in the gcov source code # The JSON format is described in the gcov source code
# https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html # https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html
records: List[CoverageRecord] = [] records: list[CoverageRecord] = []
for file_info in self._llvm_coverage["files"]: for file_info in self._llvm_coverage["files"]:
filepath = file_info["file"] filepath = file_info["file"]
if self._skip_coverage(filepath): if self._skip_coverage(filepath):
continue continue
# parse json file # parse json file
covered_lines: Set[int] = set() covered_lines: set[int] = set()
uncovered_lines: Set[int] = set() uncovered_lines: set[int] = set()
for line in file_info["lines"]: for line in file_info["lines"]:
line_number = line["line_number"] line_number = line["line_number"]
count = line["count"] count = line["count"]

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Set, Tuple from __future__ import annotations
from typing import Any
from .coverage_record import CoverageRecord from .coverage_record import CoverageRecord
from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments
@ -12,7 +14,7 @@ class LlvmCoverageParser:
""" """
def __init__(self, llvm_coverage: Dict[str, Any]) -> None: def __init__(self, llvm_coverage: dict[str, Any]) -> None:
self._llvm_coverage = llvm_coverage self._llvm_coverage = llvm_coverage
@staticmethod @staticmethod
@ -28,13 +30,13 @@ class LlvmCoverageParser:
@staticmethod @staticmethod
def _collect_coverage( def _collect_coverage(
segments: List[LlvmCoverageSegment], segments: list[LlvmCoverageSegment],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
""" """
Stateful parsing of coverage segments. Stateful parsing of coverage segments.
""" """
covered_lines: Set[int] = set() covered_lines: set[int] = set()
uncovered_lines: Set[int] = set() uncovered_lines: set[int] = set()
prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None) prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None)
for segment in segments: for segment in segments:
covered_range, uncovered_range = segment.get_coverage(prev_segment) covered_range, uncovered_range = segment.get_coverage(prev_segment)
@ -45,10 +47,10 @@ class LlvmCoverageParser:
uncovered_lines.difference_update(covered_lines) uncovered_lines.difference_update(covered_lines)
return sorted(covered_lines), sorted(uncovered_lines) return sorted(covered_lines), sorted(uncovered_lines)
def parse(self, repo_name: str) -> List[CoverageRecord]: def parse(self, repo_name: str) -> list[CoverageRecord]:
# The JSON format is described in the LLVM source code # The JSON format is described in the LLVM source code
# https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp # https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp
records: List[CoverageRecord] = [] records: list[CoverageRecord] = []
for export_unit in self._llvm_coverage["data"]: for export_unit in self._llvm_coverage["data"]:
for file_info in export_unit["files"]: for file_info in export_unit["files"]:
filepath = file_info["filename"] filepath = file_info["filename"]

View File

@ -1,4 +1,6 @@
from typing import List, NamedTuple, Optional, Tuple from __future__ import annotations
from typing import NamedTuple
class LlvmCoverageSegment(NamedTuple): class LlvmCoverageSegment(NamedTuple):
@ -7,7 +9,7 @@ class LlvmCoverageSegment(NamedTuple):
segment_count: int segment_count: int
has_count: int has_count: int
is_region_entry: int is_region_entry: int
is_gap_entry: Optional[int] is_gap_entry: int | None
@property @property
def has_coverage(self) -> bool: def has_coverage(self) -> bool:
@ -18,8 +20,8 @@ class LlvmCoverageSegment(NamedTuple):
return self.has_count > 0 return self.has_count > 0
def get_coverage( def get_coverage(
self, prev_segment: "LlvmCoverageSegment" self, prev_segment: LlvmCoverageSegment
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
# Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py # Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py
if not prev_segment.is_executable: if not prev_segment.is_executable:
return [], [] return [], []
@ -32,12 +34,12 @@ class LlvmCoverageSegment(NamedTuple):
return (lines_range, []) if prev_segment.has_coverage else ([], lines_range) return (lines_range, []) if prev_segment.has_coverage else ([], lines_range)
def parse_segments(raw_segments: List[List[int]]) -> List[LlvmCoverageSegment]: def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]:
""" """
Creates LlvmCoverageSegment from a list of lists in llvm export json. Creates LlvmCoverageSegment from a list of lists in llvm export json.
each segment is represented by 5-element array. each segment is represented by 5-element array.
""" """
ret: List[LlvmCoverageSegment] = [] ret: list[LlvmCoverageSegment] = []
for raw_segment in raw_segments: for raw_segment in raw_segments:
assert ( assert (
len(raw_segment) == 5 or len(raw_segment) == 6 len(raw_segment) == 5 or len(raw_segment) == 6

View File

@ -1,10 +1,13 @@
from __future__ import annotations
import os import os
import subprocess import subprocess
from typing import Dict, IO, List, Set, Tuple from typing import IO, Tuple
from ..oss.utils import get_pytorch_folder from ..oss.utils import get_pytorch_folder
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
CoverageItem = Tuple[str, float, int, int] CoverageItem = Tuple[str, float, int, int]
@ -16,7 +19,7 @@ def key_by_name(x: CoverageItem) -> str:
return x[0] return x[0]
def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool: def is_intrested_file(file_path: str, interested_folders: list[str]) -> bool:
if "cuda" in file_path: if "cuda" in file_path:
return False return False
if "aten/gen_aten" in file_path or "aten/aten_" in file_path: if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
@ -27,7 +30,7 @@ def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
return False return False
def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool: def is_this_type_of_tests(target_name: str, test_set_by_type: set[str]) -> bool:
# tests are divided into three types: success / partial success / fail to collect coverage # tests are divided into three types: success / partial success / fail to collect coverage
for test in test_set_by_type: for test in test_set_by_type:
if target_name in test: if target_name in test:
@ -36,7 +39,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
def print_test_by_type( def print_test_by_type(
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str] tests: TestList, test_set_by_type: set[str], type_name: str, summary_file: IO[str]
) -> None: ) -> None:
print("Tests " + type_name + " to collect coverage:", file=summary_file) print("Tests " + type_name + " to collect coverage:", file=summary_file)
for test in tests: for test in tests:
@ -48,8 +51,8 @@ def print_test_by_type(
def print_test_condition( def print_test_condition(
tests: TestList, tests: TestList,
tests_type: TestStatusType, tests_type: TestStatusType,
interested_folders: List[str], interested_folders: list[str],
coverage_only: List[str], coverage_only: list[str],
summary_file: IO[str], summary_file: IO[str],
summary_type: str, summary_type: str,
) -> None: ) -> None:
@ -77,10 +80,10 @@ def print_test_condition(
def line_oriented_report( def line_oriented_report(
tests: TestList, tests: TestList,
tests_type: TestStatusType, tests_type: TestStatusType,
interested_folders: List[str], interested_folders: list[str],
coverage_only: List[str], coverage_only: list[str],
covered_lines: Dict[str, Set[int]], covered_lines: dict[str, set[int]],
uncovered_lines: Dict[str, Set[int]], uncovered_lines: dict[str, set[int]],
) -> None: ) -> None:
with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file: with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file:
print_test_condition( print_test_condition(
@ -119,13 +122,13 @@ def print_file_summary(
def print_file_oriented_report( def print_file_oriented_report(
tests_type: TestStatusType, tests_type: TestStatusType,
coverage: List[CoverageItem], coverage: list[CoverageItem],
covered_summary: int, covered_summary: int,
total_summary: int, total_summary: int,
summary_file: IO[str], summary_file: IO[str],
tests: TestList, tests: TestList,
interested_folders: List[str], interested_folders: list[str],
coverage_only: List[str], coverage_only: list[str],
) -> None: ) -> None:
coverage_percentage = print_file_summary( coverage_percentage = print_file_summary(
covered_summary, total_summary, summary_file covered_summary, total_summary, summary_file
@ -155,10 +158,10 @@ def print_file_oriented_report(
def file_oriented_report( def file_oriented_report(
tests: TestList, tests: TestList,
tests_type: TestStatusType, tests_type: TestStatusType,
interested_folders: List[str], interested_folders: list[str],
coverage_only: List[str], coverage_only: list[str],
covered_lines: Dict[str, Set[int]], covered_lines: dict[str, set[int]],
uncovered_lines: Dict[str, Set[int]], uncovered_lines: dict[str, set[int]],
) -> None: ) -> None:
with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file: with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file:
covered_summary = 0 covered_summary = 0
@ -193,7 +196,7 @@ def file_oriented_report(
) )
def get_html_ignored_pattern() -> List[str]: def get_html_ignored_pattern() -> list[str]:
return ["/usr/*", "*anaconda3/*", "*third_party/*"] return ["/usr/*", "*anaconda3/*", "*third_party/*"]

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import json import json
import os import os
import time import time
from typing import Any, Dict, List, Set, Tuple from typing import Any, TYPE_CHECKING
from ..util.setting import ( from ..util.setting import (
CompilerType, CompilerType,
@ -16,7 +18,6 @@ from ..util.utils import (
print_time, print_time,
related_to_test_list, related_to_test_list,
) )
from .parser.coverage_record import CoverageRecord
from .parser.gcov_coverage_parser import GcovCoverageParser from .parser.gcov_coverage_parser import GcovCoverageParser
from .parser.llvm_coverage_parser import LlvmCoverageParser from .parser.llvm_coverage_parser import LlvmCoverageParser
from .print_report import ( from .print_report import (
@ -26,16 +27,20 @@ from .print_report import (
) )
if TYPE_CHECKING:
from .parser.coverage_record import CoverageRecord
# coverage_records: Dict[str, LineInfo] = {} # coverage_records: Dict[str, LineInfo] = {}
covered_lines: Dict[str, Set[int]] = {} covered_lines: dict[str, set[int]] = {}
uncovered_lines: Dict[str, Set[int]] = {} uncovered_lines: dict[str, set[int]] = {}
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()} tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
def transform_file_name( def transform_file_name(
file_path: str, interested_folders: List[str], platform: TestPlatform file_path: str, interested_folders: list[str], platform: TestPlatform
) -> str: ) -> str:
remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"} remove_patterns: set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
for pattern in remove_patterns: for pattern in remove_patterns:
file_path = file_path.replace(pattern, "") file_path = file_path.replace(pattern, "")
# if user has specified interested folder # if user has specified interested folder
@ -54,7 +59,7 @@ def transform_file_name(
def is_intrested_file( def is_intrested_file(
file_path: str, interested_folders: List[str], platform: TestPlatform file_path: str, interested_folders: list[str], platform: TestPlatform
) -> bool: ) -> bool:
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"] ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
if any(pattern in file_path for pattern in ignored_patterns): if any(pattern in file_path for pattern in ignored_patterns):
@ -77,7 +82,7 @@ def is_intrested_file(
return True return True
def get_json_obj(json_file: str) -> Tuple[Any, int]: def get_json_obj(json_file: str) -> tuple[Any, int]:
""" """
Sometimes at the start of file llvm/gcov will complains "fail to find coverage data", Sometimes at the start of file llvm/gcov will complains "fail to find coverage data",
then we need to skip these lines then we need to skip these lines
@ -102,7 +107,7 @@ def get_json_obj(json_file: str) -> Tuple[Any, int]:
return None, 2 return None, 2
def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]: def parse_json(json_file: str, platform: TestPlatform) -> list[CoverageRecord]:
print("start parse:", json_file) print("start parse:", json_file)
json_obj, read_status = get_json_obj(json_file) json_obj, read_status = get_json_obj(json_file)
if read_status == 0: if read_status == 0:
@ -117,7 +122,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
cov_type = detect_compiler_type(platform) cov_type = detect_compiler_type(platform)
coverage_records: List[CoverageRecord] = [] coverage_records: list[CoverageRecord] = []
if cov_type == CompilerType.CLANG: if cov_type == CompilerType.CLANG:
coverage_records = LlvmCoverageParser(json_obj).parse("fbcode") coverage_records = LlvmCoverageParser(json_obj).parse("fbcode")
# print(coverage_records) # print(coverage_records)
@ -128,7 +133,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
def parse_jsons( def parse_jsons(
test_list: TestList, interested_folders: List[str], platform: TestPlatform test_list: TestList, interested_folders: list[str], platform: TestPlatform
) -> None: ) -> None:
g = os.walk(JSON_FOLDER_BASE_DIR) g = os.walk(JSON_FOLDER_BASE_DIR)
@ -152,8 +157,8 @@ def parse_jsons(
def update_coverage( def update_coverage(
coverage_records: List[CoverageRecord], coverage_records: list[CoverageRecord],
interested_folders: List[str], interested_folders: list[str],
platform: TestPlatform, platform: TestPlatform,
) -> None: ) -> None:
for item in coverage_records: for item in coverage_records:
@ -187,8 +192,8 @@ def update_set() -> None:
def summarize_jsons( def summarize_jsons(
test_list: TestList, test_list: TestList,
interested_folders: List[str], interested_folders: list[str],
coverage_only: List[str], coverage_only: list[str],
platform: TestPlatform, platform: TestPlatform,
) -> None: ) -> None:
start_time = time.time() start_time = time.time()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
from enum import Enum from enum import Enum
from typing import Dict, List, Set from typing import Dict, List, Set

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import os import os
import shutil import shutil
import sys import sys
import time import time
from typing import Any, NoReturn, Optional from typing import Any, NoReturn
from .setting import ( from .setting import (
CompilerType, CompilerType,
@ -113,7 +115,7 @@ def get_test_name_from_whole_path(path: str) -> str:
return path[start + 1 : end] return path[start + 1 : end]
def check_compiler_type(cov_type: Optional[CompilerType]) -> None: def check_compiler_type(cov_type: CompilerType | None) -> None:
if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]: if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]:
return return
raise Exception( # noqa: TRY002 raise Exception( # noqa: TRY002

View File

@ -1,5 +1,6 @@
import setuptools # type: ignore[import] import setuptools # type: ignore[import]
with open("README.md", encoding="utf-8") as fh: with open("README.md", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()

View File

@ -22,6 +22,7 @@ from typing import Any
from coverage import CoverageData, CoveragePlugin # type: ignore[import] from coverage import CoverageData, CoveragePlugin # type: ignore[import]
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with # All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link: # `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
# https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine # https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine

View File

@ -5,6 +5,7 @@ import sys
from urllib.error import URLError from urllib.error import URLError
from urllib.request import urlretrieve from urllib.request import urlretrieve
MIRRORS = [ MIRRORS = [
"http://yann.lecun.com/exdb/mnist/", "http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/",

View File

@ -5,6 +5,7 @@ import sys
import traceback import traceback
import warnings import warnings
MIN_CUDA_VERSION = "11.6" MIN_CUDA_VERSION = "11.6"
MIN_ROCM_VERSION = "5.4" MIN_ROCM_VERSION = "5.4"
MIN_PYTHON_VERSION = (3, 8) MIN_PYTHON_VERSION = (3, 8)
@ -141,7 +142,7 @@ def check_rocm():
return rocm_ver if torch.version.hip else "None" return rocm_ver if torch.version.hip else "None"
def check_dynamo(backend, device, err_msg): def check_dynamo(backend, device, err_msg) -> None:
import torch import torch
if device == "cuda" and not torch.cuda.is_available(): if device == "cuda" and not torch.cuda.is_available():
@ -203,7 +204,7 @@ _SANITY_CHECK_ARGS = (
) )
def main(): def main() -> None:
python_ver = check_python() python_ver = check_python()
torch_ver = check_torch() torch_ver = check_torch()
cuda_ver = check_cuda() cuda_ver = check_cuda()

View File

@ -1,14 +1,17 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict
from typing_extensions import TypedDict # Python 3.11+ from typing_extensions import TypedDict # Python 3.11+
import yaml import yaml
Step = Dict[str, Any] Step = Dict[str, Any]
@ -17,7 +20,7 @@ class Script(TypedDict):
script: str script: str
def extract(step: Step) -> Optional[Script]: def extract(step: Step) -> Script | None:
run = step.get("run") run = step.get("run")
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import array import array
import codecs import codecs
@ -15,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess import subprocess
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any
import yaml import yaml
from yaml.constructor import ConstructorError from yaml.constructor import ConstructorError
@ -29,14 +31,14 @@ except ImportError:
CPP_H_NAME = "spv.h" CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp" CPP_SRC_NAME = "spv.cpp"
DEFAULT_ENV: Dict[str, Any] = { DEFAULT_ENV: dict[str, Any] = {
"PRECISION": "highp", "PRECISION": "highp",
"FLOAT_IMAGE_FORMAT": "rgba16f", "FLOAT_IMAGE_FORMAT": "rgba16f",
"INT_IMAGE_FORMAT": "rgba32i", "INT_IMAGE_FORMAT": "rgba32i",
"UINT_IMAGE_FORMAT": "rgba32ui", "UINT_IMAGE_FORMAT": "rgba32ui",
} }
TYPES_ENV: Dict[str, Any] = { TYPES_ENV: dict[str, Any] = {
"IMAGE_FORMAT": { "IMAGE_FORMAT": {
"float": "rgba32f", "float": "rgba32f",
"half": "rgba16f", "half": "rgba16f",
@ -91,7 +93,7 @@ TYPES_ENV: Dict[str, Any] = {
}, },
} }
FUNCS_ENV: Dict[str, Any] = { FUNCS_ENV: dict[str, Any] = {
"GET_POS": { "GET_POS": {
3: lambda pos: pos, 3: lambda pos: pos,
2: lambda pos: f"{pos}.xy", 2: lambda pos: f"{pos}.xy",
@ -169,7 +171,7 @@ def escape(line: str) -> str:
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py # https://github.com/google/XNNPACK/blob/master/tools/xngen.py
def preprocess( def preprocess(
input_text: str, variables: Dict[str, Any], input_path: str = "codegen" input_text: str, variables: dict[str, Any], input_path: str = "codegen"
) -> str: ) -> str:
input_lines = input_text.splitlines() input_lines = input_text.splitlines()
python_lines = [] python_lines = []
@ -243,9 +245,9 @@ def preprocess(
class SPVGenerator: class SPVGenerator:
def __init__( def __init__(
self, self,
src_dir_paths: Union[str, List[str]], src_dir_paths: str | list[str],
env: Dict[Any, Any], env: dict[Any, Any],
glslc_path: Optional[str], glslc_path: str | None,
) -> None: ) -> None:
if isinstance(src_dir_paths, str): if isinstance(src_dir_paths, str):
self.src_dir_paths = [src_dir_paths] self.src_dir_paths = [src_dir_paths]
@ -255,18 +257,18 @@ class SPVGenerator:
self.env = env self.env = env
self.glslc_path = glslc_path self.glslc_path = glslc_path
self.glsl_src_files: Dict[str, str] = {} self.glsl_src_files: dict[str, str] = {}
self.template_yaml_files: List[str] = [] self.template_yaml_files: list[str] = []
self.addSrcAndYamlFiles(self.src_dir_paths) self.addSrcAndYamlFiles(self.src_dir_paths)
self.shader_template_params: Dict[Any, Any] = {} self.shader_template_params: dict[Any, Any] = {}
for yaml_file in self.template_yaml_files: for yaml_file in self.template_yaml_files:
self.parseTemplateYaml(yaml_file) self.parseTemplateYaml(yaml_file)
self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {} self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
self.constructOutputMap() self.constructOutputMap()
def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None: def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
for src_path in src_dir_paths: for src_path in src_dir_paths:
# Collect glsl source files # Collect glsl source files
glsl_files = glob.glob( glsl_files = glob.glob(
@ -285,9 +287,9 @@ class SPVGenerator:
def generateVariantCombinations( def generateVariantCombinations(
self, self,
iterated_params: Dict[str, Any], iterated_params: dict[str, Any],
exclude_params: Optional[Set[str]] = None, exclude_params: set[str] | None = None,
) -> List[Any]: ) -> list[Any]:
if exclude_params is None: if exclude_params is None:
exclude_params = set() exclude_params = set()
all_iterated_params = [] all_iterated_params = []
@ -362,8 +364,8 @@ class SPVGenerator:
) )
def create_shader_params( def create_shader_params(
self, variant_params: Optional[Dict[str, Any]] = None self, variant_params: dict[str, Any] | None = None
) -> Dict[str, str]: ) -> dict[str, str]:
if variant_params is None: if variant_params is None:
variant_params = {} variant_params = {}
shader_params = copy.deepcopy(self.env) shader_params = copy.deepcopy(self.env)
@ -409,7 +411,7 @@ class SPVGenerator:
self.create_shader_params(), self.create_shader_params(),
) )
def generateSPV(self, output_dir: str) -> Dict[str, str]: def generateSPV(self, output_dir: str) -> dict[str, str]:
output_file_map = {} output_file_map = {}
for shader_name in self.output_shader_map: for shader_name in self.output_shader_map:
source_glsl = self.output_shader_map[shader_name][0] source_glsl = self.output_shader_map[shader_name][0]
@ -457,11 +459,11 @@ class SPVGenerator:
@dataclass @dataclass
class ShaderInfo: class ShaderInfo:
tile_size: List[int] tile_size: list[int]
layouts: List[str] layouts: list[str]
weight_storage_type: str = "" weight_storage_type: str = ""
bias_storage_type: str = "" bias_storage_type: str = ""
register_for: Optional[Tuple[str, List[str]]] = None register_for: tuple[str, list[str]] | None = None
def getName(filePath: str) -> str: def getName(filePath: str) -> str:
@ -478,7 +480,7 @@ def isTileSizeLine(lineStr: str) -> bool:
return re.search(tile_size_id, lineStr) is not None return re.search(tile_size_id, lineStr) is not None
def findTileSizes(lineStr: str) -> List[int]: def findTileSizes(lineStr: str) -> list[int]:
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr) matches = re.search(tile_size_id, lineStr)
if matches is None: if matches is None:
@ -520,7 +522,7 @@ def isRegisterForLine(lineStr: str) -> bool:
return re.search(register_for_id, lineStr) is not None return re.search(register_for_id, lineStr) is not None
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
register_for_pattern = r"'([A-Za-z0-9_]+)'" register_for_pattern = r"'([A-Za-z0-9_]+)'"
matches = re.findall(register_for_pattern, lineStr) matches = re.findall(register_for_pattern, lineStr)
if matches is None: if matches is None:
@ -609,7 +611,7 @@ static const api::ShaderRegisterInit register_shaders(&register_fn);
""" """
def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]: def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
with open(spvPath, "rb") as fr: with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read()) next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin) sizeBytes = 4 * len(next_bin)
@ -665,7 +667,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
def genCppFiles( def genCppFiles(
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
) -> None: ) -> None:
spv_bin_strs = [] spv_bin_strs = []
register_shader_info_strs = [] register_shader_info_strs = []
@ -705,7 +707,7 @@ def genCppFiles(
########## ##########
def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]: def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
d = {} d = {}
if items: if items:
for item in items: for item in items:
@ -716,7 +718,7 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
return d return d
def main(argv: List[str]) -> int: def main(argv: list[str]) -> int:
parser = argparse.ArgumentParser(description="") parser = argparse.ArgumentParser(description="")
parser.add_argument( parser.add_argument(
"-i", "-i",

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import argparse import argparse
import os import os
import re import re
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Optional, Union
from setuptools import distutils # type: ignore[import] from setuptools import distutils # type: ignore[import]
@ -12,7 +13,7 @@ UNKNOWN = "Unknown"
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/") RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
def get_sha(pytorch_root: Union[str, Path]) -> str: def get_sha(pytorch_root: str | Path) -> str:
try: try:
rev = None rev = None
if os.path.exists(os.path.join(pytorch_root, ".git")): if os.path.exists(os.path.join(pytorch_root, ".git")):
@ -30,7 +31,7 @@ def get_sha(pytorch_root: Union[str, Path]) -> str:
return UNKNOWN return UNKNOWN
def get_tag(pytorch_root: Union[str, Path]) -> str: def get_tag(pytorch_root: str | Path) -> str:
try: try:
tag = subprocess.run( tag = subprocess.run(
["git", "describe", "--tags", "--exact"], ["git", "describe", "--tags", "--exact"],
@ -46,8 +47,8 @@ def get_tag(pytorch_root: Union[str, Path]) -> str:
return UNKNOWN return UNKNOWN
def get_torch_version(sha: Optional[str] = None) -> str: def get_torch_version(sha: str | None = None) -> str:
pytorch_root = Path(__file__).parent.parent pytorch_root = Path(__file__).absolute().parent.parent
version = open(pytorch_root / "version.txt").read().strip() version = open(pytorch_root / "version.txt").read().strip()
if os.getenv("PYTORCH_BUILD_VERSION"): if os.getenv("PYTORCH_BUILD_VERSION"):

View File

@ -1,10 +1,10 @@
"""GitHub Utilities""" """GitHub Utilities"""
from __future__ import annotations
import json import json
import os import os
from typing import Any, Callable, cast, Dict
from typing import Any, Callable, cast, Dict, Optional, Tuple
from urllib.error import HTTPError from urllib.error import HTTPError
from urllib.parse import quote from urllib.parse import quote
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
@ -13,11 +13,11 @@ from urllib.request import Request, urlopen
def gh_fetch_url_and_headers( def gh_fetch_url_and_headers(
url: str, url: str,
*, *,
headers: Optional[Dict[str, str]] = None, headers: dict[str, str] | None = None,
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
method: Optional[str] = None, method: str | None = None,
reader: Callable[[Any], Any] = lambda x: x.read(), reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Tuple[Any, Any]: ) -> tuple[Any, Any]:
if headers is None: if headers is None:
headers = {} headers = {}
token = os.environ.get("GITHUB_TOKEN") token = os.environ.get("GITHUB_TOKEN")
@ -44,9 +44,9 @@ def gh_fetch_url_and_headers(
def gh_fetch_url( def gh_fetch_url(
url: str, url: str,
*, *,
headers: Optional[Dict[str, str]] = None, headers: dict[str, str] | None = None,
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
method: Optional[str] = None, method: str | None = None,
reader: Callable[[Any], Any] = lambda x: x.read(), reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Any: ) -> Any:
return gh_fetch_url_and_headers( return gh_fetch_url_and_headers(
@ -56,8 +56,8 @@ def gh_fetch_url(
def _gh_fetch_json_any( def _gh_fetch_json_any(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
) -> Any: ) -> Any:
headers = {"Accept": "application/vnd.github.v3+json"} headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0: if params is not None and len(params) > 0:
@ -69,13 +69,13 @@ def _gh_fetch_json_any(
def gh_fetch_json_dict( def gh_fetch_json_dict(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data)) return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
def gh_fetch_commit(org: str, repo: str, sha: str) -> Dict[str, Any]: def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]:
return gh_fetch_json_dict( return gh_fetch_json_dict(
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}" f"https://api.github.com/repos/{org}/{repo}/commits/{sha}"
) )

View File

@ -1,6 +1,7 @@
import re import re
import sys import sys
QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"') QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>") ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")

View File

@ -1,10 +1,13 @@
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp. # Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
from __future__ import annotations
import argparse import argparse
import os import os
import pathlib
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Literal, Sequence, Union from pathlib import Path
from typing import Literal, Sequence, TYPE_CHECKING
import yaml import yaml
@ -15,10 +18,13 @@ from torchgen.api.unboxing import convert_arguments
from torchgen.context import method_with_native_function from torchgen.context import method_with_native_function
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp. # Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeUnboxingFunctions: class ComputeUnboxingFunctions:
@ -156,7 +162,7 @@ def gen_unboxing(
cpu_fm: FileManager, cpu_fm: FileManager,
selector: SelectiveBuilder, selector: SelectiveBuilder,
) -> None: ) -> None:
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
return fn.root_name return fn.root_name
selected_op_num: int = len(selector.operators) selected_op_num: int = len(selector.operators)
@ -195,7 +201,7 @@ def gen_unboxing(
) )
def main(args: List[str]) -> None: def main(args: list[str]) -> None:
parser = argparse.ArgumentParser(description="Generate unboxing source files") parser = argparse.ArgumentParser(description="Generate unboxing source files")
parser.add_argument( parser.add_argument(
"-s", "-s",
@ -272,7 +278,7 @@ def main(args: List[str]) -> None:
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector) gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
if options.output_dependencies: if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve() depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name depfile_name = depfile_path.name
depfile_stem = depfile_path.stem depfile_stem = depfile_path.stem

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -8,7 +10,7 @@ import subprocess
import sys import sys
import time import time
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional, Pattern from typing import NamedTuple
LINTER_CODE = "ACTIONLINT" LINTER_CODE = "ACTIONLINT"
@ -22,18 +24,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
RESULTS_RE: Pattern[str] = re.compile( RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx) r"""(?mx)
^ ^
(?P<file>.*?): (?P<file>.*?):
@ -47,8 +49,8 @@ RESULTS_RE: Pattern[str] = re.compile(
def run_command( def run_command(
args: List[str], args: list[str],
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -64,7 +66,7 @@ def run_command(
def check_file( def check_file(
binary: str, binary: str,
file: str, file: str,
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
proc = run_command( proc = run_command(
[ [

View File

@ -5,6 +5,9 @@ archive is downloaded from some sites like GitHub because it can change. Specifi
GitHub gives no guarantee to keep the same value forever. Check for more details at GitHub gives no guarantee to keep the same value forever. Check for more details at
https://github.com/community/community/discussions/46034. https://github.com/community/community/discussions/46034.
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import re import re
@ -13,7 +16,7 @@ import subprocess
import sys import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional, Set from typing import NamedTuple
from urllib.parse import urlparse from urllib.parse import urlparse
@ -30,18 +33,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def is_required_checksum(urls: List[Optional[str]]) -> bool: def is_required_checksum(urls: list[str | None]) -> bool:
if not urls: if not urls:
return False return False
@ -58,7 +61,7 @@ def is_required_checksum(urls: List[Optional[str]]) -> bool:
def get_disallowed_checksums( def get_disallowed_checksums(
binary: str, binary: str,
) -> Set[str]: ) -> set[str]:
""" """
Return the set of disallowed checksums from all http_archive rules Return the set of disallowed checksums from all http_archive rules
""" """
@ -96,8 +99,8 @@ def get_disallowed_checksums(
def check_bazel( def check_bazel(
filename: str, filename: str,
disallowed_checksums: Set[str], disallowed_checksums: set[str],
) -> List[LintMessage]: ) -> list[LintMessage]:
original = "" original = ""
replacement = "" replacement = ""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -7,7 +9,7 @@ import subprocess
import sys import sys
import time import time
from enum import Enum from enum import Enum
from typing import Any, BinaryIO, List, NamedTuple, Optional from typing import Any, BinaryIO, NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -41,11 +43,11 @@ def as_posix(name: str) -> str:
def _run_command( def _run_command(
args: List[str], args: list[str],
*, *,
stdin: BinaryIO, stdin: BinaryIO,
timeout: int, timeout: int,
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -63,12 +65,12 @@ def _run_command(
def run_command( def run_command(
args: List[str], args: list[str],
*, *,
stdin: BinaryIO, stdin: BinaryIO,
retries: int, retries: int,
timeout: int, timeout: int,
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
remaining_retries = retries remaining_retries = retries
while True: while True:
try: try:
@ -90,7 +92,7 @@ def check_file(
filename: str, filename: str,
retries: int, retries: int,
timeout: int, timeout: int,
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
with open(filename, "rb") as f: with open(filename, "rb") as f:
original = f.read() original = f.read()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -8,7 +10,7 @@ import sys
import time import time
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, List, NamedTuple, Optional from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -42,10 +44,10 @@ def as_posix(name: str) -> str:
def _run_command( def _run_command(
args: List[str], args: list[str],
*, *,
timeout: int, timeout: int,
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -62,11 +64,11 @@ def _run_command(
def run_command( def run_command(
args: List[str], args: list[str],
*, *,
retries: int, retries: int,
timeout: int, timeout: int,
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
remaining_retries = retries remaining_retries = retries
while True: while True:
try: try:
@ -89,7 +91,7 @@ def check_file(
binary: str, binary: str,
retries: int, retries: int,
timeout: int, timeout: int,
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
with open(filename, "rb") as f: with open(filename, "rb") as f:
original = f.read() original = f.read()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -11,7 +13,7 @@ import time
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from sysconfig import get_paths as gp from sysconfig import get_paths as gp
from typing import Any, List, NamedTuple, Optional, Pattern from typing import Any, NamedTuple
# PyTorch directory root # PyTorch directory root
@ -49,15 +51,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -65,7 +67,7 @@ def as_posix(name: str) -> str:
# c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move] # c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move]
RESULTS_RE: Pattern[str] = re.compile( RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx) r"""(?mx)
^ ^
(?P<file>.*?): (?P<file>.*?):
@ -80,8 +82,8 @@ RESULTS_RE: Pattern[str] = re.compile(
def run_command( def run_command(
args: List[str], args: list[str],
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -103,7 +105,7 @@ severities = {
} }
def clang_search_dirs() -> List[str]: def clang_search_dirs() -> list[str]:
# Compilers are ordered based on fallback preference # Compilers are ordered based on fallback preference
# We pick the first one that is available on the system # We pick the first one that is available on the system
compilers = ["clang", "gcc", "cpp", "cc"] compilers = ["clang", "gcc", "cpp", "cc"]
@ -152,7 +154,7 @@ def check_file(
filename: str, filename: str,
binary: str, binary: str,
build_dir: Path, build_dir: Path,
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
proc = run_command( proc = run_command(
[binary, f"-p={build_dir}", *include_args, filename], [binary, f"-p={build_dir}", *include_args, filename],

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -7,7 +9,7 @@ import re
import subprocess import subprocess
import time import time
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional, Pattern from typing import NamedTuple
LINTER_CODE = "CMAKE" LINTER_CODE = "CMAKE"
@ -21,19 +23,19 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
# CMakeLists.txt:901: Lines should be <= 80 characters long [linelength] # CMakeLists.txt:901: Lines should be <= 80 characters long [linelength]
RESULTS_RE: Pattern[str] = re.compile( RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx) r"""(?mx)
^ ^
(?P<file>.*?): (?P<file>.*?):
@ -46,8 +48,8 @@ RESULTS_RE: Pattern[str] = re.compile(
def run_command( def run_command(
args: List[str], args: list[str],
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -63,7 +65,7 @@ def run_command(
def check_file( def check_file(
filename: str, filename: str,
config: str, config: str,
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
proc = run_command( proc = run_command(
["cmakelint", f"--config={config}", filename], ["cmakelint", f"--config={config}", filename],

View File

@ -2,13 +2,15 @@
CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
import sys import sys
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional from typing import NamedTuple
CONSTEXPR = "constexpr char" CONSTEXPR = "constexpr char"
CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char" CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char"
@ -21,18 +23,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def check_file(filename: str) -> Optional[LintMessage]: def check_file(filename: str) -> LintMessage | None:
logging.debug("Checking file %s", filename) logging.debug("Checking file %s", filename)
with open(filename) as f: with open(filename) as f:

View File

@ -1,14 +1,17 @@
""" """
EXEC: Ensure that source files are not executable. EXEC: Ensure that source files are not executable.
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
import os import os
import sys import sys
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional from typing import NamedTuple
LINTER_CODE = "EXEC" LINTER_CODE = "EXEC"
@ -21,18 +24,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def check_file(filename: str) -> Optional[LintMessage]: def check_file(filename: str) -> LintMessage | None:
is_executable = os.access(filename, os.X_OK) is_executable = os.access(filename, os.X_OK)
if is_executable: if is_executable:
return LintMessage( return LintMessage(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
@ -7,7 +9,7 @@ import subprocess
import sys import sys
import time import time
from enum import Enum from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
# fmt: off # fmt: off
# https://www.flake8rules.com/ # https://www.flake8rules.com/
DOCUMENTED_IN_FLAKE8RULES: Set[str] = { DOCUMENTED_IN_FLAKE8RULES: set[str] = {
"E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117", "E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117",
"E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129", "E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129",
"E131", "E133", "E131", "E133",
@ -78,14 +80,14 @@ DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
} }
# https://pypi.org/project/flake8-comprehensions/#rules # https://pypi.org/project/flake8-comprehensions/#rules
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: Set[str] = { DOCUMENTED_IN_FLAKE8COMPREHENSIONS: set[str] = {
"C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409", "C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409",
"C410", "C410",
"C411", "C412", "C413", "C414", "C415", "C416", "C411", "C412", "C413", "C414", "C415", "C416",
} }
# https://github.com/PyCQA/flake8-bugbear#list-of-warnings # https://github.com/PyCQA/flake8-bugbear#list-of-warnings
DOCUMENTED_IN_BUGBEAR: Set[str] = { DOCUMENTED_IN_BUGBEAR: set[str] = {
"B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010", "B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010",
"B011", "B012", "B013", "B014", "B015", "B011", "B012", "B013", "B014", "B015",
"B301", "B302", "B303", "B304", "B305", "B306", "B301", "B302", "B303", "B304", "B305", "B306",
@ -98,7 +100,7 @@ DOCUMENTED_IN_BUGBEAR: Set[str] = {
# stdin:3:6: T484 Name 'foo' is not defined # stdin:3:6: T484 Name 'foo' is not defined
# stdin:3:-100: W605 invalid escape sequence '\/' # stdin:3:-100: W605 invalid escape sequence '\/'
# stdin:3:1: E302 expected 2 blank lines, found 1 # stdin:3:1: E302 expected 2 blank lines, found 1
RESULTS_RE: Pattern[str] = re.compile( RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx) r"""(?mx)
^ ^
(?P<file>.*?): (?P<file>.*?):
@ -134,10 +136,10 @@ def _test_results_re() -> None:
def _run_command( def _run_command(
args: List[str], args: list[str],
*, *,
extra_env: Optional[Dict[str, str]], extra_env: dict[str, str] | None,
) -> "subprocess.CompletedProcess[str]": ) -> subprocess.CompletedProcess[str]:
logging.debug( logging.debug(
"$ %s", "$ %s",
" ".join( " ".join(
@ -158,11 +160,11 @@ def _run_command(
def run_command( def run_command(
args: List[str], args: list[str],
*, *,
extra_env: Optional[Dict[str, str]], extra_env: dict[str, str] | None,
retries: int, retries: int,
) -> "subprocess.CompletedProcess[str]": ) -> subprocess.CompletedProcess[str]:
remaining_retries = retries remaining_retries = retries
while True: while True:
try: try:
@ -243,11 +245,11 @@ def get_issue_documentation_url(code: str) -> str:
def check_files( def check_files(
filenames: List[str], filenames: list[str],
flake8_plugins_path: Optional[str], flake8_plugins_path: str | None,
severities: Dict[str, LintSeverity], severities: dict[str, LintSeverity],
retries: int, retries: int,
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
proc = run_command( proc = run_command(
[sys.executable, "-mflake8", "--exit-zero"] + filenames, [sys.executable, "-mflake8", "--exit-zero"] + filenames,
@ -351,7 +353,7 @@ def main() -> None:
else os.path.realpath(args.flake8_plugins_path) else os.path.realpath(args.flake8_plugins_path)
) )
severities: Dict[str, LintSeverity] = {} severities: dict[str, LintSeverity] = {}
if args.severity: if args.severity:
for severity in args.severity: for severity in args.severity:
parts = severity.split(":", 1) parts = severity.split(":", 1)

View File

@ -2,6 +2,8 @@
Generic linter that greps for a pattern and optionally suggests replacements. Generic linter that greps for a pattern and optionally suggests replacements.
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
@ -10,7 +12,7 @@ import subprocess
import sys import sys
import time import time
from enum import Enum from enum import Enum
from typing import Any, List, NamedTuple, Optional from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -44,8 +46,8 @@ def as_posix(name: str) -> str:
def run_command( def run_command(
args: List[str], args: list[str],
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -65,7 +67,7 @@ def lint_file(
linter_name: str, linter_name: str,
error_name: str, error_name: str,
error_description: str, error_description: str,
) -> Optional[LintMessage]: ) -> LintMessage | None:
# matching_line looks like: # matching_line looks like:
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz # tools/linter/clangtidy_linter.py:13:import foo.bar.baz
split = matching_line.split(":") split = matching_line.split(":")

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import json import json
import subprocess import subprocess
import sys import sys
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional, Tuple from typing import NamedTuple
LINTER_CODE = "LINTRUNNER_VERSION" LINTER_CODE = "LINTRUNNER_VERSION"
@ -16,18 +18,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def toVersionString(version_tuple: Tuple[int, int, int]) -> str: def toVersionString(version_tuple: tuple[int, int, int]) -> str:
return ".".join(str(x) for x in version_tuple) return ".".join(str(x) for x in version_tuple)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
@ -8,7 +10,7 @@ import sys
import time import time
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Pattern from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
# tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment] # tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment]
RESULTS_RE: Pattern[str] = re.compile( RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx) r"""(?mx)
^ ^
(?P<file>.*?): (?P<file>.*?):
@ -56,7 +58,7 @@ RESULTS_RE: Pattern[str] = re.compile(
) )
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR # torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
INTERNAL_ERROR_RE: Pattern[str] = re.compile( INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
r"""(?mx) r"""(?mx)
^ ^
(?P<file>.*?): (?P<file>.*?):
@ -69,11 +71,11 @@ INTERNAL_ERROR_RE: Pattern[str] = re.compile(
def run_command( def run_command(
args: List[str], args: list[str],
*, *,
extra_env: Optional[Dict[str, str]], extra_env: dict[str, str] | None,
retries: int, retries: int,
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -94,7 +96,7 @@ severities = {
} }
def check_mypy_installed(code: str) -> List[LintMessage]: def check_mypy_installed(code: str) -> list[LintMessage]:
cmd = [sys.executable, "-mmypy", "-V"] cmd = [sys.executable, "-mmypy", "-V"]
try: try:
subprocess.run(cmd, check=True, capture_output=True) subprocess.run(cmd, check=True, capture_output=True)
@ -117,11 +119,11 @@ def check_mypy_installed(code: str) -> List[LintMessage]:
def check_files( def check_files(
filenames: List[str], filenames: list[str],
config: str, config: str,
retries: int, retries: int,
code: str, code: str,
) -> List[LintMessage]: ) -> list[LintMessage]:
# dmypy has a bug where it won't pick up changes if you pass it absolute # dmypy has a bug where it won't pick up changes if you pass it absolute
# file names, see https://github.com/python/mypy/issues/16768 # file names, see https://github.com/python/mypy/issues/16768
filenames = [os.path.relpath(f) for f in filenames] filenames = [os.path.relpath(f) for f in filenames]
@ -224,7 +226,7 @@ def main() -> None:
# Use a dictionary here to preserve order. mypy cares about order, # Use a dictionary here to preserve order. mypy cares about order,
# tragically, e.g. https://github.com/python/mypy/issues/2015 # tragically, e.g. https://github.com/python/mypy/issues/2015
filenames: Dict[str, bool] = {} filenames: dict[str, bool] = {}
# If a stub file exists, have mypy check it instead of the original file, in # If a stub file exists, have mypy check it instead of the original file, in
# accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files) # accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)

View File

@ -14,12 +14,14 @@ is simply to make sure that there is *some* configuration of ruamel that can rou
the YAML, not to be prescriptive about it. the YAML, not to be prescriptive about it.
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import sys import sys
from enum import Enum from enum import Enum
from io import StringIO from io import StringIO
from typing import NamedTuple, Optional from typing import NamedTuple
import ruamel.yaml # type: ignore[import] import ruamel.yaml # type: ignore[import]
@ -32,15 +34,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,13 +1,16 @@
""" """
NEWLINE: Checks files to make sure there are no trailing newlines. NEWLINE: Checks files to make sure there are no trailing newlines.
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
import sys import sys
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional from typing import NamedTuple
NEWLINE = 10 # ASCII "\n" NEWLINE = 10 # ASCII "\n"
CARRIAGE_RETURN = 13 # ASCII "\r" CARRIAGE_RETURN = 13 # ASCII "\r"
@ -22,18 +25,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def check_file(filename: str) -> Optional[LintMessage]: def check_file(filename: str) -> LintMessage | None:
logging.debug("Checking file %s", filename) logging.debug("Checking file %s", filename)
with open(filename, "rb") as f: with open(filename, "rb") as f:
@ -85,7 +88,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.", description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
) )
has_changes = False has_changes = False
original_lines: Optional[List[bytes]] = None original_lines: list[bytes] | None = None
for idx, line in enumerate(lines): for idx, line in enumerate(lines):
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN: if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
if not has_changes: if not has_changes:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -5,7 +7,7 @@ import logging
import os import os
import sys import sys
from enum import Enum from enum import Enum
from typing import Any, List, NamedTuple, Optional from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
@ -23,18 +25,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def check_file(filename: str) -> List[LintMessage]: def check_file(filename: str) -> list[LintMessage]:
with open(filename, "rb") as f: with open(filename, "rb") as f:
original = f.read().decode("utf-8") original = f.read().decode("utf-8")
replacement = "" replacement = ""

View File

@ -1,6 +1,9 @@
""" """
Initializer script that installs stuff to pip. Initializer script that installs stuff to pip.
""" """
from __future__ import annotations
import argparse import argparse
import logging import logging
import os import os
@ -9,10 +12,8 @@ import subprocess
import sys import sys
import time import time
from typing import List
def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]:
def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]":
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:

View File

@ -14,6 +14,7 @@ import sys
import time import time
from typing import Any, BinaryIO from typing import Any, BinaryIO
LINTER_CODE = "RUFF" LINTER_CODE = "RUFF"
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
@ -6,7 +8,7 @@ import subprocess
import sys import sys
import time import time
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional from typing import NamedTuple
LINTER_CODE = "SHELLCHECK" LINTER_CODE = "SHELLCHECK"
@ -20,20 +22,20 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def run_command( def run_command(
args: List[str], args: list[str],
) -> "subprocess.CompletedProcess[bytes]": ) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args)) logging.debug("$ %s", " ".join(args))
start_time = time.monotonic() start_time = time.monotonic()
try: try:
@ -47,8 +49,8 @@ def run_command(
def check_files( def check_files(
files: List[str], files: list[str],
) -> List[LintMessage]: ) -> list[LintMessage]:
try: try:
proc = run_command( proc = run_command(
["shellcheck", "--external-sources", "--format=json1"] + files ["shellcheck", "--external-sources", "--format=json1"] + files

View File

@ -6,15 +6,19 @@ calls run_tests to ensure that the test will be run in OSS CI.
Takes ~2 minuters to run without the multiprocessing, probably overkill. Takes ~2 minuters to run without the multiprocessing, probably overkill.
""" """
from __future__ import annotations
import argparse import argparse
import json import json
import multiprocessing as mp import multiprocessing as mp
from enum import Enum from enum import Enum
from typing import List, NamedTuple, Optional from typing import NamedTuple
import libcst as cst import libcst as cst
import libcst.matchers as m import libcst.matchers as m
LINTER_CODE = "TEST_HAS_MAIN" LINTER_CODE = "TEST_HAS_MAIN"
@ -62,18 +66,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def check_file(filename: str) -> List[LintMessage]: def check_file(filename: str) -> list[LintMessage]:
lint_messages = [] lint_messages = []
with open(filename) as f: with open(filename) as f:

View File

@ -8,10 +8,13 @@ has valid ownership information in a comment header. Valid means:
- Each owner label actually exists in PyTorch - Each owner label actually exists in PyTorch
- Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS - Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS
""" """
from __future__ import annotations
import argparse import argparse
import json import json
from enum import Enum from enum import Enum
from typing import Any, List, NamedTuple, Optional from typing import Any, NamedTuple
from urllib.request import urlopen from urllib.request import urlopen
@ -26,15 +29,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
# Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions # Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions
@ -58,8 +61,8 @@ GLOB_EXCEPTIONS = ["**/test/run_test.py"]
def check_labels( def check_labels(
labels: List[str], filename: str, line_number: int labels: list[str], filename: str, line_number: int
) -> List[LintMessage]: ) -> list[LintMessage]:
lint_messages = [] lint_messages = []
for label in labels: for label in labels:
if label not in PYTORCH_LABELS: if label not in PYTORCH_LABELS:
@ -104,7 +107,7 @@ def check_labels(
return lint_messages return lint_messages
def check_file(filename: str) -> List[LintMessage]: def check_file(filename: str) -> list[LintMessage]:
lint_messages = [] lint_messages = []
has_ownership_info = False has_ownership_info = False

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import json import json
@ -6,7 +8,7 @@ import os
import sys import sys
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, List, NamedTuple, Optional from typing import Any, NamedTuple
from ufmt.core import ufmt_string from ufmt.core import ufmt_string
from ufmt.util import make_black_config from ufmt.util import make_black_config
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def as_posix(name: str) -> str: def as_posix(name: str) -> str:
@ -59,7 +61,7 @@ def format_error_message(filename: str, err: Exception) -> LintMessage:
def check_file( def check_file(
filename: str, filename: str,
) -> List[LintMessage]: ) -> list[LintMessage]:
with open(filename, "rb") as f: with open(filename, "rb") as f:
original = f.read().decode("utf-8") original = f.read().decode("utf-8")

View File

@ -2,16 +2,20 @@
Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`. Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`.
""" """
from __future__ import annotations
import argparse import argparse
import itertools import itertools
import json import json
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, NamedTuple, Optional from typing import Any, Iterable, NamedTuple
from yaml import dump, load from yaml import dump, load
# Safely load fast C Yaml loader/dumper if they are available # Safely load fast C Yaml loader/dumper if they are available
try: try:
from yaml import CSafeLoader as Loader from yaml import CSafeLoader as Loader
@ -27,15 +31,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: Optional[str] path: str | None
line: Optional[int] line: int | None
char: Optional[int] char: int | None
code: str code: str
severity: LintSeverity severity: LintSeverity
name: str name: str
original: Optional[str] original: str | None
replacement: Optional[str] replacement: str | None
description: Optional[str] description: str | None
def glob_yamls(path: Path) -> Iterable[Path]: def glob_yamls(path: Path) -> Iterable[Path]:
@ -51,7 +55,7 @@ def is_workflow(yaml: Any) -> bool:
return yaml.get("jobs") is not None return yaml.get("jobs") is not None
def print_lint_message(path: Path, job: Dict[str, Any], sync_tag: str) -> None: def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
job_id = next(iter(job.keys())) job_id = next(iter(job.keys()))
with open(path) as f: with open(path) as f:
lines = f.readlines() lines = f.readlines()

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import os import os
import subprocess import subprocess
import sys import sys
from typing import List
def run_cmd(cmd: List[str]) -> None: def run_cmd(cmd: list[str]) -> None:
print(f"Running: {cmd}") print(f"Running: {cmd}")
result = subprocess.run( result = subprocess.run(
cmd, cmd,

View File

@ -1,13 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import os import os
from typing import Set
import yaml import yaml
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder
# Safely load fast C Yaml loader/dumper if they are available # Safely load fast C Yaml loader/dumper if they are available
try: try:
from yaml import CSafeLoader as Loader from yaml import CSafeLoader as Loader
@ -46,7 +49,7 @@ selected_mobile_ops_preamble = """#pragma once
""" """
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]: def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
ops = [] ops = []
for op_name, op in selective_builder.operators.items(): for op_name, op in selective_builder.operators.items():
if op.is_root_operator: if op.is_root_operator:
@ -125,7 +128,7 @@ def write_selected_mobile_ops(
# 2. All kernel dtypes # 2. All kernel dtypes
def write_selected_mobile_ops_with_all_dtypes( def write_selected_mobile_ops_with_all_dtypes(
output_file_path: str, output_file_path: str,
root_ops: Set[str], root_ops: set[str],
) -> None: ) -> None:
with open(output_file_path, "wb") as out_file: with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble] body_parts = [selected_mobile_ops_preamble]

View File

@ -1,5 +1,6 @@
import lldb # type: ignore[import] import lldb # type: ignore[import]
# load into lldb instance with: # load into lldb instance with:
# command script import tools/lldb/deploy_debugger.py # command script import tools/lldb/deploy_debugger.py

View File

@ -24,6 +24,9 @@ well. This can be done with
Pulling will reinstalle the conda dependencies as well as the nightly binaries into Pulling will reinstalle the conda dependencies as well as the nightly binaries into
the repo directory. the repo directory.
""" """
from __future__ import annotations
import contextlib import contextlib
import datetime import datetime
import functools import functools
@ -40,23 +43,10 @@ import time
import uuid import uuid
from argparse import ArgumentParser from argparse import ArgumentParser
from ast import literal_eval from ast import literal_eval
from typing import ( from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
)
LOGGER: Optional[logging.Logger] = None
LOGGER: logging.Logger | None = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2" URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile("([0-9a-fA-F]{40})") SHA1_RE = re.compile("([0-9a-fA-F]{40})")
@ -68,9 +58,9 @@ SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphin
class Formatter(logging.Formatter): class Formatter(logging.Formatter):
redactions: Dict[str, str] redactions: dict[str, str]
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
super().__init__(fmt, datefmt) super().__init__(fmt, datefmt)
self.redactions = {} self.redactions = {}
@ -192,7 +182,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
sys.exit(1) sys.exit(1)
def check_in_repo() -> Optional[str]: def check_in_repo() -> str | None:
"""Ensures that we are in the PyTorch repo.""" """Ensures that we are in the PyTorch repo."""
if not os.path.isfile("setup.py"): if not os.path.isfile("setup.py"):
return "Not in root-level PyTorch repo, no setup.py found" return "Not in root-level PyTorch repo, no setup.py found"
@ -203,7 +193,7 @@ def check_in_repo() -> Optional[str]:
return None return None
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]: def check_branch(subcommand: str, branch: str | None) -> str | None:
"""Checks that the branch name can be checked out.""" """Checks that the branch name can be checked out."""
if subcommand != "checkout": if subcommand != "checkout":
return None return None
@ -259,7 +249,7 @@ def timed(prefix: str) -> Callable[[F], F]:
def _make_channel_args( def _make_channel_args(
channels: Iterable[str] = ("pytorch-nightly",), channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False, override_channels: bool = False,
) -> List[str]: ) -> list[str]:
args = [] args = []
for channel in channels: for channel in channels:
args.append("--channel") args.append("--channel")
@ -271,11 +261,11 @@ def _make_channel_args(
@timed("Solving conda environment") @timed("Solving conda environment")
def conda_solve( def conda_solve(
name: Optional[str] = None, name: str | None = None,
prefix: Optional[str] = None, prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",), channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False, override_channels: bool = False,
) -> Tuple[List[str], str, str, bool, List[str]]: ) -> tuple[list[str], str, str, bool, list[str]]:
"""Performs the conda solve and splits the deps from the package.""" """Performs the conda solve and splits the deps from the package."""
# compute what environment to use # compute what environment to use
if prefix is not None: if prefix is not None:
@ -329,7 +319,7 @@ def conda_solve(
@timed("Installing dependencies") @timed("Installing dependencies")
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None: def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
"""Install dependencies to deps environment""" """Install dependencies to deps environment"""
if not existing_env: if not existing_env:
# first remove previous pytorch-deps env # first remove previous pytorch-deps env
@ -342,7 +332,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No
@timed("Installing pytorch nightly binaries") @timed("Installing pytorch nightly binaries")
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]": def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
"""Install pytorch into a temporary directory""" """Install pytorch into a temporary directory"""
pytdir = tempfile.TemporaryDirectory() pytdir = tempfile.TemporaryDirectory()
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url] cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
@ -421,33 +411,33 @@ def pull_nightly_version(spdir: str) -> None:
p = subprocess.run(cmd, check=True) p = subprocess.run(cmd, check=True)
def _get_listing_linux(source_dir: str) -> List[str]: def _get_listing_linux(source_dir: str) -> list[str]:
listing = glob.glob(os.path.join(source_dir, "*.so")) listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so"))) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
return listing return listing
def _get_listing_osx(source_dir: str) -> List[str]: def _get_listing_osx(source_dir: str) -> list[str]:
# oddly, these are .so files even on Mac # oddly, these are .so files even on Mac
listing = glob.glob(os.path.join(source_dir, "*.so")) listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib"))) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
return listing return listing
def _get_listing_win(source_dir: str) -> List[str]: def _get_listing_win(source_dir: str) -> list[str]:
listing = glob.glob(os.path.join(source_dir, "*.pyd")) listing = glob.glob(os.path.join(source_dir, "*.pyd"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib"))) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll"))) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
return listing return listing
def _glob_pyis(d: str) -> Set[str]: def _glob_pyis(d: str) -> set[str]:
search = os.path.join(d, "**", "*.pyi") search = os.path.join(d, "**", "*.pyi")
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)} pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
return pyis return pyis
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]: def _find_missing_pyi(source_dir: str, target_dir: str) -> list[str]:
source_pyis = _glob_pyis(source_dir) source_pyis = _glob_pyis(source_dir)
target_pyis = _glob_pyis(target_dir) target_pyis = _glob_pyis(target_dir)
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)] missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
@ -455,7 +445,7 @@ def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
return missing_pyis return missing_pyis
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]: def _get_listing(source_dir: str, target_dir: str, platform: str) -> list[str]:
if platform.startswith("linux"): if platform.startswith("linux"):
listing = _get_listing_linux(source_dir) listing = _get_listing_linux(source_dir)
elif platform.startswith("osx"): elif platform.startswith("osx"):
@ -510,12 +500,12 @@ def _move_single(
mover(src, trg) mover(src, trg)
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None: def _copy_files(listing: list[str], source_dir: str, target_dir: str) -> None:
for src in listing: for src in listing:
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying") _move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None: def _link_files(listing: list[str], source_dir: str, target_dir: str) -> None:
for src in listing: for src in listing:
_move_single(src, source_dir, target_dir, os.link, "Linking") _move_single(src, source_dir, target_dir, os.link, "Linking")
@ -537,7 +527,7 @@ def move_nightly_files(spdir: str, platform: str) -> None:
_copy_files(listing, source_dir, target_dir) _copy_files(listing, source_dir, target_dir)
def _available_envs() -> Dict[str, str]: def _available_envs() -> dict[str, str]:
cmd = ["conda", "env", "list"] cmd = ["conda", "env", "list"]
p = subprocess.run( p = subprocess.run(
cmd, cmd,
@ -559,7 +549,7 @@ def _available_envs() -> Dict[str, str]:
@timed("Writing pytorch-nightly.pth") @timed("Writing pytorch-nightly.pth")
def write_pth(env_opts: List[str], platform: str) -> None: def write_pth(env_opts: list[str], platform: str) -> None:
"""Writes Python path file for this dir.""" """Writes Python path file for this dir."""
env_type, env_dir = env_opts env_type, env_dir = env_opts
if env_type == "--name": if env_type == "--name":
@ -582,9 +572,9 @@ def install(
*, *,
logger: logging.Logger, logger: logging.Logger,
subcommand: str = "checkout", subcommand: str = "checkout",
branch: Optional[str] = None, branch: str | None = None,
name: Optional[str] = None, name: str | None = None,
prefix: Optional[str] = None, prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",), channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False, override_channels: bool = False,
) -> None: ) -> None:
@ -673,7 +663,7 @@ def make_parser() -> ArgumentParser:
return p return p
def main(args: Optional[Sequence[str]] = None) -> None: def main(args: Sequence[str] | None = None) -> None:
"""Main entry point""" """Main entry point"""
global LOGGER global LOGGER
p = make_parser() p = make_parser()

View File

@ -13,13 +13,15 @@ CMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache"
""" """
from __future__ import annotations
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import List, Optional, TextIO from typing import TextIO
def resolve_include(path: Path, include_dirs: List[Path]) -> Path: def resolve_include(path: Path, include_dirs: list[Path]) -> Path:
for include_path in include_dirs: for include_path in include_dirs:
abs_path = include_path / path abs_path = include_path / path
if abs_path.exists(): if abs_path.exists():
@ -36,7 +38,7 @@ Tried the following paths, but none existed:
) )
def repair_depfile(depfile: TextIO, include_dirs: List[Path]) -> None: def repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None:
changes_made = False changes_made = False
out = "" out = ""
for line in depfile: for line in depfile:
@ -70,8 +72,8 @@ PRE_INCLUDE_ARGS = ["-include", "--pre-include"]
POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"] POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
def extract_include_arg(include_dirs: List[Path], i: int, args: List[str]) -> None: def extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None:
def extract_one(name: str, i: int, args: List[str]) -> Optional[str]: def extract_one(name: str, i: int, args: list[str]) -> str | None:
arg = args[i] arg = args[i]
if arg == name: if arg == name:
return args[i + 1] return args[i + 1]

View File

@ -24,6 +24,7 @@ import yaml
from torchgen import utils as torchgen_utils from torchgen import utils as torchgen_utils
from torchgen.yaml_utils import YamlLoader from torchgen.yaml_utils import YamlLoader
_RULES_GENERATED_COMMENT = """\ _RULES_GENERATED_COMMENT = """\
GENERATED CODE - DO NOT EDIT DIRECTLY GENERATED CODE - DO NOT EDIT DIRECTLY
This file is generated by gen_diagnostics.py. This file is generated by gen_diagnostics.py.

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import argparse import argparse
import collections import collections
import importlib import importlib
import sys import sys
from pprint import pformat from pprint import pformat
from typing import Dict, List, Sequence from typing import Sequence
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from warnings import warn from warnings import warn
@ -220,7 +222,7 @@ to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname: str) -> List[str]: def sig_for_ops(opname: str) -> list[str]:
"""sig_for_ops(opname : str) -> List[str] """sig_for_ops(opname : str) -> List[str]
Returns signatures for operator special functions (__add__ etc.)""" Returns signatures for operator special functions (__add__ etc.)"""
@ -254,8 +256,8 @@ def sig_for_ops(opname: str) -> List[str]:
raise Exception("unknown op", opname) # noqa: TRY002 raise Exception("unknown op", opname) # noqa: TRY002
def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
type_hints: List[str] = [] type_hints: list[str] = []
# Some deprecated ops that are on the blocklist are still included in pyi # Some deprecated ops that are on the blocklist are still included in pyi
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
@ -285,7 +287,7 @@ def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
return type_hints return type_hints
def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]: def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
flag_pos = arg_list.index("{return_indices}") flag_pos = arg_list.index("{return_indices}")
# If return_indices is positional arg, everything before should have no default # If return_indices is positional arg, everything before should have no default
arg_list_positional = ( arg_list_positional = (
@ -329,7 +331,7 @@ def gen_nn_functional(fm: FileManager) -> None:
) )
# TODO the list for `torch._C._nn` is nonexhaustive # TODO the list for `torch._C._nn` is nonexhaustive
unsorted_c_nn_function_hints: Dict[str, List[str]] = {} unsorted_c_nn_function_hints: dict[str, list[str]] = {}
for d in (2, 3): for d in (2, 3):
unsorted_c_nn_function_hints.update( unsorted_c_nn_function_hints.update(
@ -471,7 +473,7 @@ def gen_nn_functional(fm: FileManager) -> None:
} }
) )
c_nn_function_hints: List[str] = [] c_nn_function_hints: list[str] = []
for _, hints in sorted(unsorted_c_nn_function_hints.items()): for _, hints in sorted(unsorted_c_nn_function_hints.items()):
if len(hints) > 1: if len(hints) > 1:
hints = ["@overload\n" + h for h in hints] hints = ["@overload\n" + h for h in hints]
@ -528,7 +530,7 @@ def gen_nn_functional(fm: FileManager) -> None:
) )
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
unsorted_dispatched_hints: Dict[str, List[str]] = {} unsorted_dispatched_hints: dict[str, list[str]] = {}
for d in (1, 2, 3): for d in (1, 2, 3):
unsorted_dispatched_hints.update( unsorted_dispatched_hints.update(
@ -563,7 +565,7 @@ def gen_nn_functional(fm: FileManager) -> None:
# There's no fractional_max_pool1d # There's no fractional_max_pool1d
del unsorted_dispatched_hints["fractional_max_pool1d"] del unsorted_dispatched_hints["fractional_max_pool1d"]
dispatched_hints: List[str] = [] dispatched_hints: list[str] = []
for _, hints in sorted(unsorted_dispatched_hints.items()): for _, hints in sorted(unsorted_dispatched_hints.items()):
if len(hints) > 1: if len(hints) > 1:
hints = ["@overload\n" + h for h in hints] hints = ["@overload\n" + h for h in hints]
@ -594,7 +596,7 @@ We gather the docstrings for torch with the following steps:
""" """
def gather_docstrs() -> Dict[str, str]: def gather_docstrs() -> dict[str, str]:
docstrs = {} docstrs = {}
def mock_add_docstr(func: Mock, docstr: str) -> None: def mock_add_docstr(func: Mock, docstr: str) -> None:
@ -648,12 +650,12 @@ def gen_pyi(
# also needs to update the other file. # also needs to update the other file.
# Dictionary for NamedTuple definitions # Dictionary for NamedTuple definitions
structseqs: Dict[str, str] = {} structseqs: dict[str, str] = {}
# Generate type signatures for top-level functions # Generate type signatures for top-level functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
for n, n1, n2 in [ for n, n1, n2 in [
("csr", "crow", "col"), ("csr", "crow", "col"),
@ -1054,7 +1056,7 @@ def gen_pyi(
# Generate type signatures for Tensor methods # Generate type signatures for Tensor methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
unsorted_tensor_method_hints.update( unsorted_tensor_method_hints.update(
{ {
"size": [ "size": [

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import argparse import argparse
import os import os
from typing import Any, List, Union from typing import Any
try: try:
from junitparser import ( # type: ignore[import] from junitparser import ( # type: ignore[import]
@ -23,8 +26,8 @@ except ImportError:
print("rich not found, for color output use 'pip install rich'") print("rich not found, for color output use 'pip install rich'")
def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported] def parse_junit_reports(path_to_reports: str) -> list[TestCase]: # type: ignore[no-any-unimported]
def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported] def parse_file(path: str) -> list[TestCase]: # type: ignore[no-any-unimported]
try: try:
return convert_junit_to_testcases(JUnitXml.fromfile(path)) return convert_junit_to_testcases(JUnitXml.fromfile(path))
except Exception as err: except Exception as err:
@ -46,7 +49,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore
return ret_xml return ret_xml
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported] def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]: # type: ignore[no-any-unimported]
testcases = [] testcases = []
for item in xml: for item in xml:
if isinstance(item, TestSuite): if isinstance(item, TestSuite):
@ -56,7 +59,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase
return testcases return testcases
def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported] def render_tests(testcases: list[TestCase]) -> None: # type: ignore[no-any-unimported]
num_passed = 0 num_passed = 0
num_skipped = 0 num_skipped = 0
num_failed = 0 num_failed = 0

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import os import os
import sys import sys
from typing import Optional
def which(thefile: str) -> Optional[str]: def which(thefile: str) -> str | None:
path = os.environ.get("PATH", os.defpath).split(os.pathsep) path = os.environ.get("PATH", os.defpath).split(os.pathsep)
for d in path: for d in path:
fname = os.path.join(d, thefile) fname = os.path.join(d, thefile)

View File

@ -1,5 +1,6 @@
"Manages CMake." "Manages CMake."
from __future__ import annotations
import multiprocessing import multiprocessing
import os import os
@ -8,7 +9,7 @@ import sys
import sysconfig import sysconfig
from distutils.version import LooseVersion from distutils.version import LooseVersion
from subprocess import CalledProcessError, check_call, check_output from subprocess import CalledProcessError, check_call, check_output
from typing import Any, cast, Dict, List, Optional from typing import Any, cast
from . import which from . import which
from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file
@ -77,7 +78,7 @@ class CMake:
return cmake_command return cmake_command
@staticmethod @staticmethod
def _get_version(cmd: Optional[str]) -> Any: def _get_version(cmd: str | None) -> Any:
"Returns cmake version." "Returns cmake version."
if cmd is None: if cmd is None:
@ -87,7 +88,7 @@ class CMake:
return LooseVersion(line.strip().split(" ")[2]) return LooseVersion(line.strip().split(" ")[2])
raise RuntimeError("no version found") raise RuntimeError("no version found")
def run(self, args: List[str], env: Dict[str, str]) -> None: def run(self, args: list[str], env: dict[str, str]) -> None:
"Executes cmake with arguments and an environment." "Executes cmake with arguments and an environment."
command = [self._cmake_command] + args command = [self._cmake_command] + args
@ -101,13 +102,13 @@ class CMake:
sys.exit(1) sys.exit(1)
@staticmethod @staticmethod
def defines(args: List[str], **kwargs: CMakeValue) -> None: def defines(args: list[str], **kwargs: CMakeValue) -> None:
"Adds definitions to a cmake argument list." "Adds definitions to a cmake argument list."
for key, value in sorted(kwargs.items()): for key, value in sorted(kwargs.items()):
if value is not None: if value is not None:
args.append(f"-D{key}={value}") args.append(f"-D{key}={value}")
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]: def get_cmake_cache_variables(self) -> dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary. r"""Gets values in CMakeCache.txt into a dictionary.
Returns: Returns:
dict: A ``dict`` containing the value of cached CMake variables. dict: A ``dict`` containing the value of cached CMake variables.
@ -117,11 +118,11 @@ class CMake:
def generate( def generate(
self, self,
version: Optional[str], version: str | None,
cmake_python_library: Optional[str], cmake_python_library: str | None,
build_python: bool, build_python: bool,
build_test: bool, build_test: bool,
my_env: Dict[str, str], my_env: dict[str, str],
rerun: bool, rerun: bool,
) -> None: ) -> None:
"Runs cmake to generate native build files." "Runs cmake to generate native build files."
@ -181,7 +182,7 @@ class CMake:
_mkdir_p(self.build_dir) _mkdir_p(self.build_dir)
# Store build options that are directly stored in environment variables # Store build options that are directly stored in environment variables
build_options: Dict[str, CMakeValue] = {} build_options: dict[str, CMakeValue] = {}
# Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars. # Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars.
# This is a dict that maps environment variables to the corresponding variable name in CMake. # This is a dict that maps environment variables to the corresponding variable name in CMake.
@ -340,7 +341,7 @@ class CMake:
args.append(base_dir) args.append(base_dir)
self.run(args, env=my_env) self.run(args, env=my_env)
def build(self, my_env: Dict[str, str]) -> None: def build(self, my_env: dict[str, str]) -> None:
"Runs cmake to build binaries." "Runs cmake to build binaries."
from .env import build_type from .env import build_type

View File

@ -3,8 +3,10 @@ This is refactored from cmake.py to avoid circular imports issue with env.py,
which calls get_cmake_cache_variables_from_file which calls get_cmake_cache_variables_from_file
""" """
from __future__ import annotations
import re import re
from typing import Dict, IO, Optional, Union from typing import IO, Optional, Union
CMakeValue = Optional[Union[bool, str]] CMakeValue = Optional[Union[bool, str]]
@ -42,7 +44,7 @@ def convert_cmake_value_to_python_value(
def get_cmake_cache_variables_from_file( def get_cmake_cache_variables_from_file(
cmake_cache_file: IO[str], cmake_cache_file: IO[str],
) -> Dict[str, CMakeValue]: ) -> dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary. r"""Gets values in CMakeCache.txt into a dictionary.
Args: Args:

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import os import os
import platform import platform
import struct import struct
import sys import sys
from itertools import chain from itertools import chain
from typing import cast, Iterable, List, Optional from typing import cast, Iterable
IS_WINDOWS = platform.system() == "Windows" IS_WINDOWS = platform.system() == "Windows"
@ -30,11 +32,11 @@ def check_negative_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"] return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"]
def gather_paths(env_vars: Iterable[str]) -> List[str]: def gather_paths(env_vars: Iterable[str]) -> list[str]:
return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars))) return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars)))
def lib_paths_from_base(base_path: str) -> List[str]: def lib_paths_from_base(base_path: str) -> list[str]:
return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]] return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]]
@ -54,7 +56,7 @@ class BuildType:
""" """
def __init__(self, cmake_build_type_env: Optional[str] = None) -> None: def __init__(self, cmake_build_type_env: str | None = None) -> None:
if cmake_build_type_env is not None: if cmake_build_type_env is not None:
self.build_type_string = cmake_build_type_env self.build_type_string = cmake_build_type_env
return return

View File

@ -2,9 +2,12 @@
# and use the version numbers from there as substitutions for # and use the version numbers from there as substitutions for
# an expand_template action. Since there isn't, this silly script exists. # an expand_template action. Since there isn't, this silly script exists.
from __future__ import annotations
import argparse import argparse
import os import os
from typing import cast, Dict, Tuple from typing import cast, Tuple
Version = Tuple[int, int, int] Version = Tuple[int, int, int]
@ -30,7 +33,7 @@ def parse_version(version: str) -> Version:
return cast(Version, tuple([int(n) for n in version_number_str.split(".")])) return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
def apply_replacements(replacements: Dict[str, str], text: str) -> str: def apply_replacements(replacements: dict[str, str], text: str) -> str:
""" """
Applies the given replacements within the text. Applies the given replacements within the text.

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import argparse import argparse
import os import os
import pathlib import pathlib
import sys import sys
from typing import Any, cast, Optional from typing import Any, cast
import yaml import yaml
@ -18,10 +20,10 @@ TAGS_PATH = "aten/src/ATen/native/tags.yaml"
def generate_code( def generate_code(
gen_dir: pathlib.Path, gen_dir: pathlib.Path,
native_functions_path: Optional[str] = None, native_functions_path: str | None = None,
tags_path: Optional[str] = None, tags_path: str | None = None,
install_dir: Optional[str] = None, install_dir: str | None = None,
subset: Optional[str] = None, subset: str | None = None,
disable_autograd: bool = False, disable_autograd: bool = False,
force_schema_registration: bool = False, force_schema_registration: bool = False,
operator_selector: Any = None, operator_selector: Any = None,
@ -102,8 +104,8 @@ def get_selector_from_legacy_operator_selection_list(
def get_selector( def get_selector(
selected_op_list_path: Optional[str], selected_op_list_path: str | None,
operators_yaml_path: Optional[str], operators_yaml_path: str | None,
) -> Any: ) -> Any:
# cwrap depends on pyyaml, so we can't import it earlier # cwrap depends on pyyaml, so we can't import it earlier
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import argparse import argparse
import json import json
import os import os
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Dict, Generator, Tuple from typing import Any, Generator
from tools.stats.upload_stats_lib import ( from tools.stats.upload_stats_lib import (
download_s3_artifacts, download_s3_artifacts,
@ -14,13 +16,14 @@ from tools.stats.upload_stats_lib import (
) )
from tools.stats.upload_test_stats import process_xml_element from tools.stats.upload_test_stats import process_xml_element
TESTCASE_TAG = "testcase" TESTCASE_TAG = "testcase"
SEPARATOR = ";" SEPARATOR = ";"
def process_report( def process_report(
report: Path, report: Path,
) -> Dict[str, Dict[str, int]]: ) -> dict[str, dict[str, int]]:
""" """
Return a list of disabled tests that should be re-enabled and those that are still Return a list of disabled tests that should be re-enabled and those that are still
flaky (failed or skipped) flaky (failed or skipped)
@ -36,7 +39,7 @@ def process_report(
# * Skipped tests from unittest # * Skipped tests from unittest
# #
# We want to keep track of how many times the test fails (num_red) or passes (num_green) # We want to keep track of how many times the test fails (num_red) or passes (num_green)
all_tests: Dict[str, Dict[str, int]] = {} all_tests: dict[str, dict[str, int]] = {}
for test_case in root.iter(TESTCASE_TAG): for test_case in root.iter(TESTCASE_TAG):
parsed_test_case = process_xml_element(test_case) parsed_test_case = process_xml_element(test_case)
@ -116,7 +119,7 @@ def get_test_reports(
yield from Path(".").glob("**/*.xml") yield from Path(".").glob("**/*.xml")
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]: def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]:
""" """
Follow flaky bot convention here, if that changes, this will also need to be updated Follow flaky bot convention here, if that changes, this will also need to be updated
""" """
@ -133,7 +136,7 @@ def prepare_record(
flaky: bool, flaky: bool,
num_red: int = 0, num_red: int = 0,
num_green: int = 0, num_green: int = 0,
) -> Tuple[Any, Dict[str, Any]]: ) -> tuple[Any, dict[str, Any]]:
""" """
Prepare the record to save onto S3 Prepare the record to save onto S3
""" """
@ -162,7 +165,7 @@ def prepare_record(
def save_results( def save_results(
workflow_id: int, workflow_id: int,
workflow_run_attempt: int, workflow_run_attempt: int,
all_tests: Dict[str, Dict[str, int]], all_tests: dict[str, dict[str, int]],
) -> None: ) -> None:
""" """
Save the result to S3, so it can go to Rockset Save the result to S3, so it can go to Rockset
@ -228,7 +231,7 @@ def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
Find the list of all disabled tests that should be re-enabled Find the list of all disabled tests that should be re-enabled
""" """
# Aggregated across all jobs # Aggregated across all jobs
all_tests: Dict[str, Dict[str, int]] = {} all_tests: dict[str, dict[str, int]] = {}
for report in get_test_reports( for report in get_test_reports(
args.repo, args.workflow_run_id, args.workflow_run_attempt args.repo, args.workflow_run_id, args.workflow_run_attempt

View File

@ -1,17 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import datetime import datetime
import json import json
import os import os
import pathlib import pathlib
import shutil import shutil
from typing import Any, Callable, cast, Dict, List, Optional, Union from typing import Any, Callable, cast, Dict
from urllib.request import urlopen from urllib.request import urlopen
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
def get_disabled_issues() -> List[str]: def get_disabled_issues() -> list[str]:
reenabled_issues = os.getenv("REENABLED_ISSUES", "") reenabled_issues = os.getenv("REENABLED_ISSUES", "")
issue_numbers = reenabled_issues.split(",") issue_numbers = reenabled_issues.split(",")
print("Ignoring disabled issues: ", issue_numbers) print("Ignoring disabled issues: ", issue_numbers)
@ -34,11 +36,11 @@ FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
def fetch_and_cache( def fetch_and_cache(
dirpath: Union[str, pathlib.Path], dirpath: str | pathlib.Path,
name: str, name: str,
url: str, url: str,
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]], process_fn: Callable[[dict[str, Any]], dict[str, Any]],
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
This fetch and cache utils allows sharing between different process. This fetch and cache utils allows sharing between different process.
""" """
@ -76,7 +78,7 @@ def fetch_and_cache(
def get_slow_tests( def get_slow_tests(
dirpath: str, filename: str = SLOW_TESTS_FILE dirpath: str, filename: str = SLOW_TESTS_FILE
) -> Optional[Dict[str, float]]: ) -> dict[str, float] | None:
url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json" url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json"
try: try:
return fetch_and_cache(dirpath, filename, url, lambda x: x) return fetch_and_cache(dirpath, filename, url, lambda x: x)
@ -85,7 +87,7 @@ def get_slow_tests(
return {} return {}
def get_test_times() -> Dict[str, Dict[str, float]]: def get_test_times() -> dict[str, dict[str, float]]:
return get_from_test_infra_generated_stats( return get_from_test_infra_generated_stats(
"test-times.json", "test-times.json",
TEST_TIMES_FILE, TEST_TIMES_FILE,
@ -93,7 +95,7 @@ def get_test_times() -> Dict[str, Dict[str, float]]:
) )
def get_test_class_times() -> Dict[str, Dict[str, float]]: def get_test_class_times() -> dict[str, dict[str, float]]:
return get_from_test_infra_generated_stats( return get_from_test_infra_generated_stats(
"test-class-times.json", "test-class-times.json",
TEST_CLASS_TIMES_FILE, TEST_CLASS_TIMES_FILE,
@ -103,8 +105,8 @@ def get_test_class_times() -> Dict[str, Dict[str, float]]:
def get_disabled_tests( def get_disabled_tests(
dirpath: str, filename: str = DISABLED_TESTS_FILE dirpath: str, filename: str = DISABLED_TESTS_FILE
) -> Optional[Dict[str, Any]]: ) -> dict[str, Any] | None:
def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]: def process_disabled_test(the_response: dict[str, Any]) -> dict[str, Any]:
# remove re-enabled tests and condense even further by getting rid of pr_num # remove re-enabled tests and condense even further by getting rid of pr_num
disabled_issues = get_disabled_issues() disabled_issues = get_disabled_issues()
disabled_test_from_issues = dict() disabled_test_from_issues = dict()
@ -124,7 +126,7 @@ def get_disabled_tests(
return {} return {}
def get_test_file_ratings() -> Dict[str, Any]: def get_test_file_ratings() -> dict[str, Any]:
return get_from_test_infra_generated_stats( return get_from_test_infra_generated_stats(
"file_test_rating.json", "file_test_rating.json",
TEST_FILE_RATINGS_FILE, TEST_FILE_RATINGS_FILE,
@ -132,7 +134,7 @@ def get_test_file_ratings() -> Dict[str, Any]:
) )
def get_test_class_ratings() -> Dict[str, Any]: def get_test_class_ratings() -> dict[str, Any]:
return get_from_test_infra_generated_stats( return get_from_test_infra_generated_stats(
"file_test_class_rating.json", "file_test_class_rating.json",
TEST_CLASS_RATINGS_FILE, TEST_CLASS_RATINGS_FILE,
@ -140,7 +142,7 @@ def get_test_class_ratings() -> Dict[str, Any]:
) )
def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]: def get_td_heuristic_historial_edited_files_json() -> dict[str, Any]:
return get_from_test_infra_generated_stats( return get_from_test_infra_generated_stats(
"td_heuristic_historical_edited_files.json", "td_heuristic_historical_edited_files.json",
TD_HEURISTIC_HISTORICAL_EDITED_FILES, TD_HEURISTIC_HISTORICAL_EDITED_FILES,
@ -148,7 +150,7 @@ def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
) )
def get_td_heuristic_profiling_json() -> Dict[str, Any]: def get_td_heuristic_profiling_json() -> dict[str, Any]:
return get_from_test_infra_generated_stats( return get_from_test_infra_generated_stats(
"td_heuristic_profiling.json", "td_heuristic_profiling.json",
TD_HEURISTIC_PROFILING_FILE, TD_HEURISTIC_PROFILING_FILE,
@ -182,7 +184,7 @@ def copy_additional_previous_failures() -> None:
def get_from_test_infra_generated_stats( def get_from_test_infra_generated_stats(
from_file: str, to_file: str, failure_explanation: str from_file: str, to_file: str, failure_explanation: str
) -> Dict[str, Any]: ) -> dict[str, Any]:
url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}" url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}"
try: try:
return fetch_and_cache( return fetch_and_cache(

View File

@ -1,14 +1,17 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import datetime import datetime
import json import json
import signal import signal
import time import time
from typing import Any, Dict, List from typing import Any
import psutil # type: ignore[import] import psutil # type: ignore[import]
def get_processes_running_python_tests() -> List[Any]: def get_processes_running_python_tests() -> list[Any]:
python_processes = [] python_processes = []
for process in psutil.process_iter(): for process in psutil.process_iter():
try: try:
@ -20,7 +23,7 @@ def get_processes_running_python_tests() -> List[Any]:
return python_processes return python_processes
def get_per_process_cpu_info() -> List[Dict[str, Any]]: def get_per_process_cpu_info() -> list[dict[str, Any]]:
processes = get_processes_running_python_tests() processes = get_processes_running_python_tests()
per_process_info = [] per_process_info = []
for p in processes: for p in processes:
@ -49,7 +52,7 @@ def get_per_process_cpu_info() -> List[Dict[str, Any]]:
return per_process_info return per_process_info
def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: def get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
per_process_info = [] per_process_info = []
for p in processes: for p in processes:
@ -58,7 +61,7 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
return per_process_info return per_process_info
def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: def rocm_get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
processes = amdsmi.amdsmi_get_gpu_process_list(handle) processes = amdsmi.amdsmi_get_gpu_process_list(handle)
per_process_info = [] per_process_info = []
for p in processes: for p in processes:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json import json
import os import os
import re import re
@ -6,7 +8,7 @@ from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, cast, Dict, List from typing import Any, cast
import requests import requests
@ -18,6 +20,7 @@ from tools.stats.upload_stats_lib import (
upload_workflow_stats_to_s3, upload_workflow_stats_to_s3,
) )
REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)" REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)"
@ -56,7 +59,7 @@ def get_test_config(job_name: str) -> str:
def get_td_exclusions( def get_td_exclusions(
workflow_run_id: int, workflow_run_attempt: int workflow_run_id: int, workflow_run_attempt: int
) -> Dict[str, Any]: ) -> dict[str, Any]:
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir) print("Using temporary directory:", temp_dir)
os.chdir(temp_dir) os.chdir(temp_dir)
@ -68,7 +71,7 @@ def get_td_exclusions(
for path in s3_paths: for path in s3_paths:
unzip(path) unzip(path)
grouped_tests: Dict[str, Any] = defaultdict(lambda: defaultdict(set)) grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set))
for td_exclusions in Path(".").glob("**/td_exclusions*.json"): for td_exclusions in Path(".").glob("**/td_exclusions*.json"):
with open(td_exclusions) as f: with open(td_exclusions) as f:
exclusions = json.load(f) exclusions = json.load(f)
@ -85,9 +88,9 @@ def get_td_exclusions(
return grouped_tests return grouped_tests
def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]: def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]:
start = time.time() start = time.time()
grouped_tests: Dict[str, Any] = defaultdict( grouped_tests: dict[str, Any] = defaultdict(
lambda: defaultdict( lambda: defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
) )
@ -112,8 +115,8 @@ def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
return grouped_tests return grouped_tests
def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]: def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]:
reruns: Dict[str, Any] = defaultdict( reruns: dict[str, Any] = defaultdict(
lambda: defaultdict( lambda: defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
) )
@ -136,8 +139,8 @@ def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
return reruns return reruns
def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]: def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]:
invoking_file_summary: Dict[str, Any] = defaultdict( invoking_file_summary: dict[str, Any] = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0})) lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0}))
) )
for build_name, build in grouped_tests.items(): for build_name, build in grouped_tests.items():
@ -157,7 +160,7 @@ def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
def upload_additional_info( def upload_additional_info(
workflow_run_id: int, workflow_run_attempt: int, test_cases: List[Dict[str, Any]] workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]]
) -> None: ) -> None:
grouped_tests = group_test_cases(test_cases) grouped_tests = group_test_cases(test_cases)
reruns = get_reruns(grouped_tests) reruns = get_reruns(grouped_tests)

View File

@ -5,6 +5,7 @@ from tempfile import TemporaryDirectory
from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3 from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3
ARTIFACTS = [ ARTIFACTS = [
"sccache-stats", "sccache-stats",
"test-jsons", "test-jsons",

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import argparse import argparse
import csv import csv
import os import os
import re import re
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Dict, List from typing import Any
from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset
@ -23,7 +25,7 @@ def upload_dynamo_perf_stats_to_rockset(
workflow_run_attempt: int, workflow_run_attempt: int,
head_branch: str, head_branch: str,
match_filename: str, match_filename: str,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
match_filename_regex = re.compile(match_filename) match_filename_regex = re.compile(match_filename)
perf_stats = [] perf_stats = []
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import argparse import argparse
import datetime import datetime
import json import json
import os import os
import time import time
import urllib.parse import urllib.parse
from typing import Any, Callable, cast, Dict, List, Optional, Set from typing import Any, Callable, cast, Dict, List
from urllib.error import HTTPError from urllib.error import HTTPError
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
from tools.stats.upload_stats_lib import upload_to_s3 from tools.stats.upload_stats_lib import upload_to_s3
FILTER_OUT_USERS = { FILTER_OUT_USERS = {
"pytorchmergebot", "pytorchmergebot",
"facebook-github-bot", "facebook-github-bot",
@ -23,9 +25,9 @@ FILTER_OUT_USERS = {
def _fetch_url( def _fetch_url(
url: str, url: str,
headers: Dict[str, str], headers: dict[str, str],
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
method: Optional[str] = None, method: str | None = None,
reader: Callable[[Any], Any] = lambda x: x.read(), reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Any: ) -> Any:
token = os.environ.get("GITHUB_TOKEN") token = os.environ.get("GITHUB_TOKEN")
@ -49,9 +51,9 @@ def _fetch_url(
def fetch_json( def fetch_json(
url: str, url: str,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
headers = {"Accept": "application/vnd.github.v3+json"} headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0: if params is not None and len(params) > 0:
url += "?" + "&".join( url += "?" + "&".join(
@ -65,16 +67,16 @@ def fetch_json(
def get_external_pr_data( def get_external_pr_data(
start_date: datetime.date, end_date: datetime.date, period_length: int = 1 start_date: datetime.date, end_date: datetime.date, period_length: int = 1
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
pr_info = [] pr_info = []
period_begin_date = start_date period_begin_date = start_date
pr_count = 0 pr_count = 0
users: Set[str] = set() users: set[str] = set()
while period_begin_date < end_date: while period_begin_date < end_date:
period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1) period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1)
page = 1 page = 1
responses: List[Dict[str, Any]] = [] responses: list[dict[str, Any]] = []
while len(responses) > 0 or page == 1: while len(responses) > 0 or page == 1:
response = cast( response = cast(
Dict[str, Any], Dict[str, Any],

View File

@ -1,13 +1,15 @@
from __future__ import annotations
import datetime import datetime
import inspect import inspect
import os import os
import time import time
import uuid import uuid
from decimal import Decimal from decimal import Decimal
from typing import Any, Dict from typing import Any
from warnings import warn from warnings import warn
# boto3 is an optional dependency. If it's not installed, # boto3 is an optional dependency. If it's not installed,
# we'll just not emit the metrics. # we'll just not emit the metrics.
# Keeping this logic here so that callers don't have to # Keeping this logic here so that callers don't have to
@ -65,7 +67,7 @@ class EnvVarMetric:
return value return value
global_metrics: Dict[str, Any] = {} global_metrics: dict[str, Any] = {}
def add_global_metric(metric_name: str, metric_value: Any) -> None: def add_global_metric(metric_name: str, metric_value: Any) -> None:
@ -79,7 +81,7 @@ def add_global_metric(metric_name: str, metric_value: Any) -> None:
def emit_metric( def emit_metric(
metric_name: str, metric_name: str,
metrics: Dict[str, Any], metrics: dict[str, Any],
) -> None: ) -> None:
""" """
Upload a metric to DynamoDB (and from there, Rockset). Upload a metric to DynamoDB (and from there, Rockset).
@ -174,7 +176,7 @@ def emit_metric(
print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.") print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.")
def _convert_float_values_to_decimals(data: Dict[str, Any]) -> Dict[str, Any]: def _convert_float_values_to_decimals(data: dict[str, Any]) -> dict[str, Any]:
# Attempt to recurse # Attempt to recurse
def _helper(o: Any) -> Any: def _helper(o: Any) -> Any:
if isinstance(o, float): if isinstance(o, float):

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import argparse import argparse
import json import json
import os import os
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Dict, List from typing import Any
from tools.stats.upload_stats_lib import ( from tools.stats.upload_stats_lib import (
download_s3_artifacts, download_s3_artifacts,
@ -13,7 +15,7 @@ from tools.stats.upload_stats_lib import (
def get_sccache_stats( def get_sccache_stats(
workflow_run_id: int, workflow_run_attempt: int workflow_run_id: int, workflow_run_attempt: int
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir) print("Using temporary directory:", temp_dir)
os.chdir(temp_dir) os.chdir(temp_dir)

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import gzip import gzip
import io import io
import json import json
import os import os
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any
import boto3 # type: ignore[import] import boto3 # type: ignore[import]
import requests import requests
import rockset # type: ignore[import] import rockset # type: ignore[import]
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch" PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
S3_RESOURCE = boto3.resource("s3") S3_RESOURCE = boto3.resource("s3")
@ -21,14 +23,14 @@ MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3
BATCH_SIZE = 5000 BATCH_SIZE = 5000
def _get_request_headers() -> Dict[str, str]: def _get_request_headers() -> dict[str, str]:
return { return {
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
"Authorization": "token " + os.environ["GITHUB_TOKEN"], "Authorization": "token " + os.environ["GITHUB_TOKEN"],
} }
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]: def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]:
"""Get all workflow artifacts with 'test-report' in the name.""" """Get all workflow artifacts with 'test-report' in the name."""
response = requests.get( response = requests.get(
f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100", f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
@ -78,7 +80,7 @@ def _download_artifact(
def download_s3_artifacts( def download_s3_artifacts(
prefix: str, workflow_run_id: int, workflow_run_attempt: int prefix: str, workflow_run_id: int, workflow_run_attempt: int
) -> List[Path]: ) -> list[Path]:
bucket = S3_RESOURCE.Bucket("gha-artifacts") bucket = S3_RESOURCE.Bucket("gha-artifacts")
objs = bucket.objects.filter( objs = bucket.objects.filter(
Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}" Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
@ -104,7 +106,7 @@ def download_s3_artifacts(
def download_gha_artifacts( def download_gha_artifacts(
prefix: str, workflow_run_id: int, workflow_run_attempt: int prefix: str, workflow_run_id: int, workflow_run_attempt: int
) -> List[Path]: ) -> list[Path]:
artifact_urls = _get_artifact_urls(prefix, workflow_run_id) artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
paths = [] paths = []
for name, url in artifact_urls.items(): for name, url in artifact_urls.items():
@ -114,7 +116,7 @@ def download_gha_artifacts(
def upload_to_rockset( def upload_to_rockset(
collection: str, collection: str,
docs: List[Any], docs: list[Any],
workspace: str = "commons", workspace: str = "commons",
client: Any = None, client: Any = None,
) -> None: ) -> None:
@ -142,7 +144,7 @@ def upload_to_rockset(
def upload_to_s3( def upload_to_s3(
bucket_name: str, bucket_name: str,
key: str, key: str,
docs: List[Dict[str, Any]], docs: list[dict[str, Any]],
) -> None: ) -> None:
print(f"Writing {len(docs)} documents to S3") print(f"Writing {len(docs)} documents to S3")
body = io.StringIO() body = io.StringIO()
@ -164,7 +166,7 @@ def upload_to_s3(
def read_from_s3( def read_from_s3(
bucket_name: str, bucket_name: str,
key: str, key: str,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
print(f"Reading from s3://{bucket_name}/{key}") print(f"Reading from s3://{bucket_name}/{key}")
body = ( body = (
S3_RESOURCE.Object( S3_RESOURCE.Object(
@ -182,7 +184,7 @@ def upload_workflow_stats_to_s3(
workflow_run_id: int, workflow_run_id: int,
workflow_run_attempt: int, workflow_run_attempt: int,
collection: str, collection: str,
docs: List[Dict[str, Any]], docs: list[dict[str, Any]],
) -> None: ) -> None:
bucket_name = "ossci-raw-job-status" bucket_name = "ossci-raw-job-status"
key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}" key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
@ -220,7 +222,7 @@ def unzip(p: Path) -> None:
zip.extractall(unzipped_dir) zip.extractall(unzipped_dir)
def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool: def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool:
""" """
Check if the test report is coming from rerun_disabled_tests workflow where Check if the test report is coming from rerun_disabled_tests workflow where
each test is run multiple times each test is run multiple times
@ -231,7 +233,7 @@ def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
) )
def get_job_id(report: Path) -> Optional[int]: def get_job_id(report: Path) -> int | None:
# [Job id in artifacts] # [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append # Retrieve the job id from the report path. In our GHA workflows, we append
# the job id to the end of the report name, so `report` looks like: # the job id to the end of the report name, so `report` looks like:

View File

@ -1,17 +1,19 @@
from __future__ import annotations
import argparse import argparse
import ast import ast
import datetime import datetime
import json import json
import os import os
import re import re
from typing import Any, List, Union from typing import Any
import rockset # type: ignore[import] import rockset # type: ignore[import]
from tools.stats.upload_stats_lib import upload_to_s3 from tools.stats.upload_stats_lib import upload_to_s3
def get_oncall_from_testfile(testfile: str) -> Union[List[str], None]: def get_oncall_from_testfile(testfile: str) -> list[str] | None:
path = f"test/{testfile}" path = f"test/{testfile}"
if not path.endswith(".py"): if not path.endswith(".py"):
path += ".py" path += ".py"

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import os import os
import sys import sys
@ -5,7 +7,7 @@ import xml.etree.ElementTree as ET
from multiprocessing import cpu_count, Pool from multiprocessing import cpu_count, Pool
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Dict, List from typing import Any
from tools.stats.test_dashboard import upload_additional_info from tools.stats.test_dashboard import upload_additional_info
from tools.stats.upload_stats_lib import ( from tools.stats.upload_stats_lib import (
@ -21,14 +23,14 @@ def parse_xml_report(
report: Path, report: Path,
workflow_id: int, workflow_id: int,
workflow_run_attempt: int, workflow_run_attempt: int,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Convert a test report xml file into a JSON-serializable list of test cases.""" """Convert a test report xml file into a JSON-serializable list of test cases."""
print(f"Parsing {tag}s for test report: {report}") print(f"Parsing {tag}s for test report: {report}")
job_id = get_job_id(report) job_id = get_job_id(report)
print(f"Found job id: {job_id}") print(f"Found job id: {job_id}")
test_cases: List[Dict[str, Any]] = [] test_cases: list[dict[str, Any]] = []
root = ET.parse(report) root = ET.parse(report)
for test_case in root.iter(tag): for test_case in root.iter(tag):
@ -53,9 +55,9 @@ def parse_xml_report(
return test_cases return test_cases
def process_xml_element(element: ET.Element) -> Dict[str, Any]: def process_xml_element(element: ET.Element) -> dict[str, Any]:
"""Convert a test suite element into a JSON-serializable dict.""" """Convert a test suite element into a JSON-serializable dict."""
ret: Dict[str, Any] = {} ret: dict[str, Any] = {}
# Convert attributes directly into dict elements. # Convert attributes directly into dict elements.
# e.g. # e.g.
@ -110,7 +112,7 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
return ret return ret
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]: def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, Any]]:
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir) print("Using temporary directory:", temp_dir)
os.chdir(temp_dir) os.chdir(temp_dir)
@ -146,7 +148,7 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
def get_tests_for_circleci( def get_tests_for_circleci(
workflow_run_id: int, workflow_run_attempt: int workflow_run_id: int, workflow_run_attempt: int
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
# Parse the reports and transform them to JSON # Parse the reports and transform them to JSON
test_cases = [] test_cases = []
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"): for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
@ -159,13 +161,13 @@ def get_tests_for_circleci(
return test_cases return test_cases
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Group test cases by classname, file, and job_id. We perform the aggregation """Group test cases by classname, file, and job_id. We perform the aggregation
manually instead of using the `test-suite` XML tag because xmlrunner does manually instead of using the `test-suite` XML tag because xmlrunner does
not produce reliable output for it. not produce reliable output for it.
""" """
def get_key(test_case: Dict[str, Any]) -> Any: def get_key(test_case: dict[str, Any]) -> Any:
return ( return (
test_case.get("file"), test_case.get("file"),
test_case.get("classname"), test_case.get("classname"),
@ -176,7 +178,7 @@ def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any
test_case["invoking_file"], test_case["invoking_file"],
) )
def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]: def init_value(test_case: dict[str, Any]) -> dict[str, Any]:
return { return {
"file": test_case.get("file"), "file": test_case.get("file"),
"classname": test_case.get("classname"), "classname": test_case.get("classname"),

View File

@ -4,6 +4,7 @@ import sys
from tools.stats.test_dashboard import upload_additional_info from tools.stats.test_dashboard import upload_additional_info
from tools.stats.upload_test_stats import get_tests from tools.stats.upload_test_stats import get_tests
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upload test stats to Rockset") parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
parser.add_argument( parser.add_argument(

View File

@ -5,7 +5,6 @@ import argparse
import json import json
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from gen_operators_yaml import ( from gen_operators_yaml import (
@ -43,10 +42,10 @@ def _mock_load_op_dep_graph():
class GenOperatorsYAMLTest(unittest.TestCase): class GenOperatorsYAMLTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
pass pass
def test_filter_creation(self): def test_filter_creation(self) -> None:
filter_func = make_filter_from_options( filter_func = make_filter_from_options(
model_name="abc", model_name="abc",
model_versions=["100", "101"], model_versions=["100", "101"],
@ -99,7 +98,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
len(filtered_configs) == 2 len(filtered_configs) == 2
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}" ), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
def test_verification_success(self): def test_verification_success(self) -> None:
filter_func = make_filter_from_options( filter_func = make_filter_from_options(
model_name="abc", model_name="abc",
model_versions=["100", "101"], model_versions=["100", "101"],
@ -142,7 +141,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
"expected verify_all_specified_present to succeed instead it raised an exception" "expected verify_all_specified_present to succeed instead it raised an exception"
) )
def test_verification_fail(self): def test_verification_fail(self) -> None:
config = [ config = [
{ {
"model": { "model": {
@ -229,7 +228,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
) )
def test_fill_output_with_arguments_not_include_all_overloads( def test_fill_output_with_arguments_not_include_all_overloads(
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
): ) -> None:
parser = argparse.ArgumentParser(description="Generate used operators YAML") parser = argparse.ArgumentParser(description="Generate used operators YAML")
options = get_parser_options(parser) options = get_parser_options(parser)

View File

@ -8,10 +8,10 @@ from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads
class GenOplistTest(unittest.TestCase): class GenOplistTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
pass pass
def test_throw_if_any_op_includes_overloads(self): def test_throw_if_any_op_includes_overloads(self) -> None:
selective_builder = MagicMock() selective_builder = MagicMock()
selective_builder.operators = MagicMock() selective_builder.operators = MagicMock()
selective_builder.operators.items.return_value = [ selective_builder.operators.items.return_value = [

View File

@ -1,10 +1,12 @@
# For testing specific heuristics # For testing specific heuristics
from __future__ import annotations
import io import io
import json import json
import pathlib import pathlib
import sys import sys
import unittest import unittest
from typing import Any, Dict, List, Set from typing import Any
from unittest import mock from unittest import mock
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
@ -28,14 +30,14 @@ sys.path.remove(str(REPO_ROOT))
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation." HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase: def mocked_file(contents: dict[Any, Any]) -> io.IOBase:
file_object = io.StringIO() file_object = io.StringIO()
json.dump(contents, file_object) json.dump(contents, file_object)
file_object.seek(0) file_object.seek(0)
return file_object return file_object
def gen_historical_class_failures() -> Dict[str, Dict[str, float]]: def gen_historical_class_failures() -> dict[str, dict[str, float]]:
return { return {
"file1": { "file1": {
"test1::classA": 0.5, "test1::classA": 0.5,
@ -80,8 +82,8 @@ class TestHistoricalClassFailureCorrelation(TestTD):
) )
def test_get_prediction_confidence( def test_get_prediction_confidence(
self, self,
historical_class_failures: Dict[str, Dict[str, float]], historical_class_failures: dict[str, dict[str, float]],
changed_files: List[str], changed_files: list[str],
) -> None: ) -> None:
tests_to_prioritize = ALL_TESTS tests_to_prioritize = ALL_TESTS
@ -113,7 +115,7 @@ class TestHistoricalClassFailureCorrelation(TestTD):
class TestParsePrevTests(TestTD): class TestParsePrevTests(TestTD):
@mock.patch("os.path.exists", return_value=False) @mock.patch("os.path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None: def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set() expected_failing_test_files: set[str] = set()
found_tests = get_previous_failures() found_tests = get_previous_failures()
@ -122,7 +124,7 @@ class TestParsePrevTests(TestTD):
@mock.patch("os.path.exists", return_value=True) @mock.patch("os.path.exists", return_value=True)
@mock.patch("builtins.open", return_value=mocked_file({"": True})) @mock.patch("builtins.open", return_value=mocked_file({"": True}))
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None: def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files: Set[str] = set() expected_failing_test_files: set[str] = set()
found_tests = get_previous_failures() found_tests = get_previous_failures()

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import pathlib import pathlib
import sys import sys
import unittest import unittest
from typing import Any, Dict, List from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
sys.path.append(str(REPO_ROOT)) sys.path.append(str(REPO_ROOT))
@ -13,7 +15,7 @@ sys.path.remove(str(REPO_ROOT))
class TestTD(unittest.TestCase): class TestTD(unittest.TestCase):
def assert_test_scores_almost_equal( def assert_test_scores_almost_equal(
self, d1: Dict[TestRun, float], d2: Dict[TestRun, float] self, d1: dict[TestRun, float], d2: dict[TestRun, float]
) -> None: ) -> None:
# Check that dictionaries are the same, except for floating point errors # Check that dictionaries are the same, except for floating point errors
self.assertEqual(set(d1.keys()), set(d2.keys())) self.assertEqual(set(d1.keys()), set(d2.keys()))
@ -24,7 +26,7 @@ class TestTD(unittest.TestCase):
# Create a dummy heuristic class # Create a dummy heuristic class
class Heuristic(interface.HeuristicInterface): class Heuristic(interface.HeuristicInterface):
def get_prediction_confidence( def get_prediction_confidence(
self, tests: List[str] self, tests: list[str]
) -> interface.TestPrioritizations: ) -> interface.TestPrioritizations:
# Return junk # Return junk
return interface.TestPrioritizations([], {}) return interface.TestPrioritizations([], {})
@ -259,9 +261,9 @@ class TestTestPrioritizations(TestTD):
class TestAggregatedHeuristics(TestTD): class TestAggregatedHeuristics(TestTD):
def check( def check(
self, self,
tests: List[str], tests: list[str],
test_prioritizations: List[Dict[TestRun, float]], test_prioritizations: list[dict[TestRun, float]],
expected: Dict[TestRun, float], expected: dict[TestRun, float],
) -> None: ) -> None:
aggregated_heuristics = interface.AggregatedHeuristics(tests) aggregated_heuristics = interface.AggregatedHeuristics(tests)
for i, test_prioritization in enumerate(test_prioritizations): for i, test_prioritization in enumerate(test_prioritizations):
@ -429,7 +431,7 @@ class TestAggregatedHeuristicsTestStats(TestTD):
stats3 = aggregator.get_test_stats(TestRun("test3")) stats3 = aggregator.get_test_stats(TestRun("test3"))
stats5 = aggregator.get_test_stats(TestRun("test5::classA")) stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None: def assert_valid_dict(dict_contents: dict[str, Any]) -> None:
for key, value in dict_contents.items(): for key, value in dict_contents.items():
self.assertTrue(isinstance(key, str)) self.assertTrue(isinstance(key, str))
self.assertTrue( self.assertTrue(

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import pathlib import pathlib
import sys import sys
import unittest import unittest
from typing import Any, Dict from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
@ -14,14 +16,14 @@ sys.path.remove(str(REPO_ROOT))
class TestHeuristicsUtils(unittest.TestCase): class TestHeuristicsUtils(unittest.TestCase):
def assertDictAlmostEqual( def assertDictAlmostEqual(
self, first: Dict[TestRun, Any], second: Dict[TestRun, Any] self, first: dict[TestRun, Any], second: dict[TestRun, Any]
) -> None: ) -> None:
self.assertEqual(first.keys(), second.keys()) self.assertEqual(first.keys(), second.keys())
for key in first.keys(): for key in first.keys():
self.assertAlmostEqual(first[key], second[key]) self.assertAlmostEqual(first[key], second[key])
def test_normalize_ratings(self) -> None: def test_normalize_ratings(self) -> None:
ratings: Dict[TestRun, float] = { ratings: dict[TestRun, float] = {
TestRun("test1"): 1, TestRun("test1"): 1,
TestRun("test2"): 2, TestRun("test2"): 2,
TestRun("test3"): 4, TestRun("test3"): 4,

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import contextlib import contextlib
import os import os
import typing import typing
import unittest import unittest
import unittest.mock import unittest.mock
from typing import Iterator, Optional, Sequence from typing import Iterator, Sequence
import tools.setup_helpers.cmake import tools.setup_helpers.cmake
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
@ -79,7 +80,7 @@ class TestCMake(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def env_var(key: str, value: Optional[str]) -> Iterator[None]: def env_var(key: str, value: str | None) -> Iterator[None]:
"""Sets/clears an environment variable within a Python context.""" """Sets/clears an environment variable within a Python context."""
# Get the previous value and then override it. # Get the previous value and then override it.
previous_value = os.environ.get(key) previous_value = os.environ.get(key)
@ -91,7 +92,7 @@ def env_var(key: str, value: Optional[str]) -> Iterator[None]:
set_env_var(key, previous_value) set_env_var(key, previous_value)
def set_env_var(key: str, value: Optional[str]) -> None: def set_env_var(key: str, value: str | None) -> None:
"""Sets/clears an environment variable.""" """Sets/clears an environment variable."""
if value is None: if value is None:
os.environ.pop(key, None) os.environ.pop(key, None)

View File

@ -1,14 +1,13 @@
from __future__ import annotations
import dataclasses import dataclasses
import typing import typing
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from typing import Dict, List
import yaml import yaml
from tools.autograd import gen_autograd_functions, load_derivatives from tools.autograd import gen_autograd_functions, load_derivatives
import torchgen.model
from torchgen import dest from torchgen import dest
from torchgen.api.types import CppSignatureGroup, DispatcherSignature from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager from torchgen.context import native_function_manager
@ -22,6 +21,7 @@ from torchgen.model import (
BackendIndex, BackendIndex,
BackendMetadata, BackendMetadata,
DispatchKey, DispatchKey,
FunctionSchema,
Location, Location,
NativeFunction, NativeFunction,
OperatorName, OperatorName,
@ -32,7 +32,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
class TestCreateDerivative(unittest.TestCase): class TestCreateDerivative(unittest.TestCase):
def test_named_grads(self) -> None: def test_named_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse( schema = FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
) )
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
@ -47,7 +47,7 @@ class TestCreateDerivative(unittest.TestCase):
def test_non_differentiable_output(self) -> None: def test_non_differentiable_output(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification) schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info( _, differentiability_info = load_derivatives.create_differentiability_info(
@ -69,7 +69,7 @@ class TestCreateDerivative(unittest.TestCase):
) )
def test_indexed_grads(self) -> None: def test_indexed_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse( schema = FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
) )
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
@ -84,7 +84,7 @@ class TestCreateDerivative(unittest.TestCase):
def test_named_grads_and_indexed_grads(self) -> None: def test_named_grads_and_indexed_grads(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
schema = torchgen.model.FunctionSchema.parse(specification) schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -112,7 +112,7 @@ class TestCreateDerivative(unittest.TestCase):
class TestGenAutogradFunctions(unittest.TestCase): class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_invalid_type(self) -> None: def test_non_differentiable_output_invalid_type(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification) schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info( _, differentiability_info = load_derivatives.create_differentiability_info(
@ -141,7 +141,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_output_differentiability(self) -> None: def test_non_differentiable_output_output_differentiability(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)" specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification) schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info( _, differentiability_info = load_derivatives.create_differentiability_info(
@ -182,7 +182,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
def test_register_bogus_dispatch_key(self) -> None: def test_register_bogus_dispatch_key(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification) schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -213,17 +213,17 @@ class TestGenAutogradFunctions(unittest.TestCase):
class TestGenSchemaRegistration(unittest.TestCase): class TestGenSchemaRegistration(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.selector = SelectiveBuilder.get_nop_selector() self.selector = SelectiveBuilder.get_nop_selector()
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml( self.custom_native_function, _ = NativeFunction.from_yaml(
{"func": "custom::func() -> bool"}, {"func": "custom::func() -> bool"},
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )
( (
self.fragment_custom_native_function, self.fragment_custom_native_function,
_, _,
) = torchgen.model.NativeFunction.from_yaml( ) = NativeFunction.from_yaml(
{"func": "quantized_decomposed::func() -> bool"}, {"func": "quantized_decomposed::func() -> bool"},
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )
@ -285,9 +285,9 @@ TORCH_LIBRARY(custom, m) {
) )
def test_3_namespaces_schema_registration_code_valid(self) -> None: def test_3_namespaces_schema_registration_code_valid(self) -> None:
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml( custom2_native_function, _ = NativeFunction.from_yaml(
{"func": "custom2::func() -> bool"}, {"func": "custom2::func() -> bool"},
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )
( (
@ -320,7 +320,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
@ -328,11 +328,11 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
"func": "op_2() -> bool", "func": "op_2() -> bool",
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"}, "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
}, },
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {}, DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {}, DispatchKey.QuantizedCPU: {},
} }
@ -382,9 +382,9 @@ TORCH_API bool kernel_1();
# Test for native_function_generation # Test for native_function_generation
class TestNativeFunctionGeneratrion(unittest.TestCase): class TestNativeFunctionGeneratrion(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.native_functions: List[NativeFunction] = [] self.native_functions: list[NativeFunction] = []
self.backend_indices: Dict[ self.backend_indices: dict[
DispatchKey, Dict[OperatorName, BackendMetadata] DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict) ] = defaultdict(dict)
yaml_entry = """ yaml_entry = """
- func: op(Tensor self) -> Tensor - func: op(Tensor self) -> Tensor
@ -405,7 +405,7 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
"dispatch": {"CPU": "kernel_1"}, "dispatch": {"CPU": "kernel_1"},
"autogen": "op_2.out", "autogen": "op_2.out",
}, },
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index) BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
@ -442,8 +442,8 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
# Test for static_dispatch # Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase): class TestStaticDispatchGeneratrion(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.backend_indices: Dict[ self.backend_indices: dict[
DispatchKey, Dict[OperatorName, BackendMetadata] DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict) ] = defaultdict(dict)
yaml_entry = """ yaml_entry = """
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
@ -500,9 +500,9 @@ class TestStaticDispatchGeneratrion(unittest.TestCase):
# Represents the most basic NativeFunction. Use dataclasses.replace() # Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use. # to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml( DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "func() -> bool"}, {"func": "func() -> bool"},
loc=torchgen.model.Location(__file__, 1), loc=Location(__file__, 1),
valid_tags=set(), valid_tags=set(),
) )

View File

@ -1,4 +1,6 @@
from typing import Any, List from __future__ import annotations
from typing import Any
from unittest import main, TestCase from unittest import main, TestCase
from tools.alerts.create_alerts import filter_job_names, JobStatus from tools.alerts.create_alerts import filter_job_names, JobStatus
@ -38,7 +40,7 @@ MOCK_TEST_DATA = [
class TestGitHubPR(TestCase): class TestGitHubPR(TestCase):
# Should fail when jobs are ? ? Fail Fail # Should fail when jobs are ? ? Fail Fail
def test_alert(self) -> None: def test_alert(self) -> None:
modified_data: List[Any] = [{}] modified_data: list[Any] = [{}]
modified_data.append({}) modified_data.append({})
modified_data.extend(MOCK_TEST_DATA) modified_data.extend(MOCK_TEST_DATA)
status = JobStatus(JOB_NAME, modified_data) status = JobStatus(JOB_NAME, modified_data)

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import tempfile import tempfile
import unittest import unittest
from typing import Any, Dict from typing import Any
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, Mock, patch
import expecttest import expecttest
@ -13,10 +15,11 @@ from torchgen.model import Location, NativeFunction
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager from torchgen.utils import FileManager
SPACES = " " SPACES = " "
def _get_native_function_from_yaml(yaml_obj: Dict[str, object]) -> NativeFunction: def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
native_function, _ = NativeFunction.from_yaml( native_function, _ = NativeFunction.from_yaml(
yaml_obj, yaml_obj,
loc=Location(__file__, 1), loc=Location(__file__, 1),
@ -33,7 +36,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
""" """
def _test_function_schema_generates_correct_kernel( def _test_function_schema_generates_correct_kernel(
self, obj: Dict[str, Any], expected: str self, obj: dict[str, Any], expected: str
) -> None: ) -> None:
func = _get_native_function_from_yaml(obj) func = _get_native_function_from_yaml(obj)

View File

@ -1,13 +1,13 @@
from __future__ import annotations
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Dict
import yaml import yaml
from torchgen.executorch.model import ETKernelIndex, ETKernelKey from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader from torchgen.gen import LineLoader
from torchgen.gen_executorch import ( from torchgen.gen_executorch import (
ComputeCodegenUnboxedKernels, ComputeCodegenUnboxedKernels,
gen_functions_declarations, gen_functions_declarations,
@ -24,6 +24,7 @@ from torchgen.model import (
) )
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder
TEST_YAML = """ TEST_YAML = """
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator device_check: NoCheck # TensorIterator
@ -345,7 +346,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
valid_tags=set(), valid_tags=set(),
) )
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {}, DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {}, DispatchKey.QuantizedCPU: {},
} }

View File

@ -4,6 +4,7 @@ from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.local import parametrize from torchgen.local import parametrize
from torchgen.model import Location, NativeFunction from torchgen.model import Location, NativeFunction
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
loc=Location(__file__, 1), loc=Location(__file__, 1),

View File

@ -1,9 +1,10 @@
# Owner(s): ["module: codegen"] # Owner(s): ["module: codegen"]
from __future__ import annotations
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Optional
import expecttest import expecttest
@ -29,7 +30,7 @@ class TestGenBackendStubs(expecttest.TestCase):
run(fp.name, "", True) run(fp.name, "", True)
def get_errors_from_gen_backend_stubs( def get_errors_from_gen_backend_stubs(
self, yaml_str: str, *, kernels_str: Optional[str] = None self, yaml_str: str, *, kernels_str: str | None = None
) -> str: ) -> str:
with tempfile.NamedTemporaryFile(mode="w") as fp: with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_str) fp.write(yaml_str)

View File

@ -1,7 +1,7 @@
import unittest import unittest
from torchgen.selective_build.operator import * # noqa: F403
from torchgen.model import Location, NativeFunction from torchgen.model import Location, NativeFunction
from torchgen.selective_build.operator import * # noqa: F403
from torchgen.selective_build.selector import ( from torchgen.selective_build.selector import (
combine_selective_builders, combine_selective_builders,
SelectiveBuilder, SelectiveBuilder,
@ -9,7 +9,7 @@ from torchgen.selective_build.selector import (
class TestSelectiveBuild(unittest.TestCase): class TestSelectiveBuild(unittest.TestCase):
def test_selective_build_operator(self): def test_selective_build_operator(self) -> None:
op = SelectiveBuildOperator( op = SelectiveBuildOperator(
"aten::add.int", "aten::add.int",
is_root_operator=True, is_root_operator=True,
@ -21,7 +21,7 @@ class TestSelectiveBuild(unittest.TestCase):
self.assertFalse(op.is_used_for_training) self.assertFalse(op.is_used_for_training)
self.assertFalse(op.include_all_overloads) self.assertFalse(op.include_all_overloads)
def test_selector_factory(self): def test_selector_factory(self) -> None:
yaml_config_v1 = """ yaml_config_v1 = """
debug_info: debug_info:
- model1@v100 - model1@v100
@ -132,7 +132,7 @@ operators:
selector_legacy_v1.is_operator_selected_for_training("aten::add.float") selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
) )
def test_operator_combine(self): def test_operator_combine(self) -> None:
op1 = SelectiveBuildOperator( op1 = SelectiveBuildOperator(
"aten::add.int", "aten::add.int",
is_root_operator=True, is_root_operator=True,
@ -177,7 +177,7 @@ operators:
self.assertRaises(Exception, gen_new_op) self.assertRaises(Exception, gen_new_op)
def test_training_op_fetch(self): def test_training_op_fetch(self) -> None:
yaml_config = """ yaml_config = """
operators: operators:
aten::add.int: aten::add.int:
@ -194,7 +194,7 @@ operators:
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int")) self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
self.assertTrue(selector.is_operator_selected_for_training("aten::add")) self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
def test_kernel_dtypes(self): def test_kernel_dtypes(self) -> None:
yaml_config = """ yaml_config = """
kernel_metadata: kernel_metadata:
add_kernel: add_kernel:
@ -221,7 +221,7 @@ kernel_metadata:
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
def test_merge_kernel_dtypes(self): def test_merge_kernel_dtypes(self) -> None:
yaml_config1 = """ yaml_config1 = """
kernel_metadata: kernel_metadata:
add_kernel: add_kernel:
@ -266,7 +266,7 @@ kernel_metadata:
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8")) self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32")) self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
def test_all_kernel_dtypes_selected(self): def test_all_kernel_dtypes_selected(self) -> None:
yaml_config = """ yaml_config = """
include_all_non_op_selectives: True include_all_non_op_selectives: True
""" """
@ -279,7 +279,7 @@ include_all_non_op_selectives: True
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
def test_custom_namespace_selected_correctly(self): def test_custom_namespace_selected_correctly(self) -> None:
yaml_config = """ yaml_config = """
operators: operators:
aten::add.int: aten::add.int:
@ -301,7 +301,7 @@ operators:
class TestExecuTorchSelectiveBuild(unittest.TestCase): class TestExecuTorchSelectiveBuild(unittest.TestCase):
def test_et_kernel_selected(self): def test_et_kernel_selected(self) -> None:
yaml_config = """ yaml_config = """
et_kernel_metadata: et_kernel_metadata:
aten::add.out: aten::add.out:

Some files were not shown because too many files have changed in this diff Show More