mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][Easy] enable postponed annotations in tools (#129375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375 Approved by: https://github.com/malfet
This commit is contained in:
parent
58f346c874
commit
8a67daf283
|
|
@ -1,16 +1,19 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from setuptools import distutils # type: ignore[import]
|
from setuptools import distutils # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
ALL_SKIPPED_THRESHOLD = 100
|
ALL_SKIPPED_THRESHOLD = 100
|
||||||
SIMILARITY_THRESHOLD = 0.75
|
SIMILARITY_THRESHOLD = 0.75
|
||||||
FAILURE_CHAIN_THRESHOLD = 2
|
FAILURE_CHAIN_THRESHOLD = 2
|
||||||
|
|
@ -65,14 +68,14 @@ DISABLED_ALERTS = [
|
||||||
|
|
||||||
class JobStatus:
|
class JobStatus:
|
||||||
job_name: str = ""
|
job_name: str = ""
|
||||||
jobs: List[Any] = []
|
jobs: list[Any] = []
|
||||||
current_status: Any = None
|
current_status: Any = None
|
||||||
job_statuses: List[Any] = []
|
job_statuses: list[Any] = []
|
||||||
filtered_statuses: List[Any] = []
|
filtered_statuses: list[Any] = []
|
||||||
failure_chain: List[Any] = []
|
failure_chain: list[Any] = []
|
||||||
flaky_jobs: List[Any] = []
|
flaky_jobs: list[Any] = []
|
||||||
|
|
||||||
def __init__(self, job_name: str, job_statuses: List[Any]):
|
def __init__(self, job_name: str, job_statuses: list[Any]) -> None:
|
||||||
self.job_name = job_name
|
self.job_name = job_name
|
||||||
self.job_statuses = job_statuses
|
self.job_statuses = job_statuses
|
||||||
|
|
||||||
|
|
@ -93,7 +96,7 @@ class JobStatus:
|
||||||
return status
|
return status
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]:
|
def get_unique_failures(self, jobs: list[Any]) -> dict[str, list[Any]]:
|
||||||
"""
|
"""
|
||||||
Returns list of jobs grouped by failureCaptures from the input list
|
Returns list of jobs grouped by failureCaptures from the input list
|
||||||
"""
|
"""
|
||||||
|
|
@ -120,7 +123,7 @@ class JobStatus:
|
||||||
return failures
|
return failures
|
||||||
|
|
||||||
# A flaky job is if it's the only job that has that failureCapture and is not the most recent job
|
# A flaky job is if it's the only job that has that failureCapture and is not the most recent job
|
||||||
def get_flaky_jobs(self) -> List[Any]:
|
def get_flaky_jobs(self) -> list[Any]:
|
||||||
unique_failures = self.get_unique_failures(self.filtered_statuses)
|
unique_failures = self.get_unique_failures(self.filtered_statuses)
|
||||||
flaky_jobs = []
|
flaky_jobs = []
|
||||||
for failure in unique_failures:
|
for failure in unique_failures:
|
||||||
|
|
@ -134,7 +137,7 @@ class JobStatus:
|
||||||
|
|
||||||
# The most recent failure chain is an array of jobs that have the same-ish failures.
|
# The most recent failure chain is an array of jobs that have the same-ish failures.
|
||||||
# A success in the middle of the chain will terminate the chain.
|
# A success in the middle of the chain will terminate the chain.
|
||||||
def get_most_recent_failure_chain(self) -> List[Any]:
|
def get_most_recent_failure_chain(self) -> list[Any]:
|
||||||
failures = []
|
failures = []
|
||||||
found_most_recent_failure = False
|
found_most_recent_failure = False
|
||||||
|
|
||||||
|
|
@ -178,7 +181,7 @@ def fetch_hud_data(repo: str, branch: str) -> Any:
|
||||||
|
|
||||||
|
|
||||||
# Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD
|
# Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD
|
||||||
def map_job_data(jobNames: Any, shaGrid: Any) -> Dict[str, Any]:
|
def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
|
||||||
jobData = defaultdict(list)
|
jobData = defaultdict(list)
|
||||||
for sha in shaGrid:
|
for sha in shaGrid:
|
||||||
for ind, job in enumerate(sha["jobs"]):
|
for ind, job in enumerate(sha["jobs"]):
|
||||||
|
|
@ -196,13 +199,13 @@ def is_job_skipped(job: Any) -> bool:
|
||||||
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
|
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
|
||||||
|
|
||||||
|
|
||||||
def get_failed_jobs(job_data: List[Any]) -> List[Any]:
|
def get_failed_jobs(job_data: list[Any]) -> list[Any]:
|
||||||
return [job for job in job_data if job["conclusion"] == "failure"]
|
return [job for job in job_data if job["conclusion"] == "failure"]
|
||||||
|
|
||||||
|
|
||||||
def classify_jobs(
|
def classify_jobs(
|
||||||
all_job_names: List[str], sha_grid: Any, filtered_jobs_names: Set[str]
|
all_job_names: list[str], sha_grid: Any, filtered_jobs_names: set[str]
|
||||||
) -> Tuple[List[JobStatus], List[Any]]:
|
) -> tuple[list[JobStatus], list[Any]]:
|
||||||
"""
|
"""
|
||||||
Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs.
|
Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs.
|
||||||
Classifies jobs into jobs to alert on and flaky jobs.
|
Classifies jobs into jobs to alert on and flaky jobs.
|
||||||
|
|
@ -212,7 +215,7 @@ def classify_jobs(
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
job_data = map_job_data(all_job_names, sha_grid)
|
job_data = map_job_data(all_job_names, sha_grid)
|
||||||
job_statuses: List[JobStatus] = []
|
job_statuses: list[JobStatus] = []
|
||||||
for job in job_data:
|
for job in job_data:
|
||||||
job_statuses.append(JobStatus(job, job_data[job]))
|
job_statuses.append(JobStatus(job, job_data[job]))
|
||||||
|
|
||||||
|
|
@ -230,7 +233,7 @@ def classify_jobs(
|
||||||
|
|
||||||
|
|
||||||
# filter job names that don't match the regex
|
# filter job names that don't match the regex
|
||||||
def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
|
def filter_job_names(job_names: list[str], job_name_regex: str) -> list[str]:
|
||||||
if job_name_regex:
|
if job_name_regex:
|
||||||
return [
|
return [
|
||||||
job_name for job_name in job_names if re.match(job_name_regex, job_name)
|
job_name for job_name in job_names if re.match(job_name_regex, job_name)
|
||||||
|
|
@ -240,7 +243,7 @@ def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
|
||||||
|
|
||||||
def get_recurrently_failing_jobs_alerts(
|
def get_recurrently_failing_jobs_alerts(
|
||||||
repo: str, branch: str, job_name_regex: str
|
repo: str, branch: str, job_name_regex: str
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch)
|
job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch)
|
||||||
|
|
||||||
filtered_job_names = set(filter_job_names(job_names, job_name_regex))
|
filtered_job_names = set(filter_job_names(job_names, job_name_regex))
|
||||||
|
|
|
||||||
|
|
@ -14,18 +14,17 @@ generated. In the full build system, OUTPUT_DIR is
|
||||||
torch/testing/_internal/generated
|
torch/testing/_internal/generated
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Any, Sequence, TYPE_CHECKING
|
||||||
from typing import Any, Dict, List, Sequence
|
|
||||||
|
|
||||||
import torchgen.api.python as python
|
import torchgen.api.python as python
|
||||||
from torchgen.context import with_native_function
|
from torchgen.context import with_native_function
|
||||||
|
|
||||||
from torchgen.gen import parse_native_yaml
|
from torchgen.gen import parse_native_yaml
|
||||||
from torchgen.model import Argument, BaseOperatorName, NativeFunction
|
|
||||||
from torchgen.utils import FileManager
|
from torchgen.utils import FileManager
|
||||||
|
|
||||||
from .gen_python_functions import (
|
from .gen_python_functions import (
|
||||||
|
|
@ -39,6 +38,10 @@ from .gen_python_functions import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.model import Argument, BaseOperatorName, NativeFunction
|
||||||
|
|
||||||
|
|
||||||
def gen_annotated(
|
def gen_annotated(
|
||||||
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
|
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -53,9 +56,9 @@ def gen_annotated(
|
||||||
(is_py_fft_function, "torch._C._fft"),
|
(is_py_fft_function, "torch._C._fft"),
|
||||||
(is_py_variable_method, "torch.Tensor"),
|
(is_py_variable_method, "torch.Tensor"),
|
||||||
)
|
)
|
||||||
annotated_args: List[str] = []
|
annotated_args: list[str] = []
|
||||||
for pred, namespace in mappings:
|
for pred, namespace in mappings:
|
||||||
groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
|
groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
|
||||||
for f in native_functions:
|
for f in native_functions:
|
||||||
if not should_generate_py_binding(f) or not pred(f):
|
if not should_generate_py_binding(f) or not pred(f):
|
||||||
continue
|
continue
|
||||||
|
|
@ -77,7 +80,7 @@ def gen_annotated(
|
||||||
|
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def gen_annotated_args(f: NativeFunction) -> str:
|
def gen_annotated_args(f: NativeFunction) -> str:
|
||||||
def _get_kwargs_func_exclusion_list() -> List[str]:
|
def _get_kwargs_func_exclusion_list() -> list[str]:
|
||||||
# functions that currently don't work with kwargs in test_overrides.py
|
# functions that currently don't work with kwargs in test_overrides.py
|
||||||
return [
|
return [
|
||||||
"diagonal",
|
"diagonal",
|
||||||
|
|
@ -87,12 +90,12 @@ def gen_annotated_args(f: NativeFunction) -> str:
|
||||||
]
|
]
|
||||||
|
|
||||||
def _add_out_arg(
|
def _add_out_arg(
|
||||||
out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
|
out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if arg.default is not None:
|
if arg.default is not None:
|
||||||
continue
|
continue
|
||||||
out_arg: Dict[str, Any] = {}
|
out_arg: dict[str, Any] = {}
|
||||||
out_arg["is_kwarg_only"] = str(is_kwarg_only)
|
out_arg["is_kwarg_only"] = str(is_kwarg_only)
|
||||||
out_arg["name"] = arg.name
|
out_arg["name"] = arg.name
|
||||||
out_arg["simple_type"] = python.argument_type_str(
|
out_arg["simple_type"] = python.argument_type_str(
|
||||||
|
|
@ -103,7 +106,7 @@ def gen_annotated_args(f: NativeFunction) -> str:
|
||||||
out_arg["size"] = size_t
|
out_arg["size"] = size_t
|
||||||
out_args.append(out_arg)
|
out_args.append(out_arg)
|
||||||
|
|
||||||
out_args: List[Dict[str, Any]] = []
|
out_args: list[dict[str, Any]] = []
|
||||||
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
|
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
|
||||||
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
|
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
|
||||||
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
|
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,10 @@ torch/csrc/autograd/generated/
|
||||||
# gen_python_functions.py: generates Python bindings to THPVariable
|
# gen_python_functions.py: generates Python bindings to THPVariable
|
||||||
#
|
#
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
|
|
@ -69,7 +70,7 @@ def gen_autograd(
|
||||||
),
|
),
|
||||||
key=lambda f: cpp.name(f.func),
|
key=lambda f: cpp.name(f.func),
|
||||||
)
|
)
|
||||||
fns_with_diff_infos: List[
|
fns_with_diff_infos: list[
|
||||||
NativeFunctionWithDifferentiabilityInfo
|
NativeFunctionWithDifferentiabilityInfo
|
||||||
] = match_differentiability_info(fns, differentiability_infos)
|
] = match_differentiability_info(fns, differentiability_infos)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,10 @@
|
||||||
# Functions.h/cpp: subclasses of autograd::Node
|
# Functions.h/cpp: subclasses of autograd::Node
|
||||||
# python_functions.h/cpp: Python bindings for the above classes
|
# python_functions.h/cpp: Python bindings for the above classes
|
||||||
#
|
#
|
||||||
from typing import Dict, List, Sequence, Tuple
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
Derivative,
|
Derivative,
|
||||||
|
|
@ -43,6 +46,7 @@ from torchgen.utils import FileManager
|
||||||
|
|
||||||
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
|
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
FUNCTION_DECLARATION = CodeTemplate(
|
FUNCTION_DECLARATION = CodeTemplate(
|
||||||
"""\
|
"""\
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
|
@ -443,8 +447,8 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
def get_infos_with_derivatives_list(
|
def get_infos_with_derivatives_list(
|
||||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
|
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
|
||||||
) -> List[DifferentiabilityInfo]:
|
) -> list[DifferentiabilityInfo]:
|
||||||
diff_info_list = [
|
diff_info_list = [
|
||||||
info
|
info
|
||||||
for diffinfo_dict in differentiability_infos.values()
|
for diffinfo_dict in differentiability_infos.values()
|
||||||
|
|
@ -456,7 +460,7 @@ def get_infos_with_derivatives_list(
|
||||||
|
|
||||||
def gen_autograd_functions_lib(
|
def gen_autograd_functions_lib(
|
||||||
out: str,
|
out: str,
|
||||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||||
template_path: str,
|
template_path: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Functions.h and Functions.cpp body
|
"""Functions.h and Functions.cpp body
|
||||||
|
|
@ -490,7 +494,7 @@ def gen_autograd_functions_lib(
|
||||||
|
|
||||||
def gen_autograd_functions_python(
|
def gen_autograd_functions_python(
|
||||||
out: str,
|
out: str,
|
||||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||||
template_path: str,
|
template_path: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||||
|
|
@ -536,17 +540,17 @@ def gen_autograd_functions_python(
|
||||||
|
|
||||||
|
|
||||||
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
|
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
|
||||||
saved_variables: List[str] = []
|
saved_variables: list[str] = []
|
||||||
release_variables: List[str] = []
|
release_variables: list[str] = []
|
||||||
saved_list_sizes: List[str] = []
|
saved_list_sizes: list[str] = []
|
||||||
unpack: List[str] = []
|
unpack: list[str] = []
|
||||||
asserts: List[str] = []
|
asserts: list[str] = []
|
||||||
compute_index_ranges: List[str] = []
|
compute_index_ranges: list[str] = []
|
||||||
getter_definitions: List[str] = []
|
getter_definitions: list[str] = []
|
||||||
py_getsetdef_structs: List[str] = []
|
py_getsetdef_structs: list[str] = []
|
||||||
compiled_args: List[str] = []
|
compiled_args: list[str] = []
|
||||||
apply_with_saved_before: List[str] = []
|
apply_with_saved_before: list[str] = []
|
||||||
apply_with_saved_after: List[str] = []
|
apply_with_saved_after: list[str] = []
|
||||||
|
|
||||||
for arg in info.args_with_derivatives:
|
for arg in info.args_with_derivatives:
|
||||||
if arg.type in TENSOR_LIST_LIKE_CTYPES:
|
if arg.type in TENSOR_LIST_LIKE_CTYPES:
|
||||||
|
|
@ -807,7 +811,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||||
else:
|
else:
|
||||||
will_release_variables = ""
|
will_release_variables = ""
|
||||||
|
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
|
|
||||||
if uses_single_grad(info):
|
if uses_single_grad(info):
|
||||||
body.append("const auto& grad = grads[0];")
|
body.append("const auto& grad = grads[0];")
|
||||||
|
|
@ -821,7 +825,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||||
def emit_derivative(
|
def emit_derivative(
|
||||||
derivative: Derivative,
|
derivative: Derivative,
|
||||||
args_with_derivatives: Sequence[Binding],
|
args_with_derivatives: Sequence[Binding],
|
||||||
) -> Tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
formula = derivative.formula
|
formula = derivative.formula
|
||||||
var_names = derivative.var_names
|
var_names = derivative.var_names
|
||||||
if len(var_names) == 1:
|
if len(var_names) == 1:
|
||||||
|
|
@ -857,7 +861,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||||
else:
|
else:
|
||||||
grad_input_mask = ""
|
grad_input_mask = ""
|
||||||
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
|
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
|
||||||
copy_ranges: List[str] = []
|
copy_ranges: list[str] = []
|
||||||
for i, n in enumerate(var_names):
|
for i, n in enumerate(var_names):
|
||||||
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
|
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
|
||||||
return False, DERIVATIVE_MULTI.substitute(
|
return False, DERIVATIVE_MULTI.substitute(
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
||||||
# The fallback is expected to mimick this codegen, so we should keep the two in sync.
|
# The fallback is expected to mimick this codegen, so we should keep the two in sync.
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
|
|
@ -24,8 +24,7 @@ from torchgen.api.types import (
|
||||||
OptionalCType,
|
OptionalCType,
|
||||||
symIntArrayRefT,
|
symIntArrayRefT,
|
||||||
SymIntT,
|
SymIntT,
|
||||||
# See Note [Nested Arg Types]
|
tensorT, # See Note [Nested Arg Types]
|
||||||
tensorT,
|
|
||||||
)
|
)
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
from torchgen.context import with_native_function
|
from torchgen.context import with_native_function
|
||||||
|
|
@ -46,6 +45,7 @@ from .gen_trace_type import (
|
||||||
type_wrapper_name,
|
type_wrapper_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
||||||
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
|
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
|
||||||
# you **MUST** also update the public list of view ops accordingly in
|
# you **MUST** also update the public list of view ops accordingly in
|
||||||
|
|
@ -281,7 +281,7 @@ def inverse_view_name(f: NativeFunction) -> str:
|
||||||
return f"{copy_variant}{overload}_inverse"
|
return f"{copy_variant}{overload}_inverse"
|
||||||
|
|
||||||
|
|
||||||
def extract_bindings(f: NativeFunction) -> List[Binding]:
|
def extract_bindings(f: NativeFunction) -> list[Binding]:
|
||||||
return [
|
return [
|
||||||
r
|
r
|
||||||
for a in f.func.schema_order_arguments()
|
for a in f.func.schema_order_arguments()
|
||||||
|
|
@ -297,9 +297,9 @@ def extract_bindings(f: NativeFunction) -> List[Binding]:
|
||||||
|
|
||||||
|
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
|
def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
unpacked_bindings: List[Binding] = []
|
unpacked_bindings: list[Binding] = []
|
||||||
|
|
||||||
for i, binding in enumerate(extract_bindings(f)):
|
for i, binding in enumerate(extract_bindings(f)):
|
||||||
assert not isinstance(binding.argument, SelfArgument)
|
assert not isinstance(binding.argument, SelfArgument)
|
||||||
|
|
@ -338,7 +338,7 @@ def get_base_name(f: NativeFunction) -> str:
|
||||||
return f.func.name.name.base # TODO: should be str(f.func.name.name)?
|
return f.func.name.name.base # TODO: should be str(f.func.name.name)?
|
||||||
|
|
||||||
|
|
||||||
def get_view_info(f: NativeFunction) -> Optional[str]:
|
def get_view_info(f: NativeFunction) -> str | None:
|
||||||
base_name = get_base_name(f)
|
base_name = get_base_name(f)
|
||||||
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
||||||
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
|
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
|
||||||
|
|
@ -347,7 +347,7 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
|
||||||
|
|
||||||
|
|
||||||
def emit_view_func(
|
def emit_view_func(
|
||||||
f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None
|
f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
|
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
|
||||||
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
|
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
|
||||||
|
|
@ -355,8 +355,8 @@ def emit_view_func(
|
||||||
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
|
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
|
||||||
input_base = "input_base"
|
input_base = "input_base"
|
||||||
replay_view_func = ""
|
replay_view_func = ""
|
||||||
updated_args: List[str] = []
|
updated_args: list[str] = []
|
||||||
known_view_arg_simple_types: List[CType] = [
|
known_view_arg_simple_types: list[CType] = [
|
||||||
BaseCType(longT),
|
BaseCType(longT),
|
||||||
OptionalCType(BaseCType(longT)),
|
OptionalCType(BaseCType(longT)),
|
||||||
BaseCType(SymIntT),
|
BaseCType(SymIntT),
|
||||||
|
|
@ -448,7 +448,7 @@ def emit_view_func(
|
||||||
|
|
||||||
def emit_view_body(
|
def emit_view_body(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo, var: str
|
fn: NativeFunctionWithDifferentiabilityInfo, var: str
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
||||||
f = fn.func
|
f = fn.func
|
||||||
base_name = get_base_name(f)
|
base_name = get_base_name(f)
|
||||||
|
|
@ -523,9 +523,9 @@ def modifies_arguments(f: NativeFunction) -> bool:
|
||||||
|
|
||||||
|
|
||||||
@with_native_function_with_differentiability_info
|
@with_native_function_with_differentiability_info
|
||||||
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
|
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
|
||||||
f = fn.func
|
f = fn.func
|
||||||
inplace_view_body: List[str] = []
|
inplace_view_body: list[str] = []
|
||||||
|
|
||||||
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
||||||
dispatcher_exprs = dispatcher_sig.exprs()
|
dispatcher_exprs = dispatcher_sig.exprs()
|
||||||
|
|
@ -584,7 +584,7 @@ def gen_formals(f: NativeFunction) -> str:
|
||||||
@with_native_function_with_differentiability_info
|
@with_native_function_with_differentiability_info
|
||||||
def inplace_or_view_method_definition(
|
def inplace_or_view_method_definition(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
f = fn.func
|
f = fn.func
|
||||||
if get_view_info(f) is None and (
|
if get_view_info(f) is None and (
|
||||||
# For functions that modify their inputs but don't return them,
|
# For functions that modify their inputs but don't return them,
|
||||||
|
|
@ -605,7 +605,7 @@ def inplace_or_view_method_definition(
|
||||||
@with_native_function_with_differentiability_info
|
@with_native_function_with_differentiability_info
|
||||||
def inplace_or_view_method_registration(
|
def inplace_or_view_method_registration(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
f = fn.func
|
f = fn.func
|
||||||
if get_view_info(f) is None and (
|
if get_view_info(f) is None and (
|
||||||
not modifies_arguments(f) or len(f.func.returns) == 0
|
not modifies_arguments(f) or len(f.func.returns) == 0
|
||||||
|
|
@ -626,7 +626,7 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
|
||||||
|
|
||||||
def gen_inplace_or_view_type_env(
|
def gen_inplace_or_view_type_env(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
definition = inplace_or_view_method_definition(fn)
|
definition = inplace_or_view_method_definition(fn)
|
||||||
registration = inplace_or_view_method_registration(fn)
|
registration = inplace_or_view_method_registration(fn)
|
||||||
|
|
||||||
|
|
@ -649,7 +649,7 @@ def gen_inplace_or_view_type(
|
||||||
out: str,
|
out: str,
|
||||||
native_yaml_path: str,
|
native_yaml_path: str,
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
|
||||||
template_path: str,
|
template_path: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
||||||
|
|
|
||||||
|
|
@ -31,11 +31,12 @@
|
||||||
# message, but use what's there
|
# message, but use what's there
|
||||||
#
|
#
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Callable, Iterable, Sequence
|
||||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -56,7 +57,6 @@ from torchgen.api.python import (
|
||||||
signature_from_schema,
|
signature_from_schema,
|
||||||
structseq_fieldnames,
|
structseq_fieldnames,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
from torchgen.context import with_native_function
|
from torchgen.context import with_native_function
|
||||||
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
|
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
|
||||||
|
|
@ -75,6 +75,7 @@ from torchgen.yaml_utils import YamlLoader
|
||||||
from .gen_inplace_or_view_type import is_tensor_list_type
|
from .gen_inplace_or_view_type import is_tensor_list_type
|
||||||
from .gen_trace_type import should_trace
|
from .gen_trace_type import should_trace
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# declarations blocklist
|
# declarations blocklist
|
||||||
# We skip codegen for these functions, for various reasons.
|
# We skip codegen for these functions, for various reasons.
|
||||||
|
|
@ -369,7 +370,7 @@ def gen(
|
||||||
|
|
||||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||||
|
|
||||||
def gen_tags_enum() -> Dict[str, str]:
|
def gen_tags_enum() -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
"enum_of_valid_tags": (
|
"enum_of_valid_tags": (
|
||||||
"".join(
|
"".join(
|
||||||
|
|
@ -384,9 +385,9 @@ def gen(
|
||||||
def group_filter_overloads(
|
def group_filter_overloads(
|
||||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
pred: Callable[[NativeFunction], bool],
|
pred: Callable[[NativeFunction], bool],
|
||||||
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
|
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
|
||||||
grouped: Dict[
|
grouped: dict[
|
||||||
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
|
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
|
||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
if pred(pair.function):
|
if pred(pair.function):
|
||||||
|
|
@ -398,17 +399,17 @@ def create_python_bindings(
|
||||||
fm: FileManager,
|
fm: FileManager,
|
||||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
pred: Callable[[NativeFunction], bool],
|
pred: Callable[[NativeFunction], bool],
|
||||||
module: Optional[str],
|
module: str | None,
|
||||||
filename: str,
|
filename: str,
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
symint: bool = True,
|
symint: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generates Python bindings to ATen functions"""
|
"""Generates Python bindings to ATen functions"""
|
||||||
py_methods: List[str] = []
|
py_methods: list[str] = []
|
||||||
ops_headers: List[str] = []
|
ops_headers: list[str] = []
|
||||||
py_method_defs: List[str] = []
|
py_method_defs: list[str] = []
|
||||||
py_forwards: List[str] = []
|
py_forwards: list[str] = []
|
||||||
|
|
||||||
grouped = group_filter_overloads(pairs, pred)
|
grouped = group_filter_overloads(pairs, pred)
|
||||||
|
|
||||||
|
|
@ -445,8 +446,8 @@ def create_python_return_type_bindings(
|
||||||
Generate function to initialize and return named tuple for native functions
|
Generate function to initialize and return named tuple for native functions
|
||||||
which returns named tuple and registration invocations in `python_return_types.cpp`.
|
which returns named tuple and registration invocations in `python_return_types.cpp`.
|
||||||
"""
|
"""
|
||||||
py_return_types_definition: List[str] = []
|
py_return_types_definition: list[str] = []
|
||||||
py_return_types_registrations: List[str] = []
|
py_return_types_registrations: list[str] = []
|
||||||
|
|
||||||
grouped = group_filter_overloads(pairs, pred)
|
grouped = group_filter_overloads(pairs, pred)
|
||||||
|
|
||||||
|
|
@ -484,7 +485,7 @@ def create_python_return_type_bindings_header(
|
||||||
Generate function to initialize and return named tuple for native functions
|
Generate function to initialize and return named tuple for native functions
|
||||||
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
||||||
"""
|
"""
|
||||||
py_return_types_declarations: List[str] = []
|
py_return_types_declarations: list[str] = []
|
||||||
|
|
||||||
grouped = group_filter_overloads(pairs, pred)
|
grouped = group_filter_overloads(pairs, pred)
|
||||||
|
|
||||||
|
|
@ -510,7 +511,7 @@ def create_python_bindings_sharded(
|
||||||
fm: FileManager,
|
fm: FileManager,
|
||||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
pred: Callable[[NativeFunction], bool],
|
pred: Callable[[NativeFunction], bool],
|
||||||
module: Optional[str],
|
module: str | None,
|
||||||
filename: str,
|
filename: str,
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
|
|
@ -521,13 +522,13 @@ def create_python_bindings_sharded(
|
||||||
grouped = group_filter_overloads(pairs, pred)
|
grouped = group_filter_overloads(pairs, pred)
|
||||||
|
|
||||||
def key_func(
|
def key_func(
|
||||||
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||||
) -> str:
|
) -> str:
|
||||||
return kv[0].base
|
return kv[0].base
|
||||||
|
|
||||||
def env_func(
|
def env_func(
|
||||||
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
name, fn_pairs = kv
|
name, fn_pairs = kv
|
||||||
return {
|
return {
|
||||||
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
||||||
|
|
@ -553,7 +554,7 @@ def create_python_bindings_sharded(
|
||||||
|
|
||||||
|
|
||||||
def load_signatures(
|
def load_signatures(
|
||||||
native_functions: List[NativeFunction],
|
native_functions: list[NativeFunction],
|
||||||
deprecated_yaml_path: str,
|
deprecated_yaml_path: str,
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
|
|
@ -580,19 +581,19 @@ def load_deprecated_signatures(
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
pyi: bool,
|
pyi: bool,
|
||||||
) -> List[PythonSignatureNativeFunctionPair]:
|
) -> list[PythonSignatureNativeFunctionPair]:
|
||||||
# The deprecated.yaml doesn't have complete type information, we need
|
# The deprecated.yaml doesn't have complete type information, we need
|
||||||
# find and leverage the original ATen signature (to which it delegates
|
# find and leverage the original ATen signature (to which it delegates
|
||||||
# the call) to generate the full python signature.
|
# the call) to generate the full python signature.
|
||||||
# We join the deprecated and the original signatures using type-only form.
|
# We join the deprecated and the original signatures using type-only form.
|
||||||
|
|
||||||
# group the original ATen signatures by name
|
# group the original ATen signatures by name
|
||||||
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
grouped[pair.signature.name].append(pair)
|
grouped[pair.signature.name].append(pair)
|
||||||
|
|
||||||
# find matching original signatures for each deprecated signature
|
# find matching original signatures for each deprecated signature
|
||||||
results: List[PythonSignatureNativeFunctionPair] = []
|
results: list[PythonSignatureNativeFunctionPair] = []
|
||||||
|
|
||||||
with open(deprecated_yaml_path) as f:
|
with open(deprecated_yaml_path) as f:
|
||||||
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
||||||
|
|
@ -701,15 +702,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str:
|
||||||
|
|
||||||
def emit_structseq_call(
|
def emit_structseq_call(
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
) -> Tuple[List[str], Dict[str, str]]:
|
) -> tuple[list[str], dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Generate block of named tuple type def inits, and add typeref snippets
|
Generate block of named tuple type def inits, and add typeref snippets
|
||||||
to declarations that use them
|
to declarations that use them
|
||||||
"""
|
"""
|
||||||
typenames: Dict[
|
typenames: dict[
|
||||||
str, str
|
str, str
|
||||||
] = {} # map from unique name + field name lists to typedef name
|
] = {} # map from unique name + field name lists to typedef name
|
||||||
typedefs: List[str] = [] # typedef declarations and init code
|
typedefs: list[str] = [] # typedef declarations and init code
|
||||||
|
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||||
|
|
@ -732,17 +733,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
|
||||||
|
|
||||||
def generate_return_type_definition_and_registrations(
|
def generate_return_type_definition_and_registrations(
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Generate block of function in `python_return_types.cpp` to initialize
|
Generate block of function in `python_return_types.cpp` to initialize
|
||||||
and return named tuple for a native function which returns named tuple
|
and return named tuple for a native function which returns named tuple
|
||||||
and registration invocations in same file.
|
and registration invocations in same file.
|
||||||
"""
|
"""
|
||||||
typenames: Dict[
|
typenames: dict[
|
||||||
str, str
|
str, str
|
||||||
] = {} # map from unique name + field name lists to typedef name
|
] = {} # map from unique name + field name lists to typedef name
|
||||||
definitions: List[str] = [] # function definition to register the typedef
|
definitions: list[str] = [] # function definition to register the typedef
|
||||||
registrations: List[str] = [] # register call for the typedef
|
registrations: list[str] = [] # register call for the typedef
|
||||||
|
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||||
|
|
@ -783,15 +784,15 @@ PyTypeObject* get_{name}_structseq() {{
|
||||||
|
|
||||||
def generate_return_type_declarations(
|
def generate_return_type_declarations(
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Generate block of function declarations in `python_return_types.h` to initialize
|
Generate block of function declarations in `python_return_types.h` to initialize
|
||||||
and return named tuple for a native function.
|
and return named tuple for a native function.
|
||||||
"""
|
"""
|
||||||
typenames: Dict[
|
typenames: dict[
|
||||||
str, str
|
str, str
|
||||||
] = {} # map from unique name + field name lists to typedef name
|
] = {} # map from unique name + field name lists to typedef name
|
||||||
declarations: List[str] = [] # function declaration to register the typedef
|
declarations: list[str] = [] # function declaration to register the typedef
|
||||||
|
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||||
|
|
@ -891,7 +892,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
||||||
|
|
||||||
def method_impl(
|
def method_impl(
|
||||||
name: BaseOperatorName,
|
name: BaseOperatorName,
|
||||||
module: Optional[str],
|
module: str | None,
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
|
|
@ -918,8 +919,8 @@ def method_impl(
|
||||||
overloads, symint=symint
|
overloads, symint=symint
|
||||||
)
|
)
|
||||||
is_singleton = len(grouped_overloads) == 1
|
is_singleton = len(grouped_overloads) == 1
|
||||||
signatures: List[str] = []
|
signatures: list[str] = []
|
||||||
dispatch: List[str] = []
|
dispatch: list[str] = []
|
||||||
for overload_index, overload in enumerate(grouped_overloads):
|
for overload_index, overload in enumerate(grouped_overloads):
|
||||||
signature = overload.signature.signature_str(symint=symint)
|
signature = overload.signature.signature_str(symint=symint)
|
||||||
signatures.append(f"{cpp_string(str(signature))},")
|
signatures.append(f"{cpp_string(str(signature))},")
|
||||||
|
|
@ -959,7 +960,7 @@ def method_impl(
|
||||||
|
|
||||||
|
|
||||||
def gen_has_torch_function_check(
|
def gen_has_torch_function_check(
|
||||||
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
|
name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
if noarg:
|
if noarg:
|
||||||
if method:
|
if method:
|
||||||
|
|
@ -1007,7 +1008,7 @@ if (_r.isNone(${out_idx})) {
|
||||||
|
|
||||||
def emit_dispatch_case(
|
def emit_dispatch_case(
|
||||||
overload: PythonSignatureGroup,
|
overload: PythonSignatureGroup,
|
||||||
structseq_typenames: Dict[str, str],
|
structseq_typenames: dict[str, str],
|
||||||
*,
|
*,
|
||||||
symint: bool = True,
|
symint: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -1050,7 +1051,7 @@ def forward_decls(
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
) -> Tuple[str, ...]:
|
) -> tuple[str, ...]:
|
||||||
if method:
|
if method:
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
|
|
@ -1078,7 +1079,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
||||||
|
|
||||||
def method_def(
|
def method_def(
|
||||||
name: BaseOperatorName,
|
name: BaseOperatorName,
|
||||||
module: Optional[str],
|
module: str | None,
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||||
*,
|
*,
|
||||||
method: bool,
|
method: bool,
|
||||||
|
|
@ -1114,8 +1115,8 @@ def method_def(
|
||||||
def group_overloads(
|
def group_overloads(
|
||||||
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
||||||
) -> Sequence[PythonSignatureGroup]:
|
) -> Sequence[PythonSignatureGroup]:
|
||||||
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
bases: dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||||
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||||
|
|
||||||
# first group by signature ignoring out arguments
|
# first group by signature ignoring out arguments
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
|
|
@ -1137,7 +1138,7 @@ def group_overloads(
|
||||||
|
|
||||||
for sig, out in outplaces.items():
|
for sig, out in outplaces.items():
|
||||||
if sig not in bases:
|
if sig not in bases:
|
||||||
candidates: List[str] = []
|
candidates: list[str] = []
|
||||||
for overload in overloads:
|
for overload in overloads:
|
||||||
if (
|
if (
|
||||||
str(overload.function.func.name.name)
|
str(overload.function.func.name.name)
|
||||||
|
|
@ -1268,7 +1269,7 @@ def sort_overloads(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct the relation graph
|
# Construct the relation graph
|
||||||
larger_than: Dict[int, Set[int]] = defaultdict(set)
|
larger_than: dict[int, set[int]] = defaultdict(set)
|
||||||
for i1, overload1 in enumerate(grouped_overloads):
|
for i1, overload1 in enumerate(grouped_overloads):
|
||||||
for i2, overload2 in enumerate(grouped_overloads):
|
for i2, overload2 in enumerate(grouped_overloads):
|
||||||
if is_smaller(overload1.signature, overload2.signature):
|
if is_smaller(overload1.signature, overload2.signature):
|
||||||
|
|
@ -1279,7 +1280,7 @@ def sort_overloads(
|
||||||
|
|
||||||
# Use a topological sort to sort overloads according to the partial order.
|
# Use a topological sort to sort overloads according to the partial order.
|
||||||
N = len(grouped_overloads)
|
N = len(grouped_overloads)
|
||||||
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
||||||
|
|
||||||
for idx in range(N):
|
for idx in range(N):
|
||||||
# The size of sorted_ids will grow to N eventually.
|
# The size of sorted_ids will grow to N eventually.
|
||||||
|
|
@ -1304,7 +1305,7 @@ def sort_overloads(
|
||||||
def emit_single_dispatch(
|
def emit_single_dispatch(
|
||||||
ps: PythonSignature,
|
ps: PythonSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
structseq_typenames: Dict[str, str],
|
structseq_typenames: dict[str, str],
|
||||||
*,
|
*,
|
||||||
symint: bool = True,
|
symint: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Dict, List, Sequence, Union
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import DispatcherSignature
|
from torchgen.api.types import DispatcherSignature
|
||||||
|
|
@ -8,6 +10,7 @@ from torchgen.context import with_native_function
|
||||||
from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
|
from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
|
||||||
from torchgen.utils import FileManager
|
from torchgen.utils import FileManager
|
||||||
|
|
||||||
|
|
||||||
# Note [Manual Backend kernels]
|
# Note [Manual Backend kernels]
|
||||||
# For these ops, we want to manually register to dispatch key Backend and
|
# For these ops, we want to manually register to dispatch key Backend and
|
||||||
# skip codegen-ed registeration to all keys before Backend.
|
# skip codegen-ed registeration to all keys before Backend.
|
||||||
|
|
@ -136,9 +139,7 @@ ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${inpu
|
||||||
|
|
||||||
|
|
||||||
def format_trace_inputs(f: NativeFunction) -> str:
|
def format_trace_inputs(f: NativeFunction) -> str:
|
||||||
def dispatch_trace_input(
|
def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
|
||||||
arg: Union[Argument, TensorOptionsArguments]
|
|
||||||
) -> Sequence[str]:
|
|
||||||
if isinstance(arg, TensorOptionsArguments):
|
if isinstance(arg, TensorOptionsArguments):
|
||||||
name = "options"
|
name = "options"
|
||||||
return [
|
return [
|
||||||
|
|
@ -156,7 +157,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
|
||||||
else:
|
else:
|
||||||
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
|
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
|
||||||
|
|
||||||
args: List[Union[Argument, TensorOptionsArguments]] = list(
|
args: list[Argument | TensorOptionsArguments] = list(
|
||||||
f.func.schema_order_arguments()
|
f.func.schema_order_arguments()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -399,8 +400,8 @@ ${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def emit_trace_body(f: NativeFunction) -> List[str]:
|
def emit_trace_body(f: NativeFunction) -> list[str]:
|
||||||
trace_body: List[str] = []
|
trace_body: list[str] = []
|
||||||
|
|
||||||
trace_body.append(format_prerecord_trace(f))
|
trace_body.append(format_prerecord_trace(f))
|
||||||
|
|
||||||
|
|
@ -503,7 +504,7 @@ def method_registration(f: NativeFunction) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
|
def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
|
||||||
return {
|
return {
|
||||||
"ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
|
"ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
|
||||||
"trace_method_definitions": [method_definition(fn)],
|
"trace_method_definitions": [method_definition(fn)],
|
||||||
|
|
@ -512,7 +513,7 @@ def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
|
||||||
|
|
||||||
|
|
||||||
def gen_trace_type(
|
def gen_trace_type(
|
||||||
out: str, native_functions: List[NativeFunction], template_path: str
|
out: str, native_functions: list[NativeFunction], template_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
||||||
# template regarding sharding of the generated files.
|
# template regarding sharding of the generated files.
|
||||||
|
|
|
||||||
|
|
@ -2,18 +2,19 @@
|
||||||
#
|
#
|
||||||
# This writes one file: variable_factories.h
|
# This writes one file: variable_factories.h
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torchgen.api.python as python
|
import torchgen.api.python as python
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
|
|
||||||
from torchgen.api.types import CppSignatureGroup
|
from torchgen.api.types import CppSignatureGroup
|
||||||
from torchgen.context import with_native_function
|
from torchgen.context import with_native_function
|
||||||
from torchgen.gen import parse_native_yaml
|
from torchgen.gen import parse_native_yaml
|
||||||
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
|
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
|
||||||
from torchgen.utils import FileManager, mapMaybe
|
from torchgen.utils import FileManager, mapMaybe
|
||||||
|
|
||||||
|
|
||||||
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
|
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
|
||||||
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
|
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
|
||||||
|
|
||||||
|
|
@ -69,7 +70,7 @@ def is_factory_function(f: NativeFunction) -> bool:
|
||||||
|
|
||||||
|
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def process_function(f: NativeFunction) -> Optional[str]:
|
def process_function(f: NativeFunction) -> str | None:
|
||||||
name = cpp.name(f.func)
|
name = cpp.name(f.func)
|
||||||
has_tensor_options = python.has_tensor_options(f)
|
has_tensor_options = python.has_tensor_options(f)
|
||||||
is_factory = has_tensor_options or name.endswith("_like")
|
is_factory = has_tensor_options or name.endswith("_like")
|
||||||
|
|
@ -83,8 +84,8 @@ def process_function(f: NativeFunction) -> Optional[str]:
|
||||||
sigs.append(cpp_sigs.symint_signature)
|
sigs.append(cpp_sigs.symint_signature)
|
||||||
r = ""
|
r = ""
|
||||||
for sig in sigs:
|
for sig in sigs:
|
||||||
formals: List[str] = []
|
formals: list[str] = []
|
||||||
exprs: List[str] = []
|
exprs: list[str] = []
|
||||||
requires_grad = "false"
|
requires_grad = "false"
|
||||||
for arg in sig.arguments():
|
for arg in sig.arguments():
|
||||||
qualified_type = fully_qualified_type(arg.type)
|
qualified_type = fully_qualified_type(arg.type)
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,11 @@
|
||||||
# which will in turn dispatch back to VariableType for its
|
# which will in turn dispatch back to VariableType for its
|
||||||
# differentiable subcomponents.
|
# differentiable subcomponents.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
from typing import Callable, Sequence
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
|
|
@ -38,7 +41,6 @@ from torchgen.api.autograd import (
|
||||||
NativeFunctionWithDifferentiabilityInfo,
|
NativeFunctionWithDifferentiabilityInfo,
|
||||||
SavedAttribute,
|
SavedAttribute,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
ArrayRefCType,
|
ArrayRefCType,
|
||||||
BaseCppType,
|
BaseCppType,
|
||||||
|
|
@ -103,6 +105,7 @@ from .gen_trace_type import (
|
||||||
type_wrapper_name,
|
type_wrapper_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# We don't set or modify grad_fn on these methods. Generally, they return
|
# We don't set or modify grad_fn on these methods. Generally, they return
|
||||||
# tensors that have requires_grad=False. In-place functions listed here will
|
# tensors that have requires_grad=False. In-place functions listed here will
|
||||||
# not examine or modify requires_grad or grad_fn.
|
# not examine or modify requires_grad or grad_fn.
|
||||||
|
|
@ -837,9 +840,9 @@ def gen_variable_type(
|
||||||
out: str,
|
out: str,
|
||||||
native_yaml_path: str,
|
native_yaml_path: str,
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo],
|
||||||
template_path: str,
|
template_path: str,
|
||||||
used_keys: Set[str],
|
used_keys: set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""VariableType.h and VariableType.cpp body
|
"""VariableType.h and VariableType.cpp body
|
||||||
|
|
||||||
|
|
@ -858,8 +861,8 @@ def gen_variable_type(
|
||||||
|
|
||||||
# helper that generates a TORCH_LIBRARY_IMPL macro for each
|
# helper that generates a TORCH_LIBRARY_IMPL macro for each
|
||||||
# dispatch key that appears in derivatives.yaml
|
# dispatch key that appears in derivatives.yaml
|
||||||
def wrapper_registrations(used_keys: Set[str]) -> str:
|
def wrapper_registrations(used_keys: set[str]) -> str:
|
||||||
library_impl_macro_list: List[str] = []
|
library_impl_macro_list: list[str] = []
|
||||||
for key in sorted(used_keys):
|
for key in sorted(used_keys):
|
||||||
dispatch_key = key
|
dispatch_key = key
|
||||||
if key == "Default":
|
if key == "Default":
|
||||||
|
|
@ -926,7 +929,7 @@ def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
|
||||||
|
|
||||||
def gen_variable_type_func(
|
def gen_variable_type_func(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
f = fn.func
|
f = fn.func
|
||||||
result = {}
|
result = {}
|
||||||
with native_function_manager(f):
|
with native_function_manager(f):
|
||||||
|
|
@ -1034,7 +1037,7 @@ _foreach_ops_with_different_arity = {
|
||||||
@with_native_function_with_differentiability_info_and_key
|
@with_native_function_with_differentiability_info_and_key
|
||||||
def emit_body(
|
def emit_body(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
assert dispatch_strategy(fn) == "use_derived"
|
assert dispatch_strategy(fn) == "use_derived"
|
||||||
f = fn.func
|
f = fn.func
|
||||||
info = fn.info[key] if fn.info else None
|
info = fn.info[key] if fn.info else None
|
||||||
|
|
@ -1050,8 +1053,8 @@ def emit_body(
|
||||||
is_foreach = name.startswith("_foreach")
|
is_foreach = name.startswith("_foreach")
|
||||||
is_inplace_foreach = is_foreach and inplace
|
is_inplace_foreach = is_foreach and inplace
|
||||||
if is_inplace_foreach:
|
if is_inplace_foreach:
|
||||||
inplace_foreacharg2refarg: Dict[Argument, Argument] = {}
|
inplace_foreacharg2refarg: dict[Argument, Argument] = {}
|
||||||
refargname2inplace_foreacharg: Dict[str, Argument] = {}
|
refargname2inplace_foreacharg: dict[str, Argument] = {}
|
||||||
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
|
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
|
||||||
if info is None:
|
if info is None:
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -1077,8 +1080,8 @@ def emit_body(
|
||||||
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
|
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
|
||||||
|
|
||||||
def gen_differentiable_input(
|
def gen_differentiable_input(
|
||||||
arg: Union[Argument, SelfArgument, TensorOptionsArguments]
|
arg: Argument | SelfArgument | TensorOptionsArguments,
|
||||||
) -> Optional[DifferentiableInput]:
|
) -> DifferentiableInput | None:
|
||||||
if isinstance(arg, TensorOptionsArguments):
|
if isinstance(arg, TensorOptionsArguments):
|
||||||
return None
|
return None
|
||||||
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
|
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
|
||||||
|
|
@ -1097,7 +1100,7 @@ def emit_body(
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
|
def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]:
|
||||||
arguments = list(f.func.arguments.non_out)
|
arguments = list(f.func.arguments.non_out)
|
||||||
if is_inplace_foreach and info is not None:
|
if is_inplace_foreach and info is not None:
|
||||||
for i, arg in enumerate(f.func.arguments.flat_non_out):
|
for i, arg in enumerate(f.func.arguments.flat_non_out):
|
||||||
|
|
@ -1115,8 +1118,8 @@ def emit_body(
|
||||||
return list(mapMaybe(gen_differentiable_input, arguments))
|
return list(mapMaybe(gen_differentiable_input, arguments))
|
||||||
|
|
||||||
def find_args_with_derivatives(
|
def find_args_with_derivatives(
|
||||||
differentiable_inputs: List[DifferentiableInput],
|
differentiable_inputs: list[DifferentiableInput],
|
||||||
) -> List[DifferentiableInput]:
|
) -> list[DifferentiableInput]:
|
||||||
"""Find arguments that have derivative definitions"""
|
"""Find arguments that have derivative definitions"""
|
||||||
if info is None or not info.has_derivatives:
|
if info is None or not info.has_derivatives:
|
||||||
return differentiable_inputs
|
return differentiable_inputs
|
||||||
|
|
@ -1178,8 +1181,8 @@ def emit_body(
|
||||||
and (not returns_void)
|
and (not returns_void)
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_save_inputs() -> List[str]:
|
def emit_save_inputs() -> list[str]:
|
||||||
setup: List[str] = []
|
setup: list[str] = []
|
||||||
if info is None or not info.has_derivatives:
|
if info is None or not info.has_derivatives:
|
||||||
return setup
|
return setup
|
||||||
|
|
||||||
|
|
@ -1189,7 +1192,7 @@ def emit_body(
|
||||||
|
|
||||||
# We don't want to save tensors if we know that they will never be used
|
# We don't want to save tensors if we know that they will never be used
|
||||||
# when computing the derivative, so we add guards to those statements
|
# when computing the derivative, so we add guards to those statements
|
||||||
def guard_for(arg: SavedAttribute) -> Optional[str]:
|
def guard_for(arg: SavedAttribute) -> str | None:
|
||||||
assert info is not None
|
assert info is not None
|
||||||
|
|
||||||
# It's hard to determine the edge offset if we have TensorLists
|
# It's hard to determine the edge offset if we have TensorLists
|
||||||
|
|
@ -1276,8 +1279,8 @@ def emit_body(
|
||||||
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
|
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
|
||||||
return setup
|
return setup
|
||||||
|
|
||||||
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]:
|
def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]:
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
if is_out_fn:
|
if is_out_fn:
|
||||||
# For out functions, ensure that no input or output requires grad
|
# For out functions, ensure that no input or output requires grad
|
||||||
body.append(DECLARE_GRAD_FN.substitute(op="Node"))
|
body.append(DECLARE_GRAD_FN.substitute(op="Node"))
|
||||||
|
|
@ -1343,8 +1346,8 @@ def emit_body(
|
||||||
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
|
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
|
||||||
return body
|
return body
|
||||||
|
|
||||||
def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
|
def emit_check_if_in_complex_autograd_allowlist() -> list[str]:
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
|
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
|
||||||
return body
|
return body
|
||||||
for arg in differentiable_outputs:
|
for arg in differentiable_outputs:
|
||||||
|
|
@ -1355,11 +1358,11 @@ def emit_body(
|
||||||
return body
|
return body
|
||||||
|
|
||||||
def emit_check_no_requires_grad(
|
def emit_check_no_requires_grad(
|
||||||
tensor_args: List[DifferentiableInput],
|
tensor_args: list[DifferentiableInput],
|
||||||
args_with_derivatives: List[DifferentiableInput],
|
args_with_derivatives: list[DifferentiableInput],
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Checks that arguments without derivatives don't require grad"""
|
"""Checks that arguments without derivatives don't require grad"""
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
for arg in tensor_args:
|
for arg in tensor_args:
|
||||||
if arg in args_with_derivatives:
|
if arg in args_with_derivatives:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1373,8 +1376,8 @@ def emit_body(
|
||||||
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
|
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
|
||||||
return body
|
return body
|
||||||
|
|
||||||
def emit_original_self_definition() -> List[str]:
|
def emit_original_self_definition() -> list[str]:
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
if inplace:
|
if inplace:
|
||||||
if is_inplace_foreach:
|
if is_inplace_foreach:
|
||||||
body.append(
|
body.append(
|
||||||
|
|
@ -1412,17 +1415,17 @@ def emit_body(
|
||||||
def save_variables(
|
def save_variables(
|
||||||
saved_variables: Sequence[SavedAttribute],
|
saved_variables: Sequence[SavedAttribute],
|
||||||
is_output: bool,
|
is_output: bool,
|
||||||
guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None,
|
guard_for: Callable[[SavedAttribute], str | None] = lambda name: None,
|
||||||
) -> Sequence[str]:
|
) -> Sequence[str]:
|
||||||
# assign the saved variables to the generated grad_fn
|
# assign the saved variables to the generated grad_fn
|
||||||
stmts: List[str] = []
|
stmts: list[str] = []
|
||||||
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
|
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
|
||||||
name = (
|
name = (
|
||||||
arg.nctype.name.name
|
arg.nctype.name.name
|
||||||
if isinstance(arg.nctype.name, SpecialArgName)
|
if isinstance(arg.nctype.name, SpecialArgName)
|
||||||
else arg.nctype.name
|
else arg.nctype.name
|
||||||
)
|
)
|
||||||
foreacharg: Optional[Argument] = None
|
foreacharg: Argument | None = None
|
||||||
is_foreacharg_list_type: bool = False
|
is_foreacharg_list_type: bool = False
|
||||||
type = arg.nctype.type
|
type = arg.nctype.type
|
||||||
expr = arg.expr
|
expr = arg.expr
|
||||||
|
|
@ -1539,10 +1542,10 @@ def emit_body(
|
||||||
return call
|
return call
|
||||||
|
|
||||||
def wrap_output(
|
def wrap_output(
|
||||||
f: NativeFunction, unpacked_bindings: List[Binding], var: str
|
f: NativeFunction, unpacked_bindings: list[Binding], var: str
|
||||||
) -> str:
|
) -> str:
|
||||||
call = ""
|
call = ""
|
||||||
rhs_value: Optional[str] = None
|
rhs_value: str | None = None
|
||||||
if not any(r.type.is_tensor_like() for r in f.func.returns):
|
if not any(r.type.is_tensor_like() for r in f.func.returns):
|
||||||
rhs_value = var
|
rhs_value = var
|
||||||
else:
|
else:
|
||||||
|
|
@ -1554,11 +1557,11 @@ def emit_body(
|
||||||
return call
|
return call
|
||||||
|
|
||||||
def check_tensorimpl_and_storage(
|
def check_tensorimpl_and_storage(
|
||||||
call: str, unpacked_bindings: List[Binding]
|
call: str, unpacked_bindings: list[Binding]
|
||||||
) -> str:
|
) -> str:
|
||||||
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
|
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
|
||||||
stmts_before_call: List[str] = []
|
stmts_before_call: list[str] = []
|
||||||
stmts_after_call: List[str] = []
|
stmts_after_call: list[str] = []
|
||||||
|
|
||||||
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
|
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
|
||||||
return call
|
return call
|
||||||
|
|
@ -1665,7 +1668,7 @@ def emit_body(
|
||||||
return call
|
return call
|
||||||
|
|
||||||
def emit_call(
|
def emit_call(
|
||||||
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
|
f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
|
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
|
||||||
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
|
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
|
||||||
|
|
@ -1764,7 +1767,7 @@ def emit_body(
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def emit_any_requires_grad() -> List[str]:
|
def emit_any_requires_grad() -> list[str]:
|
||||||
extra_condition = ""
|
extra_condition = ""
|
||||||
if info and info.output_differentiability_conditions:
|
if info and info.output_differentiability_conditions:
|
||||||
assert len(info.output_differentiability_conditions) == 1
|
assert len(info.output_differentiability_conditions) == 1
|
||||||
|
|
@ -1782,14 +1785,14 @@ def emit_body(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str:
|
def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str:
|
||||||
if len(var_names) == 1:
|
if len(var_names) == 1:
|
||||||
return f"_any_has_forward_grad_{var_names[0]}"
|
return f"_any_has_forward_grad_{var_names[0]}"
|
||||||
else:
|
else:
|
||||||
return f'_any_has_forward_grad_{"_".join(var_names)}'
|
return f'_any_has_forward_grad_{"_".join(var_names)}'
|
||||||
|
|
||||||
def emit_any_has_forward_grad() -> List[str]:
|
def emit_any_has_forward_grad() -> list[str]:
|
||||||
content: List[str] = []
|
content: list[str] = []
|
||||||
if not is_foreach:
|
if not is_foreach:
|
||||||
for derivative in fw_derivatives:
|
for derivative in fw_derivatives:
|
||||||
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
||||||
|
|
@ -1844,7 +1847,7 @@ def emit_body(
|
||||||
content.append("}")
|
content.append("}")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def emit_check_inplace() -> List[str]:
|
def emit_check_inplace() -> list[str]:
|
||||||
if not inplace:
|
if not inplace:
|
||||||
return []
|
return []
|
||||||
return [
|
return [
|
||||||
|
|
@ -1852,9 +1855,9 @@ def emit_body(
|
||||||
for arg in differentiable_outputs
|
for arg in differentiable_outputs
|
||||||
]
|
]
|
||||||
|
|
||||||
def emit_fw_derivatives() -> List[str]:
|
def emit_fw_derivatives() -> list[str]:
|
||||||
content: List[str] = []
|
content: list[str] = []
|
||||||
fw_grad_setters: List[str] = []
|
fw_grad_setters: list[str] = []
|
||||||
for derivative in fw_derivatives:
|
for derivative in fw_derivatives:
|
||||||
res = derivative.var_names
|
res = derivative.var_names
|
||||||
if f.func.name.name.inplace:
|
if f.func.name.name.inplace:
|
||||||
|
|
@ -2002,7 +2005,7 @@ def emit_body(
|
||||||
"(self.size(), c10::nullopt);"
|
"(self.size(), c10::nullopt);"
|
||||||
)
|
)
|
||||||
foreach_forward_grad_formula = derivative.formula
|
foreach_forward_grad_formula = derivative.formula
|
||||||
_foreach_arg: Union[Argument, DifferentiableInput]
|
_foreach_arg: Argument | DifferentiableInput
|
||||||
if inplace:
|
if inplace:
|
||||||
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
||||||
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
||||||
|
|
@ -2044,7 +2047,7 @@ def emit_body(
|
||||||
content.append("\n".join(fw_grad_setters))
|
content.append("\n".join(fw_grad_setters))
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
|
def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str:
|
||||||
#
|
#
|
||||||
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
|
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
|
||||||
#
|
#
|
||||||
|
|
@ -2053,7 +2056,7 @@ def emit_body(
|
||||||
# - Used in the out_fn case when we want to forbid fw derivatives
|
# - Used in the out_fn case when we want to forbid fw derivatives
|
||||||
# - Used in the case where the fw_derivative is not defined, but we want
|
# - Used in the case where the fw_derivative is not defined, but we want
|
||||||
# To check if there is a decomposition registered for jvp
|
# To check if there is a decomposition registered for jvp
|
||||||
to_check: List[str] = []
|
to_check: list[str] = []
|
||||||
for inp in list(
|
for inp in list(
|
||||||
mapMaybe(
|
mapMaybe(
|
||||||
gen_differentiable_input,
|
gen_differentiable_input,
|
||||||
|
|
@ -2126,7 +2129,7 @@ def emit_body(
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
unpack_args_stats, unpacked_bindings = unpack_args(f)
|
unpack_args_stats, unpacked_bindings = unpack_args(f)
|
||||||
|
|
||||||
body.extend(unpack_args_stats)
|
body.extend(unpack_args_stats)
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,11 @@
|
||||||
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
||||||
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
|
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
|
||||||
|
|
||||||
from typing import List, Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.dispatcher as dispatcher
|
import torchgen.api.dispatcher as dispatcher
|
||||||
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
|
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
BaseCType,
|
BaseCType,
|
||||||
|
|
@ -29,6 +30,11 @@ from .gen_inplace_or_view_type import (
|
||||||
use_derived,
|
use_derived,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
|
||||||
|
|
||||||
|
|
||||||
FUNCTION_DECLARATION = CodeTemplate(
|
FUNCTION_DECLARATION = CodeTemplate(
|
||||||
"""\
|
"""\
|
||||||
#define ${uppercase_op}_AVAILABLE
|
#define ${uppercase_op}_AVAILABLE
|
||||||
|
|
@ -155,9 +161,9 @@ def returns_multi_tensor(fn: NativeFunction) -> bool:
|
||||||
# tuple: (list of getter logic strings, list of setter logic strings, string
|
# tuple: (list of getter logic strings, list of setter logic strings, string
|
||||||
# with num items expression)
|
# with num items expression)
|
||||||
def generate_state_getter_setter(
|
def generate_state_getter_setter(
|
||||||
bindings: List[Binding],
|
bindings: list[Binding],
|
||||||
state_vec_type: NamedCType,
|
state_vec_type: NamedCType,
|
||||||
) -> Tuple[List[str], List[str], str]:
|
) -> tuple[list[str], list[str], str]:
|
||||||
getter_logic = []
|
getter_logic = []
|
||||||
setter_logic = []
|
setter_logic = []
|
||||||
|
|
||||||
|
|
@ -302,7 +308,7 @@ def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
|
||||||
|
|
||||||
def gen_view_funcs(
|
def gen_view_funcs(
|
||||||
out: str,
|
out: str,
|
||||||
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
|
||||||
template_path: str,
|
template_path: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# don't need the info parts, just the function
|
# don't need the info parts, just the function
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,16 @@
|
||||||
#
|
#
|
||||||
# Each autograd function is represented by `DifferentiabilityInfo` containing
|
# Each autograd function is represented by `DifferentiabilityInfo` containing
|
||||||
# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
|
# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple
|
from typing import Any, Counter, Dict, Sequence, Set, Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
|
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
Derivative,
|
Derivative,
|
||||||
DifferentiabilityInfo,
|
DifferentiabilityInfo,
|
||||||
|
|
@ -50,9 +52,10 @@ from torchgen.model import (
|
||||||
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
|
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
|
||||||
from torchgen.yaml_utils import YamlLoader
|
from torchgen.yaml_utils import YamlLoader
|
||||||
|
|
||||||
|
|
||||||
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
|
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
|
||||||
|
|
||||||
_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {}
|
_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
|
||||||
|
|
||||||
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
|
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
|
||||||
|
|
||||||
|
|
@ -62,11 +65,11 @@ _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
|
||||||
# we generate them here instead of duplicating them in the yaml.
|
# we generate them here instead of duplicating them in the yaml.
|
||||||
# See Note [Codegen'd {view}_copy Operators]
|
# See Note [Codegen'd {view}_copy Operators]
|
||||||
def add_view_copy_derivatives(
|
def add_view_copy_derivatives(
|
||||||
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||||
view_groups: List[NativeFunctionsViewGroup],
|
view_groups: list[NativeFunctionsViewGroup],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Get the map from each view op's name to its corresponding view group
|
# Get the map from each view op's name to its corresponding view group
|
||||||
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
|
view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
|
||||||
g.view.func.name: g for g in view_groups
|
g.view.func.name: g for g in view_groups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -125,10 +128,10 @@ def load_derivatives(
|
||||||
# function schema is the complete declaration including mutability annotation / default value and etc.
|
# function schema is the complete declaration including mutability annotation / default value and etc.
|
||||||
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
|
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
|
||||||
# that are semantically related.
|
# that are semantically related.
|
||||||
functions_by_signature: Dict[
|
functions_by_signature: dict[
|
||||||
FunctionSchema, List[NativeFunction]
|
FunctionSchema, list[NativeFunction]
|
||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
functions_by_schema: Dict[str, NativeFunction] = {}
|
functions_by_schema: dict[str, NativeFunction] = {}
|
||||||
for function in native_functions:
|
for function in native_functions:
|
||||||
functions_by_signature[function.func.signature()].append(function)
|
functions_by_signature[function.func.signature()].append(function)
|
||||||
assert str(function.func) not in functions_by_schema
|
assert str(function.func) not in functions_by_schema
|
||||||
|
|
@ -141,8 +144,8 @@ def load_derivatives(
|
||||||
# infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
|
# infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
|
||||||
# this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
|
# this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
|
||||||
# we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
|
# we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
|
||||||
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {}
|
infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
|
||||||
used_dispatch_keys: Set[str] = set()
|
used_dispatch_keys: set[str] = set()
|
||||||
for defn_dict in definitions:
|
for defn_dict in definitions:
|
||||||
# Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
|
# Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
|
||||||
if "dispatch" not in defn_dict:
|
if "dispatch" not in defn_dict:
|
||||||
|
|
@ -185,11 +188,11 @@ def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
|
||||||
def create_derivative(
|
def create_derivative(
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
formula: str,
|
formula: str,
|
||||||
var_names: Tuple[str, ...],
|
var_names: tuple[str, ...],
|
||||||
available_named_gradients: Sequence[str],
|
available_named_gradients: Sequence[str],
|
||||||
) -> Derivative:
|
) -> Derivative:
|
||||||
original_formula = formula
|
original_formula = formula
|
||||||
arguments: List[NamedCType] = [
|
arguments: list[NamedCType] = [
|
||||||
a.nctype.remove_const_ref() for a in cpp_arguments(f)
|
a.nctype.remove_const_ref() for a in cpp_arguments(f)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -230,10 +233,10 @@ def create_derivative(
|
||||||
|
|
||||||
|
|
||||||
def create_forward_derivative(
|
def create_forward_derivative(
|
||||||
f: NativeFunction, formula: str, names: Tuple[str, ...]
|
f: NativeFunction, formula: str, names: tuple[str, ...]
|
||||||
) -> ForwardDerivative:
|
) -> ForwardDerivative:
|
||||||
var_names = names
|
var_names = names
|
||||||
var_types: Optional[Tuple[Type, ...]] = None
|
var_types: tuple[Type, ...] | None = None
|
||||||
for r in f.func.returns:
|
for r in f.func.returns:
|
||||||
if r.name in var_names:
|
if r.name in var_names:
|
||||||
if var_types is None:
|
if var_types is None:
|
||||||
|
|
@ -269,12 +272,12 @@ def create_forward_derivative(
|
||||||
def postprocess_forward_derivatives(
|
def postprocess_forward_derivatives(
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
defn_name: str,
|
defn_name: str,
|
||||||
all_arg_names: List[str],
|
all_arg_names: list[str],
|
||||||
derivatives: List[Derivative],
|
derivatives: list[Derivative],
|
||||||
forward_derivatives: List[ForwardDerivative],
|
forward_derivatives: list[ForwardDerivative],
|
||||||
args_with_derivatives: Sequence[Binding],
|
args_with_derivatives: Sequence[Binding],
|
||||||
) -> List[ForwardDerivative]:
|
) -> list[ForwardDerivative]:
|
||||||
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
|
def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
|
||||||
is_foreach = f.func.name.name.base.startswith("_foreach_")
|
is_foreach = f.func.name.name.base.startswith("_foreach_")
|
||||||
required_inputs = set()
|
required_inputs = set()
|
||||||
for arg in args_with_derivatives:
|
for arg in args_with_derivatives:
|
||||||
|
|
@ -300,7 +303,7 @@ def postprocess_forward_derivatives(
|
||||||
|
|
||||||
return tuple(required_inputs)
|
return tuple(required_inputs)
|
||||||
|
|
||||||
updated_derivatives: List[ForwardDerivative] = []
|
updated_derivatives: list[ForwardDerivative] = []
|
||||||
|
|
||||||
for defn in forward_derivatives:
|
for defn in forward_derivatives:
|
||||||
formula = defn.formula
|
formula = defn.formula
|
||||||
|
|
@ -430,7 +433,7 @@ def postprocess_forward_derivatives(
|
||||||
|
|
||||||
|
|
||||||
def is_forward_derivative_definition(
|
def is_forward_derivative_definition(
|
||||||
all_arg_names: List[str], names: Tuple[str, ...]
|
all_arg_names: list[str], names: tuple[str, ...]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
for name in names:
|
for name in names:
|
||||||
if name not in all_arg_names:
|
if name not in all_arg_names:
|
||||||
|
|
@ -441,12 +444,12 @@ def is_forward_derivative_definition(
|
||||||
|
|
||||||
|
|
||||||
def create_differentiability_info(
|
def create_differentiability_info(
|
||||||
defn_dict: Dict[Any, Any],
|
defn_dict: dict[Any, Any],
|
||||||
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
|
functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
|
||||||
functions_by_schema: Dict[str, NativeFunction],
|
functions_by_schema: dict[str, NativeFunction],
|
||||||
op_counter: Counter[str],
|
op_counter: Counter[str],
|
||||||
used_dispatch_keys: Set[str],
|
used_dispatch_keys: set[str],
|
||||||
) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]:
|
) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
|
||||||
"""Processes a single entry `defn` in derivatives.yaml"""
|
"""Processes a single entry `defn` in derivatives.yaml"""
|
||||||
|
|
||||||
def canonical_function(
|
def canonical_function(
|
||||||
|
|
@ -463,7 +466,7 @@ def create_differentiability_info(
|
||||||
assert name + "_" == cpp.name(functions[0].func)
|
assert name + "_" == cpp.name(functions[0].func)
|
||||||
return functions[0]
|
return functions[0]
|
||||||
|
|
||||||
def split_names(raw_names: str) -> Tuple[str, ...]:
|
def split_names(raw_names: str) -> tuple[str, ...]:
|
||||||
"""Given "foo, bar", return ["foo", "bar"]."""
|
"""Given "foo, bar", return ["foo", "bar"]."""
|
||||||
return tuple(x.strip() for x in raw_names.split(","))
|
return tuple(x.strip() for x in raw_names.split(","))
|
||||||
|
|
||||||
|
|
@ -477,7 +480,7 @@ def create_differentiability_info(
|
||||||
uses_grad = False # true if any derivative uses "grad"
|
uses_grad = False # true if any derivative uses "grad"
|
||||||
num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
|
num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
|
||||||
uses_named_grads = False # true if any derivative uses "grad_{name}"
|
uses_named_grads = False # true if any derivative uses "grad_{name}"
|
||||||
used_grads_indices: List[int] = [] # which indices of grads are used
|
used_grads_indices: list[int] = [] # which indices of grads are used
|
||||||
for d in derivatives:
|
for d in derivatives:
|
||||||
formula = d.formula
|
formula = d.formula
|
||||||
uses_grad = uses_grad or bool(
|
uses_grad = uses_grad or bool(
|
||||||
|
|
@ -521,7 +524,7 @@ def create_differentiability_info(
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def set_up_derivatives(
|
def set_up_derivatives(
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Sequence[Derivative],
|
Sequence[Derivative],
|
||||||
Sequence[ForwardDerivative],
|
Sequence[ForwardDerivative],
|
||||||
Sequence[Binding],
|
Sequence[Binding],
|
||||||
|
|
@ -529,10 +532,10 @@ def create_differentiability_info(
|
||||||
Sequence[str],
|
Sequence[str],
|
||||||
]:
|
]:
|
||||||
# Set up the derivative information
|
# Set up the derivative information
|
||||||
derivatives: List[Derivative] = []
|
derivatives: list[Derivative] = []
|
||||||
forward_derivatives: List[ForwardDerivative] = []
|
forward_derivatives: list[ForwardDerivative] = []
|
||||||
non_differentiable_arg_names: List[str] = []
|
non_differentiable_arg_names: list[str] = []
|
||||||
args_with_derivatives_set: Set[str] = set()
|
args_with_derivatives_set: set[str] = set()
|
||||||
|
|
||||||
all_arg_names = [a.name for a in cpp_arguments(f)]
|
all_arg_names = [a.name for a in cpp_arguments(f)]
|
||||||
all_ret_names = [
|
all_ret_names = [
|
||||||
|
|
@ -699,7 +702,7 @@ def create_differentiability_info(
|
||||||
available_named_gradients,
|
available_named_gradients,
|
||||||
) = set_up_derivatives(canonical)
|
) = set_up_derivatives(canonical)
|
||||||
|
|
||||||
used_named_gradients: Set[str] = set()
|
used_named_gradients: set[str] = set()
|
||||||
for d in derivatives:
|
for d in derivatives:
|
||||||
used_named_gradients |= d.named_gradients
|
used_named_gradients |= d.named_gradients
|
||||||
|
|
||||||
|
|
@ -738,7 +741,7 @@ def create_differentiability_info(
|
||||||
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
|
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
|
||||||
|
|
||||||
|
|
||||||
def used_gradient_indices(formula: str) -> List[int]:
|
def used_gradient_indices(formula: str) -> list[int]:
|
||||||
"""Determine a list of gradient indices (the i in grads[i]) that
|
"""Determine a list of gradient indices (the i in grads[i]) that
|
||||||
are used by the formula.
|
are used by the formula.
|
||||||
|
|
||||||
|
|
@ -750,9 +753,9 @@ def used_gradient_indices(formula: str) -> List[int]:
|
||||||
|
|
||||||
def saved_variables(
|
def saved_variables(
|
||||||
formula: str,
|
formula: str,
|
||||||
nctypes: List[NamedCType],
|
nctypes: list[NamedCType],
|
||||||
var_names: Tuple[str, ...],
|
var_names: tuple[str, ...],
|
||||||
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
|
) -> tuple[str, tuple[SavedAttribute, ...]]:
|
||||||
def stride_expr(name: str) -> str:
|
def stride_expr(name: str) -> str:
|
||||||
assert var_names == (name,), (
|
assert var_names == (name,), (
|
||||||
'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
|
'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
|
||||||
|
|
@ -760,7 +763,7 @@ def saved_variables(
|
||||||
)
|
)
|
||||||
return f'strides_or_error({name}, "{name}")'
|
return f'strides_or_error({name}, "{name}")'
|
||||||
|
|
||||||
REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
|
REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
|
||||||
# replace self.sym_sizes() with self_sym_sizes
|
# replace self.sym_sizes() with self_sym_sizes
|
||||||
(
|
(
|
||||||
r"{}.sym_sizes\(\)",
|
r"{}.sym_sizes\(\)",
|
||||||
|
|
@ -914,7 +917,7 @@ def saved_variables(
|
||||||
]
|
]
|
||||||
|
|
||||||
# find which arguments need to be saved
|
# find which arguments need to be saved
|
||||||
saved: List[SavedAttribute] = []
|
saved: list[SavedAttribute] = []
|
||||||
|
|
||||||
if ".sizes()" in formula or "->sizes()" in formula:
|
if ".sizes()" in formula or "->sizes()" in formula:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -941,7 +944,7 @@ def saved_variables(
|
||||||
# when the autograd Function is created to avoid saving variables
|
# when the autograd Function is created to avoid saving variables
|
||||||
for regex, info in REPLACEMENTS:
|
for regex, info in REPLACEMENTS:
|
||||||
|
|
||||||
def repl(m: Match[str]) -> str:
|
def repl(m: re.Match[str]) -> str:
|
||||||
suffix: str = (
|
suffix: str = (
|
||||||
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
|
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
|
||||||
)
|
)
|
||||||
|
|
@ -999,8 +1002,8 @@ def _create_op_prefix(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
|
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
|
||||||
seen: Set[str] = set()
|
seen: set[str] = set()
|
||||||
saved: List[SavedAttribute] = []
|
saved: list[SavedAttribute] = []
|
||||||
for var in vars:
|
for var in vars:
|
||||||
name = (
|
name = (
|
||||||
var.nctype.name.name
|
var.nctype.name.name
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from setuptools import distutils # type: ignore[import]
|
from setuptools import distutils # type: ignore[import]
|
||||||
|
|
||||||
from .setup_helpers.cmake import CMake, USE_NINJA
|
from .setup_helpers.cmake import CMake, USE_NINJA
|
||||||
|
|
||||||
from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS
|
from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS
|
||||||
|
|
||||||
|
|
||||||
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
def _overlay_windows_vcvars(env: dict[str, str]) -> dict[str, str]:
|
||||||
vc_arch = "x64" if IS_64BIT else "x86"
|
vc_arch = "x64" if IS_64BIT else "x86"
|
||||||
|
|
||||||
if platform.machine() == "ARM64":
|
if platform.machine() == "ARM64":
|
||||||
|
|
@ -34,7 +34,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
||||||
"emulation is enabled!"
|
"emulation is enabled!"
|
||||||
)
|
)
|
||||||
|
|
||||||
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
|
vc_env: dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
|
||||||
# Keys in `_get_vc_env` are always lowercase.
|
# Keys in `_get_vc_env` are always lowercase.
|
||||||
# We turn them into uppercase before overlaying vcvars
|
# We turn them into uppercase before overlaying vcvars
|
||||||
# because OS environ keys are always uppercase on Windows.
|
# because OS environ keys are always uppercase on Windows.
|
||||||
|
|
@ -47,7 +47,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
||||||
return vc_env
|
return vc_env
|
||||||
|
|
||||||
|
|
||||||
def _create_build_env() -> Dict[str, str]:
|
def _create_build_env() -> dict[str, str]:
|
||||||
# XXX - our cmake file sometimes looks at the system environment
|
# XXX - our cmake file sometimes looks at the system environment
|
||||||
# and not cmake flags!
|
# and not cmake flags!
|
||||||
# you should NEVER add something to this list. It is bad practice to
|
# you should NEVER add something to this list. It is bad practice to
|
||||||
|
|
@ -72,8 +72,8 @@ def _create_build_env() -> Dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
def build_caffe2(
|
def build_caffe2(
|
||||||
version: Optional[str],
|
version: str | None,
|
||||||
cmake_python_library: Optional[str],
|
cmake_python_library: str | None,
|
||||||
build_python: bool,
|
build_python: bool,
|
||||||
rerun_cmake: bool,
|
rerun_cmake: bool,
|
||||||
cmake_only: bool,
|
cmake_only: bool,
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,13 @@
|
||||||
# - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh
|
# - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh
|
||||||
# - Copy libs from build/lib to torch/lib folder
|
# - Copy libs from build/lib to torch/lib folder
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent
|
PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent
|
||||||
TORCH_DIR = PYTORCH_ROOTDIR / "torch"
|
TORCH_DIR = PYTORCH_ROOTDIR / "torch"
|
||||||
|
|
@ -17,7 +20,7 @@ BUILD_DIR = PYTORCH_ROOTDIR / "build"
|
||||||
BUILD_LIB_DIR = BUILD_DIR / "lib"
|
BUILD_LIB_DIR = BUILD_DIR / "lib"
|
||||||
|
|
||||||
|
|
||||||
def check_output(args: List[str], cwd: Optional[str] = None) -> str:
|
def check_output(args: list[str], cwd: str | None = None) -> str:
|
||||||
return subprocess.check_output(args, cwd=cwd).decode("utf-8")
|
return subprocess.check_output(args, cwd=cwd).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,7 +66,7 @@ def is_devel_setup() -> bool:
|
||||||
return output.strip() == str(TORCH_DIR / "__init__.py")
|
return output.strip() == str(TORCH_DIR / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
def create_build_plan() -> List[Tuple[str, str]]:
|
def create_build_plan() -> list[tuple[str, str]]:
|
||||||
output = check_output(
|
output = check_output(
|
||||||
["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR)
|
["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,15 @@ For custom build with static dispatch, the op dependency graph will be omitted,
|
||||||
and it will directly output root ops as the allowlist.
|
and it will directly output root ops as the allowlist.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, Set
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
DepGraph = Dict[str, Set[str]]
|
DepGraph = Dict[str, Set[str]]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph:
|
||||||
return dict(result)
|
return dict(result)
|
||||||
|
|
||||||
|
|
||||||
def load_root_ops(fname: str) -> List[str]:
|
def load_root_ops(fname: str) -> list[str]:
|
||||||
result = []
|
result = []
|
||||||
with open(fname) as stream:
|
with open(fname) as stream:
|
||||||
for op in yaml.safe_load(stream):
|
for op in yaml.safe_load(stream):
|
||||||
|
|
@ -44,9 +46,9 @@ def load_root_ops(fname: str) -> List[str]:
|
||||||
|
|
||||||
def gen_transitive_closure(
|
def gen_transitive_closure(
|
||||||
dep_graph: DepGraph,
|
dep_graph: DepGraph,
|
||||||
root_ops: List[str],
|
root_ops: list[str],
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
result = set(root_ops)
|
result = set(root_ops)
|
||||||
queue = root_ops.copy()
|
queue = root_ops.copy()
|
||||||
|
|
||||||
|
|
@ -73,7 +75,7 @@ def gen_transitive_closure(
|
||||||
return sorted(result)
|
return sorted(result)
|
||||||
|
|
||||||
|
|
||||||
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str:
|
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str:
|
||||||
return " ".join(gen_transitive_closure(dep_graph, root_ops))
|
return " ".join(gen_transitive_closure(dep_graph, root_ops))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from gen_op_registration_allowlist import (
|
from gen_op_registration_allowlist import (
|
||||||
|
|
@ -17,6 +20,7 @@ from torchgen.selective_build.operator import (
|
||||||
)
|
)
|
||||||
from torchgen.selective_build.selector import merge_kernel_metadata
|
from torchgen.selective_build.selector import merge_kernel_metadata
|
||||||
|
|
||||||
|
|
||||||
# Generate YAML file containing the operators used for a specific PyTorch model.
|
# Generate YAML file containing the operators used for a specific PyTorch model.
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
|
|
@ -84,17 +88,17 @@ from torchgen.selective_build.selector import merge_kernel_metadata
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def canonical_opnames(opnames: List[str]) -> List[str]:
|
def canonical_opnames(opnames: list[str]) -> list[str]:
|
||||||
return [canonical_name(opname) for opname in opnames]
|
return [canonical_name(opname) for opname in opnames]
|
||||||
|
|
||||||
|
|
||||||
def make_filter_from_options(
|
def make_filter_from_options(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_versions: List[str],
|
model_versions: list[str],
|
||||||
model_assets: Optional[List[str]],
|
model_assets: list[str] | None,
|
||||||
model_backends: Optional[List[str]],
|
model_backends: list[str] | None,
|
||||||
):
|
):
|
||||||
def is_model_included(model_info):
|
def is_model_included(model_info) -> bool:
|
||||||
model = model_info["model"]
|
model = model_info["model"]
|
||||||
if model["name"] != model_name:
|
if model["name"] != model_name:
|
||||||
return False
|
return False
|
||||||
|
|
@ -109,7 +113,7 @@ def make_filter_from_options(
|
||||||
|
|
||||||
|
|
||||||
# Returns if a the specified rule is a new or old style pt_operator_library
|
# Returns if a the specified rule is a new or old style pt_operator_library
|
||||||
def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
|
def is_new_style_rule(model_name: str, model_versions: list[str] | None):
|
||||||
return model_name is not None and model_versions is not None
|
return model_name is not None and model_versions is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -117,13 +121,13 @@ def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
|
||||||
# appear in at least one model yaml. Throws if verification is failed,
|
# appear in at least one model yaml. Throws if verification is failed,
|
||||||
# returns None on success
|
# returns None on success
|
||||||
def verify_all_specified_present(
|
def verify_all_specified_present(
|
||||||
model_assets: Optional[List[str]],
|
model_assets: list[str] | None,
|
||||||
model_versions: List[str],
|
model_versions: list[str],
|
||||||
selected_models_yaml: List[Dict[str, Any]],
|
selected_models_yaml: list[dict[str, Any]],
|
||||||
rule_name: str,
|
rule_name: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
new_style_rule: bool,
|
new_style_rule: bool,
|
||||||
):
|
) -> None:
|
||||||
def find_missing_items(model_items, key, selected_models_yaml):
|
def find_missing_items(model_items, key, selected_models_yaml):
|
||||||
missing_items = []
|
missing_items = []
|
||||||
if not new_style_rule or not model_items:
|
if not new_style_rule or not model_items:
|
||||||
|
|
@ -179,10 +183,10 @@ def verify_all_specified_present(
|
||||||
# Uses the selected models configs and then combines them into one dictionary,
|
# Uses the selected models configs and then combines them into one dictionary,
|
||||||
# formats them as a string, and places the string into output as a top level debug_info
|
# formats them as a string, and places the string into output as a top level debug_info
|
||||||
def create_debug_info_from_selected_models(
|
def create_debug_info_from_selected_models(
|
||||||
output: Dict[str, object],
|
output: dict[str, object],
|
||||||
selected_models: List[dict],
|
selected_models: list[dict],
|
||||||
new_style_rule: bool,
|
new_style_rule: bool,
|
||||||
):
|
) -> None:
|
||||||
model_dict = {
|
model_dict = {
|
||||||
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes
|
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes
|
||||||
"is_new_style_rule": new_style_rule,
|
"is_new_style_rule": new_style_rule,
|
||||||
|
|
@ -201,7 +205,7 @@ def create_debug_info_from_selected_models(
|
||||||
output["debug_info"] = [json.dumps(model_dict)]
|
output["debug_info"] = [json.dumps(model_dict)]
|
||||||
|
|
||||||
|
|
||||||
def fill_output(output: Dict[str, object], options: object):
|
def fill_output(output: dict[str, object], options: object) -> None:
|
||||||
"""Populate the output dict with the information required to serialize
|
"""Populate the output dict with the information required to serialize
|
||||||
the YAML file used for selective build.
|
the YAML file used for selective build.
|
||||||
"""
|
"""
|
||||||
|
|
@ -458,7 +462,7 @@ def fill_output(output: Dict[str, object], options: object):
|
||||||
# END TRACING BASED BUILD OPS
|
# END TRACING BASED BUILD OPS
|
||||||
|
|
||||||
# Merge dictionaries together to remove op duplication
|
# Merge dictionaries together to remove op duplication
|
||||||
operators: Dict[str, SelectiveBuildOperator] = {}
|
operators: dict[str, SelectiveBuildOperator] = {}
|
||||||
for ops_dict in bucketed_ops:
|
for ops_dict in bucketed_ops:
|
||||||
operators = merge_operator_dicts(operators, ops_dict)
|
operators = merge_operator_dicts(operators, ops_dict)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Any, List, Set
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from tools.lite_interpreter.gen_selected_mobile_ops_header import (
|
from tools.lite_interpreter.gen_selected_mobile_ops_header import (
|
||||||
|
|
@ -17,11 +20,11 @@ from torchgen.selective_build.selector import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:
|
||||||
return set(selective_builder.operators.keys())
|
return set(selective_builder.operators.keys())
|
||||||
|
|
||||||
|
|
||||||
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:
|
||||||
ops = []
|
ops = []
|
||||||
for op_name, op in selective_builder.operators.items():
|
for op_name, op in selective_builder.operators.items():
|
||||||
if op.is_used_for_training:
|
if op.is_used_for_training:
|
||||||
|
|
@ -44,7 +47,7 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
|
def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:
|
||||||
supported_mobile_models_source = """/*
|
supported_mobile_models_source = """/*
|
||||||
* Generated by gen_oplist.py
|
* Generated by gen_oplist.py
|
||||||
*/
|
*/
|
||||||
|
|
@ -87,7 +90,7 @@ SupportedMobileModelCheckerRegistry register_model_versions;
|
||||||
out_file.write(source.encode("utf-8"))
|
out_file.write(source.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def main(argv: List[Any]) -> None:
|
def main(argv: list[Any]) -> None:
|
||||||
"""This binary generates 3 files:
|
"""This binary generates 3 files:
|
||||||
|
|
||||||
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
|
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import cast, List, Optional, Tuple
|
from typing import cast
|
||||||
|
|
||||||
from ..util.setting import (
|
from ..util.setting import (
|
||||||
CompilerType,
|
CompilerType,
|
||||||
|
|
@ -38,7 +40,7 @@ BLOCKED_PYTHON_TESTS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def initialization() -> Tuple[Option, TestList, List[str]]:
|
def initialization() -> tuple[Option, TestList, list[str]]:
|
||||||
# create folder if not exists
|
# create folder if not exists
|
||||||
create_folders()
|
create_folders()
|
||||||
# add arguments
|
# add arguments
|
||||||
|
|
@ -77,7 +79,7 @@ def add_arguments_oss(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
|
||||||
|
|
||||||
def parse_arguments(
|
def parse_arguments(
|
||||||
parser: argparse.ArgumentParser,
|
parser: argparse.ArgumentParser,
|
||||||
) -> Tuple[Option, Optional[List[str]], Optional[List[str]], Optional[bool]]:
|
) -> tuple[Option, list[str] | None, list[str] | None, bool | None]:
|
||||||
# parse args
|
# parse args
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# get option
|
# get option
|
||||||
|
|
@ -85,9 +87,7 @@ def parse_arguments(
|
||||||
return (options, args.interest_only, args.run_only, args.clean)
|
return (options, args.interest_only, args.run_only, args.clean)
|
||||||
|
|
||||||
|
|
||||||
def get_test_list_by_type(
|
def get_test_list_by_type(run_only: list[str] | None, test_type: TestType) -> TestList:
|
||||||
run_only: Optional[List[str]], test_type: TestType
|
|
||||||
) -> TestList:
|
|
||||||
test_list: TestList = []
|
test_list: TestList = []
|
||||||
binary_folder = get_oss_binary_folder(test_type)
|
binary_folder = get_oss_binary_folder(test_type)
|
||||||
g = os.walk(binary_folder)
|
g = os.walk(binary_folder)
|
||||||
|
|
@ -106,7 +106,7 @@ def get_test_list_by_type(
|
||||||
return test_list
|
return test_list
|
||||||
|
|
||||||
|
|
||||||
def get_test_list(run_only: Optional[List[str]]) -> TestList:
|
def get_test_list(run_only: list[str] | None) -> TestList:
|
||||||
test_list: TestList = []
|
test_list: TestList = []
|
||||||
# add c++ test list
|
# add c++ test list
|
||||||
test_list.extend(get_test_list_by_type(run_only, TestType.CPP))
|
test_list.extend(get_test_list_by_type(run_only, TestType.CPP))
|
||||||
|
|
@ -122,7 +122,7 @@ def get_test_list(run_only: Optional[List[str]]) -> TestList:
|
||||||
return test_list
|
return test_list
|
||||||
|
|
||||||
|
|
||||||
def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]:
|
def empty_list_if_none(arg_interested_folder: list[str] | None) -> list[str]:
|
||||||
if arg_interested_folder is None:
|
if arg_interested_folder is None:
|
||||||
return []
|
return []
|
||||||
# if this argument is specified, just return itself
|
# if this argument is specified, just return itself
|
||||||
|
|
@ -134,7 +134,7 @@ def gcc_export_init() -> None:
|
||||||
create_folder(JSON_FOLDER_BASE_DIR)
|
create_folder(JSON_FOLDER_BASE_DIR)
|
||||||
|
|
||||||
|
|
||||||
def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
|
def get_python_run_only(args_run_only: list[str] | None) -> list[str]:
|
||||||
# if user specifies run-only option
|
# if user specifies run-only option
|
||||||
if args_run_only:
|
if args_run_only:
|
||||||
return args_run_only
|
return args_run_only
|
||||||
|
|
@ -144,7 +144,7 @@ def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
|
||||||
return ["run_test.py"]
|
return ["run_test.py"]
|
||||||
else:
|
else:
|
||||||
# for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them
|
# for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them
|
||||||
run_only: List[str] = []
|
run_only: list[str] = []
|
||||||
binary_folder = get_oss_binary_folder(TestType.PY)
|
binary_folder = get_oss_binary_folder(TestType.PY)
|
||||||
g = os.walk(binary_folder)
|
g = os.walk(binary_folder)
|
||||||
for _, _, file_list in g:
|
for _, _, file_list in g:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from ..util.setting import CompilerType, TestType, TOOLS_FOLDER
|
from ..util.setting import CompilerType, TestType, TOOLS_FOLDER
|
||||||
from ..util.utils import print_error, remove_file
|
from ..util.utils import print_error, remove_file
|
||||||
|
|
@ -14,7 +15,7 @@ def get_oss_binary_folder(test_type: TestType) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_oss_shared_library() -> List[str]:
|
def get_oss_shared_library() -> list[str]:
|
||||||
lib_dir = os.path.join(get_pytorch_folder(), "build", "lib")
|
lib_dir = os.path.join(get_pytorch_folder(), "build", "lib")
|
||||||
return [
|
return [
|
||||||
os.path.join(lib_dir, lib)
|
os.path.join(lib_dir, lib)
|
||||||
|
|
@ -48,7 +49,7 @@ def get_pytorch_folder() -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def detect_compiler_type() -> Optional[CompilerType]:
|
def detect_compiler_type() -> CompilerType | None:
|
||||||
# check if user specifies the compiler type
|
# check if user specifies the compiler type
|
||||||
user_specify = os.environ.get("CXX", None)
|
user_specify = os.environ.get("CXX", None)
|
||||||
if user_specify:
|
if user_specify:
|
||||||
|
|
@ -76,7 +77,7 @@ def clean_up_gcda() -> None:
|
||||||
remove_file(item)
|
remove_file(item)
|
||||||
|
|
||||||
|
|
||||||
def get_gcda_files() -> List[str]:
|
def get_gcda_files() -> list[str]:
|
||||||
folder_has_gcda = os.path.join(get_pytorch_folder(), "build")
|
folder_has_gcda = os.path.join(get_pytorch_folder(), "build")
|
||||||
if os.path.isdir(folder_has_gcda):
|
if os.path.isdir(folder_has_gcda):
|
||||||
# TODO use glob
|
# TODO use glob
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from ..util.setting import (
|
from ..util.setting import (
|
||||||
JSON_FOLDER_BASE_DIR,
|
JSON_FOLDER_BASE_DIR,
|
||||||
|
|
@ -25,7 +26,7 @@ from .utils import get_tool_path_by_platform, run_cpp_test
|
||||||
|
|
||||||
|
|
||||||
def create_corresponding_folder(
|
def create_corresponding_folder(
|
||||||
cur_path: str, prefix_cur_path: str, dir_list: List[str], new_base_folder: str
|
cur_path: str, prefix_cur_path: str, dir_list: list[str], new_base_folder: str
|
||||||
) -> None:
|
) -> None:
|
||||||
for dir_name in dir_list:
|
for dir_name in dir_list:
|
||||||
relative_path = convert_to_relative_path(
|
relative_path = convert_to_relative_path(
|
||||||
|
|
@ -70,7 +71,7 @@ def export_target(
|
||||||
merged_file: str,
|
merged_file: str,
|
||||||
json_file: str,
|
json_file: str,
|
||||||
binary_file: str,
|
binary_file: str,
|
||||||
shared_library_list: List[str],
|
shared_library_list: list[str],
|
||||||
platform_type: TestPlatform,
|
platform_type: TestPlatform,
|
||||||
) -> None:
|
) -> None:
|
||||||
if binary_file is None:
|
if binary_file is None:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
# gcc is only used in oss
|
# gcc is only used in oss
|
||||||
from ..oss.utils import get_gcda_files, run_oss_python_test
|
from ..oss.utils import get_gcda_files, run_oss_python_test
|
||||||
|
|
@ -10,7 +11,7 @@ from ..util.utils import print_log, print_time
|
||||||
from .utils import run_cpp_test
|
from .utils import run_cpp_test
|
||||||
|
|
||||||
|
|
||||||
def update_gzip_dict(gzip_dict: Dict[str, int], file_name: str) -> str:
|
def update_gzip_dict(gzip_dict: dict[str, int], file_name: str) -> str:
|
||||||
file_name = file_name.lower()
|
file_name = file_name.lower()
|
||||||
gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1
|
gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1
|
||||||
num = gzip_dict[file_name]
|
num = gzip_dict[file_name]
|
||||||
|
|
@ -34,7 +35,7 @@ def export() -> None:
|
||||||
# collect .gcda files
|
# collect .gcda files
|
||||||
gcda_files = get_gcda_files()
|
gcda_files = get_gcda_files()
|
||||||
# file name like utils.cpp may have same name in different folder
|
# file name like utils.cpp may have same name in different folder
|
||||||
gzip_dict: Dict[str, int] = {}
|
gzip_dict: dict[str, int] = {}
|
||||||
for gcda_item in gcda_files:
|
for gcda_item in gcda_files:
|
||||||
# generate json.gz
|
# generate json.gz
|
||||||
subprocess.check_call(["gcov", "-i", gcda_item])
|
subprocess.check_call(["gcov", "-i", gcda_item])
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
import typing as t
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
class CoverageRecord(t.NamedTuple):
|
class CoverageRecord(NamedTuple):
|
||||||
filepath: str
|
filepath: str
|
||||||
covered_lines: t.List[int]
|
covered_lines: list[int]
|
||||||
uncovered_lines: t.Optional[t.List[int]] = None
|
uncovered_lines: list[int] | None = None
|
||||||
|
|
||||||
def to_dict(self) -> t.Dict[str, t.Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"filepath": self.filepath,
|
"filepath": self.filepath,
|
||||||
"covered_lines": self.covered_lines,
|
"covered_lines": self.covered_lines,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Any, Dict, List, Set
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .coverage_record import CoverageRecord
|
from .coverage_record import CoverageRecord
|
||||||
|
|
||||||
|
|
@ -10,7 +12,7 @@ class GcovCoverageParser:
|
||||||
of CoverageRecord(s).
|
of CoverageRecord(s).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llvm_coverage: Dict[str, Any]) -> None:
|
def __init__(self, llvm_coverage: dict[str, Any]) -> None:
|
||||||
self._llvm_coverage = llvm_coverage
|
self._llvm_coverage = llvm_coverage
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -24,17 +26,17 @@ class GcovCoverageParser:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def parse(self) -> List[CoverageRecord]:
|
def parse(self) -> list[CoverageRecord]:
|
||||||
# The JSON format is described in the gcov source code
|
# The JSON format is described in the gcov source code
|
||||||
# https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html
|
# https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html
|
||||||
records: List[CoverageRecord] = []
|
records: list[CoverageRecord] = []
|
||||||
for file_info in self._llvm_coverage["files"]:
|
for file_info in self._llvm_coverage["files"]:
|
||||||
filepath = file_info["file"]
|
filepath = file_info["file"]
|
||||||
if self._skip_coverage(filepath):
|
if self._skip_coverage(filepath):
|
||||||
continue
|
continue
|
||||||
# parse json file
|
# parse json file
|
||||||
covered_lines: Set[int] = set()
|
covered_lines: set[int] = set()
|
||||||
uncovered_lines: Set[int] = set()
|
uncovered_lines: set[int] = set()
|
||||||
for line in file_info["lines"]:
|
for line in file_info["lines"]:
|
||||||
line_number = line["line_number"]
|
line_number = line["line_number"]
|
||||||
count = line["count"]
|
count = line["count"]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .coverage_record import CoverageRecord
|
from .coverage_record import CoverageRecord
|
||||||
from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments
|
from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments
|
||||||
|
|
@ -12,7 +14,7 @@ class LlvmCoverageParser:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llvm_coverage: Dict[str, Any]) -> None:
|
def __init__(self, llvm_coverage: dict[str, Any]) -> None:
|
||||||
self._llvm_coverage = llvm_coverage
|
self._llvm_coverage = llvm_coverage
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -28,13 +30,13 @@ class LlvmCoverageParser:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _collect_coverage(
|
def _collect_coverage(
|
||||||
segments: List[LlvmCoverageSegment],
|
segments: list[LlvmCoverageSegment],
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> tuple[list[int], list[int]]:
|
||||||
"""
|
"""
|
||||||
Stateful parsing of coverage segments.
|
Stateful parsing of coverage segments.
|
||||||
"""
|
"""
|
||||||
covered_lines: Set[int] = set()
|
covered_lines: set[int] = set()
|
||||||
uncovered_lines: Set[int] = set()
|
uncovered_lines: set[int] = set()
|
||||||
prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None)
|
prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None)
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
covered_range, uncovered_range = segment.get_coverage(prev_segment)
|
covered_range, uncovered_range = segment.get_coverage(prev_segment)
|
||||||
|
|
@ -45,10 +47,10 @@ class LlvmCoverageParser:
|
||||||
uncovered_lines.difference_update(covered_lines)
|
uncovered_lines.difference_update(covered_lines)
|
||||||
return sorted(covered_lines), sorted(uncovered_lines)
|
return sorted(covered_lines), sorted(uncovered_lines)
|
||||||
|
|
||||||
def parse(self, repo_name: str) -> List[CoverageRecord]:
|
def parse(self, repo_name: str) -> list[CoverageRecord]:
|
||||||
# The JSON format is described in the LLVM source code
|
# The JSON format is described in the LLVM source code
|
||||||
# https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp
|
# https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp
|
||||||
records: List[CoverageRecord] = []
|
records: list[CoverageRecord] = []
|
||||||
for export_unit in self._llvm_coverage["data"]:
|
for export_unit in self._llvm_coverage["data"]:
|
||||||
for file_info in export_unit["files"]:
|
for file_info in export_unit["files"]:
|
||||||
filepath = file_info["filename"]
|
filepath = file_info["filename"]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import List, NamedTuple, Optional, Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
class LlvmCoverageSegment(NamedTuple):
|
class LlvmCoverageSegment(NamedTuple):
|
||||||
|
|
@ -7,7 +9,7 @@ class LlvmCoverageSegment(NamedTuple):
|
||||||
segment_count: int
|
segment_count: int
|
||||||
has_count: int
|
has_count: int
|
||||||
is_region_entry: int
|
is_region_entry: int
|
||||||
is_gap_entry: Optional[int]
|
is_gap_entry: int | None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_coverage(self) -> bool:
|
def has_coverage(self) -> bool:
|
||||||
|
|
@ -18,8 +20,8 @@ class LlvmCoverageSegment(NamedTuple):
|
||||||
return self.has_count > 0
|
return self.has_count > 0
|
||||||
|
|
||||||
def get_coverage(
|
def get_coverage(
|
||||||
self, prev_segment: "LlvmCoverageSegment"
|
self, prev_segment: LlvmCoverageSegment
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> tuple[list[int], list[int]]:
|
||||||
# Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py
|
# Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py
|
||||||
if not prev_segment.is_executable:
|
if not prev_segment.is_executable:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
@ -32,12 +34,12 @@ class LlvmCoverageSegment(NamedTuple):
|
||||||
return (lines_range, []) if prev_segment.has_coverage else ([], lines_range)
|
return (lines_range, []) if prev_segment.has_coverage else ([], lines_range)
|
||||||
|
|
||||||
|
|
||||||
def parse_segments(raw_segments: List[List[int]]) -> List[LlvmCoverageSegment]:
|
def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]:
|
||||||
"""
|
"""
|
||||||
Creates LlvmCoverageSegment from a list of lists in llvm export json.
|
Creates LlvmCoverageSegment from a list of lists in llvm export json.
|
||||||
each segment is represented by 5-element array.
|
each segment is represented by 5-element array.
|
||||||
"""
|
"""
|
||||||
ret: List[LlvmCoverageSegment] = []
|
ret: list[LlvmCoverageSegment] = []
|
||||||
for raw_segment in raw_segments:
|
for raw_segment in raw_segments:
|
||||||
assert (
|
assert (
|
||||||
len(raw_segment) == 5 or len(raw_segment) == 6
|
len(raw_segment) == 5 or len(raw_segment) == 6
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Dict, IO, List, Set, Tuple
|
from typing import IO, Tuple
|
||||||
|
|
||||||
from ..oss.utils import get_pytorch_folder
|
from ..oss.utils import get_pytorch_folder
|
||||||
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
|
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
|
||||||
|
|
||||||
|
|
||||||
CoverageItem = Tuple[str, float, int, int]
|
CoverageItem = Tuple[str, float, int, int]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -16,7 +19,7 @@ def key_by_name(x: CoverageItem) -> str:
|
||||||
return x[0]
|
return x[0]
|
||||||
|
|
||||||
|
|
||||||
def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
|
def is_intrested_file(file_path: str, interested_folders: list[str]) -> bool:
|
||||||
if "cuda" in file_path:
|
if "cuda" in file_path:
|
||||||
return False
|
return False
|
||||||
if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
|
if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
|
||||||
|
|
@ -27,7 +30,7 @@ def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
|
def is_this_type_of_tests(target_name: str, test_set_by_type: set[str]) -> bool:
|
||||||
# tests are divided into three types: success / partial success / fail to collect coverage
|
# tests are divided into three types: success / partial success / fail to collect coverage
|
||||||
for test in test_set_by_type:
|
for test in test_set_by_type:
|
||||||
if target_name in test:
|
if target_name in test:
|
||||||
|
|
@ -36,7 +39,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def print_test_by_type(
|
def print_test_by_type(
|
||||||
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
|
tests: TestList, test_set_by_type: set[str], type_name: str, summary_file: IO[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
print("Tests " + type_name + " to collect coverage:", file=summary_file)
|
print("Tests " + type_name + " to collect coverage:", file=summary_file)
|
||||||
for test in tests:
|
for test in tests:
|
||||||
|
|
@ -48,8 +51,8 @@ def print_test_by_type(
|
||||||
def print_test_condition(
|
def print_test_condition(
|
||||||
tests: TestList,
|
tests: TestList,
|
||||||
tests_type: TestStatusType,
|
tests_type: TestStatusType,
|
||||||
interested_folders: List[str],
|
interested_folders: list[str],
|
||||||
coverage_only: List[str],
|
coverage_only: list[str],
|
||||||
summary_file: IO[str],
|
summary_file: IO[str],
|
||||||
summary_type: str,
|
summary_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -77,10 +80,10 @@ def print_test_condition(
|
||||||
def line_oriented_report(
|
def line_oriented_report(
|
||||||
tests: TestList,
|
tests: TestList,
|
||||||
tests_type: TestStatusType,
|
tests_type: TestStatusType,
|
||||||
interested_folders: List[str],
|
interested_folders: list[str],
|
||||||
coverage_only: List[str],
|
coverage_only: list[str],
|
||||||
covered_lines: Dict[str, Set[int]],
|
covered_lines: dict[str, set[int]],
|
||||||
uncovered_lines: Dict[str, Set[int]],
|
uncovered_lines: dict[str, set[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file:
|
with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file:
|
||||||
print_test_condition(
|
print_test_condition(
|
||||||
|
|
@ -119,13 +122,13 @@ def print_file_summary(
|
||||||
|
|
||||||
def print_file_oriented_report(
|
def print_file_oriented_report(
|
||||||
tests_type: TestStatusType,
|
tests_type: TestStatusType,
|
||||||
coverage: List[CoverageItem],
|
coverage: list[CoverageItem],
|
||||||
covered_summary: int,
|
covered_summary: int,
|
||||||
total_summary: int,
|
total_summary: int,
|
||||||
summary_file: IO[str],
|
summary_file: IO[str],
|
||||||
tests: TestList,
|
tests: TestList,
|
||||||
interested_folders: List[str],
|
interested_folders: list[str],
|
||||||
coverage_only: List[str],
|
coverage_only: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
coverage_percentage = print_file_summary(
|
coverage_percentage = print_file_summary(
|
||||||
covered_summary, total_summary, summary_file
|
covered_summary, total_summary, summary_file
|
||||||
|
|
@ -155,10 +158,10 @@ def print_file_oriented_report(
|
||||||
def file_oriented_report(
|
def file_oriented_report(
|
||||||
tests: TestList,
|
tests: TestList,
|
||||||
tests_type: TestStatusType,
|
tests_type: TestStatusType,
|
||||||
interested_folders: List[str],
|
interested_folders: list[str],
|
||||||
coverage_only: List[str],
|
coverage_only: list[str],
|
||||||
covered_lines: Dict[str, Set[int]],
|
covered_lines: dict[str, set[int]],
|
||||||
uncovered_lines: Dict[str, Set[int]],
|
uncovered_lines: dict[str, set[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file:
|
with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file:
|
||||||
covered_summary = 0
|
covered_summary = 0
|
||||||
|
|
@ -193,7 +196,7 @@ def file_oriented_report(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_html_ignored_pattern() -> List[str]:
|
def get_html_ignored_pattern() -> list[str]:
|
||||||
return ["/usr/*", "*anaconda3/*", "*third_party/*"]
|
return ["/usr/*", "*anaconda3/*", "*third_party/*"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from ..util.setting import (
|
from ..util.setting import (
|
||||||
CompilerType,
|
CompilerType,
|
||||||
|
|
@ -16,7 +18,6 @@ from ..util.utils import (
|
||||||
print_time,
|
print_time,
|
||||||
related_to_test_list,
|
related_to_test_list,
|
||||||
)
|
)
|
||||||
from .parser.coverage_record import CoverageRecord
|
|
||||||
from .parser.gcov_coverage_parser import GcovCoverageParser
|
from .parser.gcov_coverage_parser import GcovCoverageParser
|
||||||
from .parser.llvm_coverage_parser import LlvmCoverageParser
|
from .parser.llvm_coverage_parser import LlvmCoverageParser
|
||||||
from .print_report import (
|
from .print_report import (
|
||||||
|
|
@ -26,16 +27,20 @@ from .print_report import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .parser.coverage_record import CoverageRecord
|
||||||
|
|
||||||
|
|
||||||
# coverage_records: Dict[str, LineInfo] = {}
|
# coverage_records: Dict[str, LineInfo] = {}
|
||||||
covered_lines: Dict[str, Set[int]] = {}
|
covered_lines: dict[str, set[int]] = {}
|
||||||
uncovered_lines: Dict[str, Set[int]] = {}
|
uncovered_lines: dict[str, set[int]] = {}
|
||||||
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
|
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
|
||||||
|
|
||||||
|
|
||||||
def transform_file_name(
|
def transform_file_name(
|
||||||
file_path: str, interested_folders: List[str], platform: TestPlatform
|
file_path: str, interested_folders: list[str], platform: TestPlatform
|
||||||
) -> str:
|
) -> str:
|
||||||
remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
|
remove_patterns: set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
|
||||||
for pattern in remove_patterns:
|
for pattern in remove_patterns:
|
||||||
file_path = file_path.replace(pattern, "")
|
file_path = file_path.replace(pattern, "")
|
||||||
# if user has specified interested folder
|
# if user has specified interested folder
|
||||||
|
|
@ -54,7 +59,7 @@ def transform_file_name(
|
||||||
|
|
||||||
|
|
||||||
def is_intrested_file(
|
def is_intrested_file(
|
||||||
file_path: str, interested_folders: List[str], platform: TestPlatform
|
file_path: str, interested_folders: list[str], platform: TestPlatform
|
||||||
) -> bool:
|
) -> bool:
|
||||||
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
|
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
|
||||||
if any(pattern in file_path for pattern in ignored_patterns):
|
if any(pattern in file_path for pattern in ignored_patterns):
|
||||||
|
|
@ -77,7 +82,7 @@ def is_intrested_file(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_json_obj(json_file: str) -> Tuple[Any, int]:
|
def get_json_obj(json_file: str) -> tuple[Any, int]:
|
||||||
"""
|
"""
|
||||||
Sometimes at the start of file llvm/gcov will complains "fail to find coverage data",
|
Sometimes at the start of file llvm/gcov will complains "fail to find coverage data",
|
||||||
then we need to skip these lines
|
then we need to skip these lines
|
||||||
|
|
@ -102,7 +107,7 @@ def get_json_obj(json_file: str) -> Tuple[Any, int]:
|
||||||
return None, 2
|
return None, 2
|
||||||
|
|
||||||
|
|
||||||
def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
|
def parse_json(json_file: str, platform: TestPlatform) -> list[CoverageRecord]:
|
||||||
print("start parse:", json_file)
|
print("start parse:", json_file)
|
||||||
json_obj, read_status = get_json_obj(json_file)
|
json_obj, read_status = get_json_obj(json_file)
|
||||||
if read_status == 0:
|
if read_status == 0:
|
||||||
|
|
@ -117,7 +122,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
|
||||||
|
|
||||||
cov_type = detect_compiler_type(platform)
|
cov_type = detect_compiler_type(platform)
|
||||||
|
|
||||||
coverage_records: List[CoverageRecord] = []
|
coverage_records: list[CoverageRecord] = []
|
||||||
if cov_type == CompilerType.CLANG:
|
if cov_type == CompilerType.CLANG:
|
||||||
coverage_records = LlvmCoverageParser(json_obj).parse("fbcode")
|
coverage_records = LlvmCoverageParser(json_obj).parse("fbcode")
|
||||||
# print(coverage_records)
|
# print(coverage_records)
|
||||||
|
|
@ -128,7 +133,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
|
||||||
|
|
||||||
|
|
||||||
def parse_jsons(
|
def parse_jsons(
|
||||||
test_list: TestList, interested_folders: List[str], platform: TestPlatform
|
test_list: TestList, interested_folders: list[str], platform: TestPlatform
|
||||||
) -> None:
|
) -> None:
|
||||||
g = os.walk(JSON_FOLDER_BASE_DIR)
|
g = os.walk(JSON_FOLDER_BASE_DIR)
|
||||||
|
|
||||||
|
|
@ -152,8 +157,8 @@ def parse_jsons(
|
||||||
|
|
||||||
|
|
||||||
def update_coverage(
|
def update_coverage(
|
||||||
coverage_records: List[CoverageRecord],
|
coverage_records: list[CoverageRecord],
|
||||||
interested_folders: List[str],
|
interested_folders: list[str],
|
||||||
platform: TestPlatform,
|
platform: TestPlatform,
|
||||||
) -> None:
|
) -> None:
|
||||||
for item in coverage_records:
|
for item in coverage_records:
|
||||||
|
|
@ -187,8 +192,8 @@ def update_set() -> None:
|
||||||
|
|
||||||
def summarize_jsons(
|
def summarize_jsons(
|
||||||
test_list: TestList,
|
test_list: TestList,
|
||||||
interested_folders: List[str],
|
interested_folders: list[str],
|
||||||
coverage_only: List[str],
|
coverage_only: list[str],
|
||||||
platform: TestPlatform,
|
platform: TestPlatform,
|
||||||
) -> None:
|
) -> None:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, List, Set
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, NoReturn, Optional
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
from .setting import (
|
from .setting import (
|
||||||
CompilerType,
|
CompilerType,
|
||||||
|
|
@ -113,7 +115,7 @@ def get_test_name_from_whole_path(path: str) -> str:
|
||||||
return path[start + 1 : end]
|
return path[start + 1 : end]
|
||||||
|
|
||||||
|
|
||||||
def check_compiler_type(cov_type: Optional[CompilerType]) -> None:
|
def check_compiler_type(cov_type: CompilerType | None) -> None:
|
||||||
if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]:
|
if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]:
|
||||||
return
|
return
|
||||||
raise Exception( # noqa: TRY002
|
raise Exception( # noqa: TRY002
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import setuptools # type: ignore[import]
|
import setuptools # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
with open("README.md", encoding="utf-8") as fh:
|
with open("README.md", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from typing import Any
|
||||||
|
|
||||||
from coverage import CoverageData, CoveragePlugin # type: ignore[import]
|
from coverage import CoverageData, CoveragePlugin # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
|
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
|
||||||
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
|
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
|
||||||
# https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine
|
# https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import sys
|
||||||
from urllib.error import URLError
|
from urllib.error import URLError
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
|
|
||||||
MIRRORS = [
|
MIRRORS = [
|
||||||
"http://yann.lecun.com/exdb/mnist/",
|
"http://yann.lecun.com/exdb/mnist/",
|
||||||
"https://ossci-datasets.s3.amazonaws.com/mnist/",
|
"https://ossci-datasets.s3.amazonaws.com/mnist/",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
MIN_CUDA_VERSION = "11.6"
|
MIN_CUDA_VERSION = "11.6"
|
||||||
MIN_ROCM_VERSION = "5.4"
|
MIN_ROCM_VERSION = "5.4"
|
||||||
MIN_PYTHON_VERSION = (3, 8)
|
MIN_PYTHON_VERSION = (3, 8)
|
||||||
|
|
@ -141,7 +142,7 @@ def check_rocm():
|
||||||
return rocm_ver if torch.version.hip else "None"
|
return rocm_ver if torch.version.hip else "None"
|
||||||
|
|
||||||
|
|
||||||
def check_dynamo(backend, device, err_msg):
|
def check_dynamo(backend, device, err_msg) -> None:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if device == "cuda" and not torch.cuda.is_available():
|
if device == "cuda" and not torch.cuda.is_available():
|
||||||
|
|
@ -203,7 +204,7 @@ _SANITY_CHECK_ARGS = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
python_ver = check_python()
|
python_ver = check_python()
|
||||||
torch_ver = check_torch()
|
torch_ver = check_torch()
|
||||||
cuda_ver = check_cuda()
|
cuda_ver = check_cuda()
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,17 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
from typing_extensions import TypedDict # Python 3.11+
|
from typing_extensions import TypedDict # Python 3.11+
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
Step = Dict[str, Any]
|
Step = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,7 +20,7 @@ class Script(TypedDict):
|
||||||
script: str
|
script: str
|
||||||
|
|
||||||
|
|
||||||
def extract(step: Step) -> Optional[Script]:
|
def extract(step: Step) -> Script | None:
|
||||||
run = step.get("run")
|
run = step.get("run")
|
||||||
|
|
||||||
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell
|
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import array
|
import array
|
||||||
import codecs
|
import codecs
|
||||||
|
|
@ -15,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
import subprocess
|
import subprocess
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from yaml.constructor import ConstructorError
|
from yaml.constructor import ConstructorError
|
||||||
|
|
@ -29,14 +31,14 @@ except ImportError:
|
||||||
CPP_H_NAME = "spv.h"
|
CPP_H_NAME = "spv.h"
|
||||||
CPP_SRC_NAME = "spv.cpp"
|
CPP_SRC_NAME = "spv.cpp"
|
||||||
|
|
||||||
DEFAULT_ENV: Dict[str, Any] = {
|
DEFAULT_ENV: dict[str, Any] = {
|
||||||
"PRECISION": "highp",
|
"PRECISION": "highp",
|
||||||
"FLOAT_IMAGE_FORMAT": "rgba16f",
|
"FLOAT_IMAGE_FORMAT": "rgba16f",
|
||||||
"INT_IMAGE_FORMAT": "rgba32i",
|
"INT_IMAGE_FORMAT": "rgba32i",
|
||||||
"UINT_IMAGE_FORMAT": "rgba32ui",
|
"UINT_IMAGE_FORMAT": "rgba32ui",
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPES_ENV: Dict[str, Any] = {
|
TYPES_ENV: dict[str, Any] = {
|
||||||
"IMAGE_FORMAT": {
|
"IMAGE_FORMAT": {
|
||||||
"float": "rgba32f",
|
"float": "rgba32f",
|
||||||
"half": "rgba16f",
|
"half": "rgba16f",
|
||||||
|
|
@ -91,7 +93,7 @@ TYPES_ENV: Dict[str, Any] = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCS_ENV: Dict[str, Any] = {
|
FUNCS_ENV: dict[str, Any] = {
|
||||||
"GET_POS": {
|
"GET_POS": {
|
||||||
3: lambda pos: pos,
|
3: lambda pos: pos,
|
||||||
2: lambda pos: f"{pos}.xy",
|
2: lambda pos: f"{pos}.xy",
|
||||||
|
|
@ -169,7 +171,7 @@ def escape(line: str) -> str:
|
||||||
|
|
||||||
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
|
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
|
||||||
def preprocess(
|
def preprocess(
|
||||||
input_text: str, variables: Dict[str, Any], input_path: str = "codegen"
|
input_text: str, variables: dict[str, Any], input_path: str = "codegen"
|
||||||
) -> str:
|
) -> str:
|
||||||
input_lines = input_text.splitlines()
|
input_lines = input_text.splitlines()
|
||||||
python_lines = []
|
python_lines = []
|
||||||
|
|
@ -243,9 +245,9 @@ def preprocess(
|
||||||
class SPVGenerator:
|
class SPVGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
src_dir_paths: Union[str, List[str]],
|
src_dir_paths: str | list[str],
|
||||||
env: Dict[Any, Any],
|
env: dict[Any, Any],
|
||||||
glslc_path: Optional[str],
|
glslc_path: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(src_dir_paths, str):
|
if isinstance(src_dir_paths, str):
|
||||||
self.src_dir_paths = [src_dir_paths]
|
self.src_dir_paths = [src_dir_paths]
|
||||||
|
|
@ -255,18 +257,18 @@ class SPVGenerator:
|
||||||
self.env = env
|
self.env = env
|
||||||
self.glslc_path = glslc_path
|
self.glslc_path = glslc_path
|
||||||
|
|
||||||
self.glsl_src_files: Dict[str, str] = {}
|
self.glsl_src_files: dict[str, str] = {}
|
||||||
self.template_yaml_files: List[str] = []
|
self.template_yaml_files: list[str] = []
|
||||||
|
|
||||||
self.addSrcAndYamlFiles(self.src_dir_paths)
|
self.addSrcAndYamlFiles(self.src_dir_paths)
|
||||||
self.shader_template_params: Dict[Any, Any] = {}
|
self.shader_template_params: dict[Any, Any] = {}
|
||||||
for yaml_file in self.template_yaml_files:
|
for yaml_file in self.template_yaml_files:
|
||||||
self.parseTemplateYaml(yaml_file)
|
self.parseTemplateYaml(yaml_file)
|
||||||
|
|
||||||
self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
|
self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
|
||||||
self.constructOutputMap()
|
self.constructOutputMap()
|
||||||
|
|
||||||
def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
|
def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
|
||||||
for src_path in src_dir_paths:
|
for src_path in src_dir_paths:
|
||||||
# Collect glsl source files
|
# Collect glsl source files
|
||||||
glsl_files = glob.glob(
|
glsl_files = glob.glob(
|
||||||
|
|
@ -285,9 +287,9 @@ class SPVGenerator:
|
||||||
|
|
||||||
def generateVariantCombinations(
|
def generateVariantCombinations(
|
||||||
self,
|
self,
|
||||||
iterated_params: Dict[str, Any],
|
iterated_params: dict[str, Any],
|
||||||
exclude_params: Optional[Set[str]] = None,
|
exclude_params: set[str] | None = None,
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
if exclude_params is None:
|
if exclude_params is None:
|
||||||
exclude_params = set()
|
exclude_params = set()
|
||||||
all_iterated_params = []
|
all_iterated_params = []
|
||||||
|
|
@ -362,8 +364,8 @@ class SPVGenerator:
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_shader_params(
|
def create_shader_params(
|
||||||
self, variant_params: Optional[Dict[str, Any]] = None
|
self, variant_params: dict[str, Any] | None = None
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
if variant_params is None:
|
if variant_params is None:
|
||||||
variant_params = {}
|
variant_params = {}
|
||||||
shader_params = copy.deepcopy(self.env)
|
shader_params = copy.deepcopy(self.env)
|
||||||
|
|
@ -409,7 +411,7 @@ class SPVGenerator:
|
||||||
self.create_shader_params(),
|
self.create_shader_params(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generateSPV(self, output_dir: str) -> Dict[str, str]:
|
def generateSPV(self, output_dir: str) -> dict[str, str]:
|
||||||
output_file_map = {}
|
output_file_map = {}
|
||||||
for shader_name in self.output_shader_map:
|
for shader_name in self.output_shader_map:
|
||||||
source_glsl = self.output_shader_map[shader_name][0]
|
source_glsl = self.output_shader_map[shader_name][0]
|
||||||
|
|
@ -457,11 +459,11 @@ class SPVGenerator:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShaderInfo:
|
class ShaderInfo:
|
||||||
tile_size: List[int]
|
tile_size: list[int]
|
||||||
layouts: List[str]
|
layouts: list[str]
|
||||||
weight_storage_type: str = ""
|
weight_storage_type: str = ""
|
||||||
bias_storage_type: str = ""
|
bias_storage_type: str = ""
|
||||||
register_for: Optional[Tuple[str, List[str]]] = None
|
register_for: tuple[str, list[str]] | None = None
|
||||||
|
|
||||||
|
|
||||||
def getName(filePath: str) -> str:
|
def getName(filePath: str) -> str:
|
||||||
|
|
@ -478,7 +480,7 @@ def isTileSizeLine(lineStr: str) -> bool:
|
||||||
return re.search(tile_size_id, lineStr) is not None
|
return re.search(tile_size_id, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def findTileSizes(lineStr: str) -> List[int]:
|
def findTileSizes(lineStr: str) -> list[int]:
|
||||||
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
||||||
matches = re.search(tile_size_id, lineStr)
|
matches = re.search(tile_size_id, lineStr)
|
||||||
if matches is None:
|
if matches is None:
|
||||||
|
|
@ -520,7 +522,7 @@ def isRegisterForLine(lineStr: str) -> bool:
|
||||||
return re.search(register_for_id, lineStr) is not None
|
return re.search(register_for_id, lineStr) is not None
|
||||||
|
|
||||||
|
|
||||||
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
|
||||||
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
||||||
matches = re.findall(register_for_pattern, lineStr)
|
matches = re.findall(register_for_pattern, lineStr)
|
||||||
if matches is None:
|
if matches is None:
|
||||||
|
|
@ -609,7 +611,7 @@ static const api::ShaderRegisterInit register_shaders(®ister_fn);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]:
|
def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
|
||||||
with open(spvPath, "rb") as fr:
|
with open(spvPath, "rb") as fr:
|
||||||
next_bin = array.array("I", fr.read())
|
next_bin = array.array("I", fr.read())
|
||||||
sizeBytes = 4 * len(next_bin)
|
sizeBytes = 4 * len(next_bin)
|
||||||
|
|
@ -665,7 +667,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def genCppFiles(
|
def genCppFiles(
|
||||||
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
|
spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
spv_bin_strs = []
|
spv_bin_strs = []
|
||||||
register_shader_info_strs = []
|
register_shader_info_strs = []
|
||||||
|
|
@ -705,7 +707,7 @@ def genCppFiles(
|
||||||
##########
|
##########
|
||||||
|
|
||||||
|
|
||||||
def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
d = {}
|
d = {}
|
||||||
if items:
|
if items:
|
||||||
for item in items:
|
for item in items:
|
||||||
|
|
@ -716,7 +718,7 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def main(argv: List[str]) -> int:
|
def main(argv: list[str]) -> int:
|
||||||
parser = argparse.ArgumentParser(description="")
|
parser = argparse.ArgumentParser(description="")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-i",
|
"-i",
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from setuptools import distutils # type: ignore[import]
|
from setuptools import distutils # type: ignore[import]
|
||||||
|
|
||||||
|
|
@ -12,7 +13,7 @@ UNKNOWN = "Unknown"
|
||||||
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
|
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
|
||||||
|
|
||||||
|
|
||||||
def get_sha(pytorch_root: Union[str, Path]) -> str:
|
def get_sha(pytorch_root: str | Path) -> str:
|
||||||
try:
|
try:
|
||||||
rev = None
|
rev = None
|
||||||
if os.path.exists(os.path.join(pytorch_root, ".git")):
|
if os.path.exists(os.path.join(pytorch_root, ".git")):
|
||||||
|
|
@ -30,7 +31,7 @@ def get_sha(pytorch_root: Union[str, Path]) -> str:
|
||||||
return UNKNOWN
|
return UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
def get_tag(pytorch_root: Union[str, Path]) -> str:
|
def get_tag(pytorch_root: str | Path) -> str:
|
||||||
try:
|
try:
|
||||||
tag = subprocess.run(
|
tag = subprocess.run(
|
||||||
["git", "describe", "--tags", "--exact"],
|
["git", "describe", "--tags", "--exact"],
|
||||||
|
|
@ -46,8 +47,8 @@ def get_tag(pytorch_root: Union[str, Path]) -> str:
|
||||||
return UNKNOWN
|
return UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
def get_torch_version(sha: Optional[str] = None) -> str:
|
def get_torch_version(sha: str | None = None) -> str:
|
||||||
pytorch_root = Path(__file__).parent.parent
|
pytorch_root = Path(__file__).absolute().parent.parent
|
||||||
version = open(pytorch_root / "version.txt").read().strip()
|
version = open(pytorch_root / "version.txt").read().strip()
|
||||||
|
|
||||||
if os.getenv("PYTORCH_BUILD_VERSION"):
|
if os.getenv("PYTORCH_BUILD_VERSION"):
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
"""GitHub Utilities"""
|
"""GitHub Utilities"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Callable, cast, Dict
|
||||||
from typing import Any, Callable, cast, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
@ -13,11 +13,11 @@ from urllib.request import Request, urlopen
|
||||||
def gh_fetch_url_and_headers(
|
def gh_fetch_url_and_headers(
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: dict[str, str] | None = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
method: Optional[str] = None,
|
method: str | None = None,
|
||||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||||
) -> Tuple[Any, Any]:
|
) -> tuple[Any, Any]:
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = {}
|
headers = {}
|
||||||
token = os.environ.get("GITHUB_TOKEN")
|
token = os.environ.get("GITHUB_TOKEN")
|
||||||
|
|
@ -44,9 +44,9 @@ def gh_fetch_url_and_headers(
|
||||||
def gh_fetch_url(
|
def gh_fetch_url(
|
||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: dict[str, str] | None = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
method: Optional[str] = None,
|
method: str | None = None,
|
||||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return gh_fetch_url_and_headers(
|
return gh_fetch_url_and_headers(
|
||||||
|
|
@ -56,8 +56,8 @@ def gh_fetch_url(
|
||||||
|
|
||||||
def _gh_fetch_json_any(
|
def _gh_fetch_json_any(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||||
if params is not None and len(params) > 0:
|
if params is not None and len(params) > 0:
|
||||||
|
|
@ -69,13 +69,13 @@ def _gh_fetch_json_any(
|
||||||
|
|
||||||
def gh_fetch_json_dict(
|
def gh_fetch_json_dict(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||||
|
|
||||||
|
|
||||||
def gh_fetch_commit(org: str, repo: str, sha: str) -> Dict[str, Any]:
|
def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]:
|
||||||
return gh_fetch_json_dict(
|
return gh_fetch_json_dict(
|
||||||
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}"
|
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
|
QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
|
||||||
ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")
|
ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
|
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, Sequence, Union
|
from pathlib import Path
|
||||||
|
from typing import Literal, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -15,10 +18,13 @@ from torchgen.api.unboxing import convert_arguments
|
||||||
from torchgen.context import method_with_native_function
|
from torchgen.context import method_with_native_function
|
||||||
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
|
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
|
||||||
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
|
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
|
||||||
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
|
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
|
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeUnboxingFunctions:
|
class ComputeUnboxingFunctions:
|
||||||
|
|
@ -156,7 +162,7 @@ def gen_unboxing(
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
) -> None:
|
) -> None:
|
||||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
|
||||||
return fn.root_name
|
return fn.root_name
|
||||||
|
|
||||||
selected_op_num: int = len(selector.operators)
|
selected_op_num: int = len(selector.operators)
|
||||||
|
|
@ -195,7 +201,7 @@ def gen_unboxing(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(args: List[str]) -> None:
|
def main(args: list[str]) -> None:
|
||||||
parser = argparse.ArgumentParser(description="Generate unboxing source files")
|
parser = argparse.ArgumentParser(description="Generate unboxing source files")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s",
|
"-s",
|
||||||
|
|
@ -272,7 +278,7 @@ def main(args: List[str]) -> None:
|
||||||
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
|
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
|
||||||
|
|
||||||
if options.output_dependencies:
|
if options.output_dependencies:
|
||||||
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
depfile_path = Path(options.output_dependencies).resolve()
|
||||||
depfile_name = depfile_path.name
|
depfile_name = depfile_path.name
|
||||||
depfile_stem = depfile_path.stem
|
depfile_stem = depfile_path.stem
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -8,7 +10,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, NamedTuple, Optional, Pattern
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "ACTIONLINT"
|
LINTER_CODE = "ACTIONLINT"
|
||||||
|
|
@ -22,18 +24,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
RESULTS_RE: Pattern[str] = re.compile(
|
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||||
r"""(?mx)
|
r"""(?mx)
|
||||||
^
|
^
|
||||||
(?P<file>.*?):
|
(?P<file>.*?):
|
||||||
|
|
@ -47,8 +49,8 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -64,7 +66,7 @@ def run_command(
|
||||||
def check_file(
|
def check_file(
|
||||||
binary: str,
|
binary: str,
|
||||||
file: str,
|
file: str,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,9 @@ archive is downloaded from some sites like GitHub because it can change. Specifi
|
||||||
GitHub gives no guarantee to keep the same value forever. Check for more details at
|
GitHub gives no guarantee to keep the same value forever. Check for more details at
|
||||||
https://github.com/community/community/discussions/46034.
|
https://github.com/community/community/discussions/46034.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
@ -13,7 +16,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, NamedTuple, Optional, Set
|
from typing import NamedTuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,18 +33,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def is_required_checksum(urls: List[Optional[str]]) -> bool:
|
def is_required_checksum(urls: list[str | None]) -> bool:
|
||||||
if not urls:
|
if not urls:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -58,7 +61,7 @@ def is_required_checksum(urls: List[Optional[str]]) -> bool:
|
||||||
|
|
||||||
def get_disallowed_checksums(
|
def get_disallowed_checksums(
|
||||||
binary: str,
|
binary: str,
|
||||||
) -> Set[str]:
|
) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Return the set of disallowed checksums from all http_archive rules
|
Return the set of disallowed checksums from all http_archive rules
|
||||||
"""
|
"""
|
||||||
|
|
@ -96,8 +99,8 @@ def get_disallowed_checksums(
|
||||||
|
|
||||||
def check_bazel(
|
def check_bazel(
|
||||||
filename: str,
|
filename: str,
|
||||||
disallowed_checksums: Set[str],
|
disallowed_checksums: set[str],
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
original = ""
|
original = ""
|
||||||
replacement = ""
|
replacement = ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -7,7 +9,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, BinaryIO, List, NamedTuple, Optional
|
from typing import Any, BinaryIO, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -41,11 +43,11 @@ def as_posix(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _run_command(
|
def _run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
stdin: BinaryIO,
|
stdin: BinaryIO,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -63,12 +65,12 @@ def _run_command(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
stdin: BinaryIO,
|
stdin: BinaryIO,
|
||||||
retries: int,
|
retries: int,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
remaining_retries = retries
|
remaining_retries = retries
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -90,7 +92,7 @@ def check_file(
|
||||||
filename: str,
|
filename: str,
|
||||||
retries: int,
|
retries: int,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
original = f.read()
|
original = f.read()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -8,7 +10,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -42,10 +44,10 @@ def as_posix(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _run_command(
|
def _run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -62,11 +64,11 @@ def _run_command(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
retries: int,
|
retries: int,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
remaining_retries = retries
|
remaining_retries = retries
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -89,7 +91,7 @@ def check_file(
|
||||||
binary: str,
|
binary: str,
|
||||||
retries: int,
|
retries: int,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
original = f.read()
|
original = f.read()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -11,7 +13,7 @@ import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sysconfig import get_paths as gp
|
from sysconfig import get_paths as gp
|
||||||
from typing import Any, List, NamedTuple, Optional, Pattern
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
# PyTorch directory root
|
# PyTorch directory root
|
||||||
|
|
@ -49,15 +51,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -65,7 +67,7 @@ def as_posix(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
# c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move]
|
# c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move]
|
||||||
RESULTS_RE: Pattern[str] = re.compile(
|
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||||
r"""(?mx)
|
r"""(?mx)
|
||||||
^
|
^
|
||||||
(?P<file>.*?):
|
(?P<file>.*?):
|
||||||
|
|
@ -80,8 +82,8 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -103,7 +105,7 @@ severities = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def clang_search_dirs() -> List[str]:
|
def clang_search_dirs() -> list[str]:
|
||||||
# Compilers are ordered based on fallback preference
|
# Compilers are ordered based on fallback preference
|
||||||
# We pick the first one that is available on the system
|
# We pick the first one that is available on the system
|
||||||
compilers = ["clang", "gcc", "cpp", "cc"]
|
compilers = ["clang", "gcc", "cpp", "cc"]
|
||||||
|
|
@ -152,7 +154,7 @@ def check_file(
|
||||||
filename: str,
|
filename: str,
|
||||||
binary: str,
|
binary: str,
|
||||||
build_dir: Path,
|
build_dir: Path,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
[binary, f"-p={build_dir}", *include_args, filename],
|
[binary, f"-p={build_dir}", *include_args, filename],
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -7,7 +9,7 @@ import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, NamedTuple, Optional, Pattern
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "CMAKE"
|
LINTER_CODE = "CMAKE"
|
||||||
|
|
@ -21,19 +23,19 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
# CMakeLists.txt:901: Lines should be <= 80 characters long [linelength]
|
# CMakeLists.txt:901: Lines should be <= 80 characters long [linelength]
|
||||||
RESULTS_RE: Pattern[str] = re.compile(
|
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||||
r"""(?mx)
|
r"""(?mx)
|
||||||
^
|
^
|
||||||
(?P<file>.*?):
|
(?P<file>.*?):
|
||||||
|
|
@ -46,8 +48,8 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -63,7 +65,7 @@ def run_command(
|
||||||
def check_file(
|
def check_file(
|
||||||
filename: str,
|
filename: str,
|
||||||
config: str,
|
config: str,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
["cmakelint", f"--config={config}", filename],
|
["cmakelint", f"--config={config}", filename],
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,15 @@
|
||||||
CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues
|
CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
CONSTEXPR = "constexpr char"
|
CONSTEXPR = "constexpr char"
|
||||||
CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char"
|
CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char"
|
||||||
|
|
@ -21,18 +23,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> Optional[LintMessage]:
|
def check_file(filename: str) -> LintMessage | None:
|
||||||
logging.debug("Checking file %s", filename)
|
logging.debug("Checking file %s", filename)
|
||||||
|
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,17 @@
|
||||||
"""
|
"""
|
||||||
EXEC: Ensure that source files are not executable.
|
EXEC: Ensure that source files are not executable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "EXEC"
|
LINTER_CODE = "EXEC"
|
||||||
|
|
||||||
|
|
@ -21,18 +24,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> Optional[LintMessage]:
|
def check_file(filename: str) -> LintMessage | None:
|
||||||
is_executable = os.access(filename, os.X_OK)
|
is_executable = os.access(filename, os.X_OK)
|
||||||
if is_executable:
|
if is_executable:
|
||||||
return LintMessage(
|
return LintMessage(
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -7,7 +9,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# https://www.flake8rules.com/
|
# https://www.flake8rules.com/
|
||||||
DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
|
DOCUMENTED_IN_FLAKE8RULES: set[str] = {
|
||||||
"E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117",
|
"E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117",
|
||||||
"E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129",
|
"E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129",
|
||||||
"E131", "E133",
|
"E131", "E133",
|
||||||
|
|
@ -78,14 +80,14 @@ DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# https://pypi.org/project/flake8-comprehensions/#rules
|
# https://pypi.org/project/flake8-comprehensions/#rules
|
||||||
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: Set[str] = {
|
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: set[str] = {
|
||||||
"C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409",
|
"C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409",
|
||||||
"C410",
|
"C410",
|
||||||
"C411", "C412", "C413", "C414", "C415", "C416",
|
"C411", "C412", "C413", "C414", "C415", "C416",
|
||||||
}
|
}
|
||||||
|
|
||||||
# https://github.com/PyCQA/flake8-bugbear#list-of-warnings
|
# https://github.com/PyCQA/flake8-bugbear#list-of-warnings
|
||||||
DOCUMENTED_IN_BUGBEAR: Set[str] = {
|
DOCUMENTED_IN_BUGBEAR: set[str] = {
|
||||||
"B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010",
|
"B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010",
|
||||||
"B011", "B012", "B013", "B014", "B015",
|
"B011", "B012", "B013", "B014", "B015",
|
||||||
"B301", "B302", "B303", "B304", "B305", "B306",
|
"B301", "B302", "B303", "B304", "B305", "B306",
|
||||||
|
|
@ -98,7 +100,7 @@ DOCUMENTED_IN_BUGBEAR: Set[str] = {
|
||||||
# stdin:3:6: T484 Name 'foo' is not defined
|
# stdin:3:6: T484 Name 'foo' is not defined
|
||||||
# stdin:3:-100: W605 invalid escape sequence '\/'
|
# stdin:3:-100: W605 invalid escape sequence '\/'
|
||||||
# stdin:3:1: E302 expected 2 blank lines, found 1
|
# stdin:3:1: E302 expected 2 blank lines, found 1
|
||||||
RESULTS_RE: Pattern[str] = re.compile(
|
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||||
r"""(?mx)
|
r"""(?mx)
|
||||||
^
|
^
|
||||||
(?P<file>.*?):
|
(?P<file>.*?):
|
||||||
|
|
@ -134,10 +136,10 @@ def _test_results_re() -> None:
|
||||||
|
|
||||||
|
|
||||||
def _run_command(
|
def _run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
extra_env: Optional[Dict[str, str]],
|
extra_env: dict[str, str] | None,
|
||||||
) -> "subprocess.CompletedProcess[str]":
|
) -> subprocess.CompletedProcess[str]:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"$ %s",
|
"$ %s",
|
||||||
" ".join(
|
" ".join(
|
||||||
|
|
@ -158,11 +160,11 @@ def _run_command(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
extra_env: Optional[Dict[str, str]],
|
extra_env: dict[str, str] | None,
|
||||||
retries: int,
|
retries: int,
|
||||||
) -> "subprocess.CompletedProcess[str]":
|
) -> subprocess.CompletedProcess[str]:
|
||||||
remaining_retries = retries
|
remaining_retries = retries
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -243,11 +245,11 @@ def get_issue_documentation_url(code: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def check_files(
|
def check_files(
|
||||||
filenames: List[str],
|
filenames: list[str],
|
||||||
flake8_plugins_path: Optional[str],
|
flake8_plugins_path: str | None,
|
||||||
severities: Dict[str, LintSeverity],
|
severities: dict[str, LintSeverity],
|
||||||
retries: int,
|
retries: int,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
[sys.executable, "-mflake8", "--exit-zero"] + filenames,
|
[sys.executable, "-mflake8", "--exit-zero"] + filenames,
|
||||||
|
|
@ -351,7 +353,7 @@ def main() -> None:
|
||||||
else os.path.realpath(args.flake8_plugins_path)
|
else os.path.realpath(args.flake8_plugins_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
severities: Dict[str, LintSeverity] = {}
|
severities: dict[str, LintSeverity] = {}
|
||||||
if args.severity:
|
if args.severity:
|
||||||
for severity in args.severity:
|
for severity in args.severity:
|
||||||
parts = severity.split(":", 1)
|
parts = severity.split(":", 1)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
Generic linter that greps for a pattern and optionally suggests replacements.
|
Generic linter that greps for a pattern and optionally suggests replacements.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -10,7 +12,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -44,8 +46,8 @@ def as_posix(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -65,7 +67,7 @@ def lint_file(
|
||||||
linter_name: str,
|
linter_name: str,
|
||||||
error_name: str,
|
error_name: str,
|
||||||
error_description: str,
|
error_description: str,
|
||||||
) -> Optional[LintMessage]:
|
) -> LintMessage | None:
|
||||||
# matching_line looks like:
|
# matching_line looks like:
|
||||||
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
|
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
|
||||||
split = matching_line.split(":")
|
split = matching_line.split(":")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple, Optional, Tuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "LINTRUNNER_VERSION"
|
LINTER_CODE = "LINTRUNNER_VERSION"
|
||||||
|
|
@ -16,18 +18,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def toVersionString(version_tuple: Tuple[int, int, int]) -> str:
|
def toVersionString(version_tuple: tuple[int, int, int]) -> str:
|
||||||
return ".".join(str(x) for x in version_tuple)
|
return ".".join(str(x) for x in version_tuple)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -8,7 +10,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Pattern
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
# tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment]
|
# tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment]
|
||||||
RESULTS_RE: Pattern[str] = re.compile(
|
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||||
r"""(?mx)
|
r"""(?mx)
|
||||||
^
|
^
|
||||||
(?P<file>.*?):
|
(?P<file>.*?):
|
||||||
|
|
@ -56,7 +58,7 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
|
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
|
||||||
INTERNAL_ERROR_RE: Pattern[str] = re.compile(
|
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
|
||||||
r"""(?mx)
|
r"""(?mx)
|
||||||
^
|
^
|
||||||
(?P<file>.*?):
|
(?P<file>.*?):
|
||||||
|
|
@ -69,11 +71,11 @@ INTERNAL_ERROR_RE: Pattern[str] = re.compile(
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
*,
|
*,
|
||||||
extra_env: Optional[Dict[str, str]],
|
extra_env: dict[str, str] | None,
|
||||||
retries: int,
|
retries: int,
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -94,7 +96,7 @@ severities = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def check_mypy_installed(code: str) -> List[LintMessage]:
|
def check_mypy_installed(code: str) -> list[LintMessage]:
|
||||||
cmd = [sys.executable, "-mmypy", "-V"]
|
cmd = [sys.executable, "-mmypy", "-V"]
|
||||||
try:
|
try:
|
||||||
subprocess.run(cmd, check=True, capture_output=True)
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
|
@ -117,11 +119,11 @@ def check_mypy_installed(code: str) -> List[LintMessage]:
|
||||||
|
|
||||||
|
|
||||||
def check_files(
|
def check_files(
|
||||||
filenames: List[str],
|
filenames: list[str],
|
||||||
config: str,
|
config: str,
|
||||||
retries: int,
|
retries: int,
|
||||||
code: str,
|
code: str,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
# dmypy has a bug where it won't pick up changes if you pass it absolute
|
# dmypy has a bug where it won't pick up changes if you pass it absolute
|
||||||
# file names, see https://github.com/python/mypy/issues/16768
|
# file names, see https://github.com/python/mypy/issues/16768
|
||||||
filenames = [os.path.relpath(f) for f in filenames]
|
filenames = [os.path.relpath(f) for f in filenames]
|
||||||
|
|
@ -224,7 +226,7 @@ def main() -> None:
|
||||||
|
|
||||||
# Use a dictionary here to preserve order. mypy cares about order,
|
# Use a dictionary here to preserve order. mypy cares about order,
|
||||||
# tragically, e.g. https://github.com/python/mypy/issues/2015
|
# tragically, e.g. https://github.com/python/mypy/issues/2015
|
||||||
filenames: Dict[str, bool] = {}
|
filenames: dict[str, bool] = {}
|
||||||
|
|
||||||
# If a stub file exists, have mypy check it instead of the original file, in
|
# If a stub file exists, have mypy check it instead of the original file, in
|
||||||
# accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)
|
# accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,14 @@ is simply to make sure that there is *some* configuration of ruamel that can rou
|
||||||
the YAML, not to be prescriptive about it.
|
the YAML, not to be prescriptive about it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
import ruamel.yaml # type: ignore[import]
|
import ruamel.yaml # type: ignore[import]
|
||||||
|
|
||||||
|
|
@ -32,15 +34,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,16 @@
|
||||||
"""
|
"""
|
||||||
NEWLINE: Checks files to make sure there are no trailing newlines.
|
NEWLINE: Checks files to make sure there are no trailing newlines.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
NEWLINE = 10 # ASCII "\n"
|
NEWLINE = 10 # ASCII "\n"
|
||||||
CARRIAGE_RETURN = 13 # ASCII "\r"
|
CARRIAGE_RETURN = 13 # ASCII "\r"
|
||||||
|
|
@ -22,18 +25,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> Optional[LintMessage]:
|
def check_file(filename: str) -> LintMessage | None:
|
||||||
logging.debug("Checking file %s", filename)
|
logging.debug("Checking file %s", filename)
|
||||||
|
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
|
|
@ -85,7 +88,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
|
||||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||||
)
|
)
|
||||||
has_changes = False
|
has_changes = False
|
||||||
original_lines: Optional[List[bytes]] = None
|
original_lines: list[bytes] | None = None
|
||||||
for idx, line in enumerate(lines):
|
for idx, line in enumerate(lines):
|
||||||
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
|
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
|
||||||
if not has_changes:
|
if not has_changes:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -5,7 +7,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
@ -23,18 +25,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> List[LintMessage]:
|
def check_file(filename: str) -> list[LintMessage]:
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
original = f.read().decode("utf-8")
|
original = f.read().decode("utf-8")
|
||||||
replacement = ""
|
replacement = ""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
"""
|
"""
|
||||||
Initializer script that installs stuff to pip.
|
Initializer script that installs stuff to pip.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -9,10 +12,8 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]:
|
||||||
def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]":
|
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, BinaryIO
|
from typing import Any, BinaryIO
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "RUFF"
|
LINTER_CODE = "RUFF"
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -6,7 +8,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "SHELLCHECK"
|
LINTER_CODE = "SHELLCHECK"
|
||||||
|
|
@ -20,20 +22,20 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
args: List[str],
|
args: list[str],
|
||||||
) -> "subprocess.CompletedProcess[bytes]":
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
logging.debug("$ %s", " ".join(args))
|
logging.debug("$ %s", " ".join(args))
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
try:
|
try:
|
||||||
|
|
@ -47,8 +49,8 @@ def run_command(
|
||||||
|
|
||||||
|
|
||||||
def check_files(
|
def check_files(
|
||||||
files: List[str],
|
files: list[str],
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
try:
|
try:
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
["shellcheck", "--external-sources", "--format=json1"] + files
|
["shellcheck", "--external-sources", "--format=json1"] + files
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,19 @@ calls run_tests to ensure that the test will be run in OSS CI.
|
||||||
|
|
||||||
Takes ~2 minuters to run without the multiprocessing, probably overkill.
|
Takes ~2 minuters to run without the multiprocessing, probably overkill.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
import libcst.matchers as m
|
import libcst.matchers as m
|
||||||
|
|
||||||
|
|
||||||
LINTER_CODE = "TEST_HAS_MAIN"
|
LINTER_CODE = "TEST_HAS_MAIN"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -62,18 +66,18 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> List[LintMessage]:
|
def check_file(filename: str) -> list[LintMessage]:
|
||||||
lint_messages = []
|
lint_messages = []
|
||||||
|
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,13 @@ has valid ownership information in a comment header. Valid means:
|
||||||
- Each owner label actually exists in PyTorch
|
- Each owner label actually exists in PyTorch
|
||||||
- Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS
|
- Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, NamedTuple
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,15 +29,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
# Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions
|
# Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions
|
||||||
|
|
@ -58,8 +61,8 @@ GLOB_EXCEPTIONS = ["**/test/run_test.py"]
|
||||||
|
|
||||||
|
|
||||||
def check_labels(
|
def check_labels(
|
||||||
labels: List[str], filename: str, line_number: int
|
labels: list[str], filename: str, line_number: int
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
lint_messages = []
|
lint_messages = []
|
||||||
for label in labels:
|
for label in labels:
|
||||||
if label not in PYTORCH_LABELS:
|
if label not in PYTORCH_LABELS:
|
||||||
|
|
@ -104,7 +107,7 @@ def check_labels(
|
||||||
return lint_messages
|
return lint_messages
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> List[LintMessage]:
|
def check_file(filename: str) -> list[LintMessage]:
|
||||||
lint_messages = []
|
lint_messages = []
|
||||||
has_ownership_info = False
|
has_ownership_info = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
|
@ -6,7 +8,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
from ufmt.core import ufmt_string
|
from ufmt.core import ufmt_string
|
||||||
from ufmt.util import make_black_config
|
from ufmt.util import make_black_config
|
||||||
|
|
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def as_posix(name: str) -> str:
|
def as_posix(name: str) -> str:
|
||||||
|
|
@ -59,7 +61,7 @@ def format_error_message(filename: str, err: Exception) -> LintMessage:
|
||||||
|
|
||||||
def check_file(
|
def check_file(
|
||||||
filename: str,
|
filename: str,
|
||||||
) -> List[LintMessage]:
|
) -> list[LintMessage]:
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
original = f.read().decode("utf-8")
|
original = f.read().decode("utf-8")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,20 @@
|
||||||
|
|
||||||
Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`.
|
Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, NamedTuple, Optional
|
from typing import Any, Iterable, NamedTuple
|
||||||
|
|
||||||
from yaml import dump, load
|
from yaml import dump, load
|
||||||
|
|
||||||
|
|
||||||
# Safely load fast C Yaml loader/dumper if they are available
|
# Safely load fast C Yaml loader/dumper if they are available
|
||||||
try:
|
try:
|
||||||
from yaml import CSafeLoader as Loader
|
from yaml import CSafeLoader as Loader
|
||||||
|
|
@ -27,15 +31,15 @@ class LintSeverity(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
line: Optional[int]
|
line: int | None
|
||||||
char: Optional[int]
|
char: int | None
|
||||||
code: str
|
code: str
|
||||||
severity: LintSeverity
|
severity: LintSeverity
|
||||||
name: str
|
name: str
|
||||||
original: Optional[str]
|
original: str | None
|
||||||
replacement: Optional[str]
|
replacement: str | None
|
||||||
description: Optional[str]
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
def glob_yamls(path: Path) -> Iterable[Path]:
|
def glob_yamls(path: Path) -> Iterable[Path]:
|
||||||
|
|
@ -51,7 +55,7 @@ def is_workflow(yaml: Any) -> bool:
|
||||||
return yaml.get("jobs") is not None
|
return yaml.get("jobs") is not None
|
||||||
|
|
||||||
|
|
||||||
def print_lint_message(path: Path, job: Dict[str, Any], sync_tag: str) -> None:
|
def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
|
||||||
job_id = next(iter(job.keys()))
|
job_id = next(iter(job.keys()))
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(cmd: List[str]) -> None:
|
def run_cmd(cmd: list[str]) -> None:
|
||||||
print(f"Running: {cmd}")
|
print(f"Running: {cmd}")
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,16 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import Set
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
# Safely load fast C Yaml loader/dumper if they are available
|
# Safely load fast C Yaml loader/dumper if they are available
|
||||||
try:
|
try:
|
||||||
from yaml import CSafeLoader as Loader
|
from yaml import CSafeLoader as Loader
|
||||||
|
|
@ -46,7 +49,7 @@ selected_mobile_ops_preamble = """#pragma once
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
|
||||||
ops = []
|
ops = []
|
||||||
for op_name, op in selective_builder.operators.items():
|
for op_name, op in selective_builder.operators.items():
|
||||||
if op.is_root_operator:
|
if op.is_root_operator:
|
||||||
|
|
@ -125,7 +128,7 @@ def write_selected_mobile_ops(
|
||||||
# 2. All kernel dtypes
|
# 2. All kernel dtypes
|
||||||
def write_selected_mobile_ops_with_all_dtypes(
|
def write_selected_mobile_ops_with_all_dtypes(
|
||||||
output_file_path: str,
|
output_file_path: str,
|
||||||
root_ops: Set[str],
|
root_ops: set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
with open(output_file_path, "wb") as out_file:
|
with open(output_file_path, "wb") as out_file:
|
||||||
body_parts = [selected_mobile_ops_preamble]
|
body_parts = [selected_mobile_ops_preamble]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import lldb # type: ignore[import]
|
import lldb # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
# load into lldb instance with:
|
# load into lldb instance with:
|
||||||
# command script import tools/lldb/deploy_debugger.py
|
# command script import tools/lldb/deploy_debugger.py
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,9 @@ well. This can be done with
|
||||||
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
|
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
|
||||||
the repo directory.
|
the repo directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
|
|
@ -40,23 +43,10 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from typing import (
|
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
cast,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER: Optional[logging.Logger] = None
|
|
||||||
|
LOGGER: logging.Logger | None = None
|
||||||
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
||||||
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
||||||
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
||||||
|
|
@ -68,9 +58,9 @@ SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphin
|
||||||
|
|
||||||
|
|
||||||
class Formatter(logging.Formatter):
|
class Formatter(logging.Formatter):
|
||||||
redactions: Dict[str, str]
|
redactions: dict[str, str]
|
||||||
|
|
||||||
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None):
|
def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
|
||||||
super().__init__(fmt, datefmt)
|
super().__init__(fmt, datefmt)
|
||||||
self.redactions = {}
|
self.redactions = {}
|
||||||
|
|
||||||
|
|
@ -192,7 +182,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def check_in_repo() -> Optional[str]:
|
def check_in_repo() -> str | None:
|
||||||
"""Ensures that we are in the PyTorch repo."""
|
"""Ensures that we are in the PyTorch repo."""
|
||||||
if not os.path.isfile("setup.py"):
|
if not os.path.isfile("setup.py"):
|
||||||
return "Not in root-level PyTorch repo, no setup.py found"
|
return "Not in root-level PyTorch repo, no setup.py found"
|
||||||
|
|
@ -203,7 +193,7 @@ def check_in_repo() -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
|
def check_branch(subcommand: str, branch: str | None) -> str | None:
|
||||||
"""Checks that the branch name can be checked out."""
|
"""Checks that the branch name can be checked out."""
|
||||||
if subcommand != "checkout":
|
if subcommand != "checkout":
|
||||||
return None
|
return None
|
||||||
|
|
@ -259,7 +249,7 @@ def timed(prefix: str) -> Callable[[F], F]:
|
||||||
def _make_channel_args(
|
def _make_channel_args(
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
channels: Iterable[str] = ("pytorch-nightly",),
|
||||||
override_channels: bool = False,
|
override_channels: bool = False,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
args = []
|
args = []
|
||||||
for channel in channels:
|
for channel in channels:
|
||||||
args.append("--channel")
|
args.append("--channel")
|
||||||
|
|
@ -271,11 +261,11 @@ def _make_channel_args(
|
||||||
|
|
||||||
@timed("Solving conda environment")
|
@timed("Solving conda environment")
|
||||||
def conda_solve(
|
def conda_solve(
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: str | None = None,
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
channels: Iterable[str] = ("pytorch-nightly",),
|
||||||
override_channels: bool = False,
|
override_channels: bool = False,
|
||||||
) -> Tuple[List[str], str, str, bool, List[str]]:
|
) -> tuple[list[str], str, str, bool, list[str]]:
|
||||||
"""Performs the conda solve and splits the deps from the package."""
|
"""Performs the conda solve and splits the deps from the package."""
|
||||||
# compute what environment to use
|
# compute what environment to use
|
||||||
if prefix is not None:
|
if prefix is not None:
|
||||||
|
|
@ -329,7 +319,7 @@ def conda_solve(
|
||||||
|
|
||||||
|
|
||||||
@timed("Installing dependencies")
|
@timed("Installing dependencies")
|
||||||
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
|
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
|
||||||
"""Install dependencies to deps environment"""
|
"""Install dependencies to deps environment"""
|
||||||
if not existing_env:
|
if not existing_env:
|
||||||
# first remove previous pytorch-deps env
|
# first remove previous pytorch-deps env
|
||||||
|
|
@ -342,7 +332,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No
|
||||||
|
|
||||||
|
|
||||||
@timed("Installing pytorch nightly binaries")
|
@timed("Installing pytorch nightly binaries")
|
||||||
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
|
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
|
||||||
"""Install pytorch into a temporary directory"""
|
"""Install pytorch into a temporary directory"""
|
||||||
pytdir = tempfile.TemporaryDirectory()
|
pytdir = tempfile.TemporaryDirectory()
|
||||||
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
|
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
|
||||||
|
|
@ -421,33 +411,33 @@ def pull_nightly_version(spdir: str) -> None:
|
||||||
p = subprocess.run(cmd, check=True)
|
p = subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_linux(source_dir: str) -> List[str]:
|
def _get_listing_linux(source_dir: str) -> list[str]:
|
||||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
||||||
return listing
|
return listing
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_osx(source_dir: str) -> List[str]:
|
def _get_listing_osx(source_dir: str) -> list[str]:
|
||||||
# oddly, these are .so files even on Mac
|
# oddly, these are .so files even on Mac
|
||||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
||||||
return listing
|
return listing
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_win(source_dir: str) -> List[str]:
|
def _get_listing_win(source_dir: str) -> list[str]:
|
||||||
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
||||||
return listing
|
return listing
|
||||||
|
|
||||||
|
|
||||||
def _glob_pyis(d: str) -> Set[str]:
|
def _glob_pyis(d: str) -> set[str]:
|
||||||
search = os.path.join(d, "**", "*.pyi")
|
search = os.path.join(d, "**", "*.pyi")
|
||||||
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
|
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
|
||||||
return pyis
|
return pyis
|
||||||
|
|
||||||
|
|
||||||
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
def _find_missing_pyi(source_dir: str, target_dir: str) -> list[str]:
|
||||||
source_pyis = _glob_pyis(source_dir)
|
source_pyis = _glob_pyis(source_dir)
|
||||||
target_pyis = _glob_pyis(target_dir)
|
target_pyis = _glob_pyis(target_dir)
|
||||||
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
|
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
|
||||||
|
|
@ -455,7 +445,7 @@ def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
||||||
return missing_pyis
|
return missing_pyis
|
||||||
|
|
||||||
|
|
||||||
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
|
def _get_listing(source_dir: str, target_dir: str, platform: str) -> list[str]:
|
||||||
if platform.startswith("linux"):
|
if platform.startswith("linux"):
|
||||||
listing = _get_listing_linux(source_dir)
|
listing = _get_listing_linux(source_dir)
|
||||||
elif platform.startswith("osx"):
|
elif platform.startswith("osx"):
|
||||||
|
|
@ -510,12 +500,12 @@ def _move_single(
|
||||||
mover(src, trg)
|
mover(src, trg)
|
||||||
|
|
||||||
|
|
||||||
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
def _copy_files(listing: list[str], source_dir: str, target_dir: str) -> None:
|
||||||
for src in listing:
|
for src in listing:
|
||||||
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
||||||
|
|
||||||
|
|
||||||
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
def _link_files(listing: list[str], source_dir: str, target_dir: str) -> None:
|
||||||
for src in listing:
|
for src in listing:
|
||||||
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
||||||
|
|
||||||
|
|
@ -537,7 +527,7 @@ def move_nightly_files(spdir: str, platform: str) -> None:
|
||||||
_copy_files(listing, source_dir, target_dir)
|
_copy_files(listing, source_dir, target_dir)
|
||||||
|
|
||||||
|
|
||||||
def _available_envs() -> Dict[str, str]:
|
def _available_envs() -> dict[str, str]:
|
||||||
cmd = ["conda", "env", "list"]
|
cmd = ["conda", "env", "list"]
|
||||||
p = subprocess.run(
|
p = subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
|
|
@ -559,7 +549,7 @@ def _available_envs() -> Dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
@timed("Writing pytorch-nightly.pth")
|
@timed("Writing pytorch-nightly.pth")
|
||||||
def write_pth(env_opts: List[str], platform: str) -> None:
|
def write_pth(env_opts: list[str], platform: str) -> None:
|
||||||
"""Writes Python path file for this dir."""
|
"""Writes Python path file for this dir."""
|
||||||
env_type, env_dir = env_opts
|
env_type, env_dir = env_opts
|
||||||
if env_type == "--name":
|
if env_type == "--name":
|
||||||
|
|
@ -582,9 +572,9 @@ def install(
|
||||||
*,
|
*,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
subcommand: str = "checkout",
|
subcommand: str = "checkout",
|
||||||
branch: Optional[str] = None,
|
branch: str | None = None,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: str | None = None,
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
channels: Iterable[str] = ("pytorch-nightly",),
|
||||||
override_channels: bool = False,
|
override_channels: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -673,7 +663,7 @@ def make_parser() -> ArgumentParser:
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
def main(args: Sequence[str] | None = None) -> None:
|
||||||
"""Main entry point"""
|
"""Main entry point"""
|
||||||
global LOGGER
|
global LOGGER
|
||||||
p = make_parser()
|
p = make_parser()
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,15 @@ CMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
|
|
||||||
def resolve_include(path: Path, include_dirs: List[Path]) -> Path:
|
def resolve_include(path: Path, include_dirs: list[Path]) -> Path:
|
||||||
for include_path in include_dirs:
|
for include_path in include_dirs:
|
||||||
abs_path = include_path / path
|
abs_path = include_path / path
|
||||||
if abs_path.exists():
|
if abs_path.exists():
|
||||||
|
|
@ -36,7 +38,7 @@ Tried the following paths, but none existed:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def repair_depfile(depfile: TextIO, include_dirs: List[Path]) -> None:
|
def repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None:
|
||||||
changes_made = False
|
changes_made = False
|
||||||
out = ""
|
out = ""
|
||||||
for line in depfile:
|
for line in depfile:
|
||||||
|
|
@ -70,8 +72,8 @@ PRE_INCLUDE_ARGS = ["-include", "--pre-include"]
|
||||||
POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
|
POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
|
||||||
|
|
||||||
|
|
||||||
def extract_include_arg(include_dirs: List[Path], i: int, args: List[str]) -> None:
|
def extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None:
|
||||||
def extract_one(name: str, i: int, args: List[str]) -> Optional[str]:
|
def extract_one(name: str, i: int, args: list[str]) -> str | None:
|
||||||
arg = args[i]
|
arg = args[i]
|
||||||
if arg == name:
|
if arg == name:
|
||||||
return args[i + 1]
|
return args[i + 1]
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import yaml
|
||||||
from torchgen import utils as torchgen_utils
|
from torchgen import utils as torchgen_utils
|
||||||
from torchgen.yaml_utils import YamlLoader
|
from torchgen.yaml_utils import YamlLoader
|
||||||
|
|
||||||
|
|
||||||
_RULES_GENERATED_COMMENT = """\
|
_RULES_GENERATED_COMMENT = """\
|
||||||
GENERATED CODE - DO NOT EDIT DIRECTLY
|
GENERATED CODE - DO NOT EDIT DIRECTLY
|
||||||
This file is generated by gen_diagnostics.py.
|
This file is generated by gen_diagnostics.py.
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Dict, List, Sequence
|
from typing import Sequence
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
|
|
@ -220,7 +222,7 @@ to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
|
||||||
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
||||||
|
|
||||||
|
|
||||||
def sig_for_ops(opname: str) -> List[str]:
|
def sig_for_ops(opname: str) -> list[str]:
|
||||||
"""sig_for_ops(opname : str) -> List[str]
|
"""sig_for_ops(opname : str) -> List[str]
|
||||||
|
|
||||||
Returns signatures for operator special functions (__add__ etc.)"""
|
Returns signatures for operator special functions (__add__ etc.)"""
|
||||||
|
|
@ -254,8 +256,8 @@ def sig_for_ops(opname: str) -> List[str]:
|
||||||
raise Exception("unknown op", opname) # noqa: TRY002
|
raise Exception("unknown op", opname) # noqa: TRY002
|
||||||
|
|
||||||
|
|
||||||
def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
|
def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
|
||||||
type_hints: List[str] = []
|
type_hints: list[str] = []
|
||||||
|
|
||||||
# Some deprecated ops that are on the blocklist are still included in pyi
|
# Some deprecated ops that are on the blocklist are still included in pyi
|
||||||
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
||||||
|
|
@ -285,7 +287,7 @@ def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
|
||||||
return type_hints
|
return type_hints
|
||||||
|
|
||||||
|
|
||||||
def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]:
|
def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
|
||||||
flag_pos = arg_list.index("{return_indices}")
|
flag_pos = arg_list.index("{return_indices}")
|
||||||
# If return_indices is positional arg, everything before should have no default
|
# If return_indices is positional arg, everything before should have no default
|
||||||
arg_list_positional = (
|
arg_list_positional = (
|
||||||
|
|
@ -329,7 +331,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO the list for `torch._C._nn` is nonexhaustive
|
# TODO the list for `torch._C._nn` is nonexhaustive
|
||||||
unsorted_c_nn_function_hints: Dict[str, List[str]] = {}
|
unsorted_c_nn_function_hints: dict[str, list[str]] = {}
|
||||||
|
|
||||||
for d in (2, 3):
|
for d in (2, 3):
|
||||||
unsorted_c_nn_function_hints.update(
|
unsorted_c_nn_function_hints.update(
|
||||||
|
|
@ -471,7 +473,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
c_nn_function_hints: List[str] = []
|
c_nn_function_hints: list[str] = []
|
||||||
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
|
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
|
||||||
if len(hints) > 1:
|
if len(hints) > 1:
|
||||||
hints = ["@overload\n" + h for h in hints]
|
hints = ["@overload\n" + h for h in hints]
|
||||||
|
|
@ -528,7 +530,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
|
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
|
||||||
unsorted_dispatched_hints: Dict[str, List[str]] = {}
|
unsorted_dispatched_hints: dict[str, list[str]] = {}
|
||||||
|
|
||||||
for d in (1, 2, 3):
|
for d in (1, 2, 3):
|
||||||
unsorted_dispatched_hints.update(
|
unsorted_dispatched_hints.update(
|
||||||
|
|
@ -563,7 +565,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||||
# There's no fractional_max_pool1d
|
# There's no fractional_max_pool1d
|
||||||
del unsorted_dispatched_hints["fractional_max_pool1d"]
|
del unsorted_dispatched_hints["fractional_max_pool1d"]
|
||||||
|
|
||||||
dispatched_hints: List[str] = []
|
dispatched_hints: list[str] = []
|
||||||
for _, hints in sorted(unsorted_dispatched_hints.items()):
|
for _, hints in sorted(unsorted_dispatched_hints.items()):
|
||||||
if len(hints) > 1:
|
if len(hints) > 1:
|
||||||
hints = ["@overload\n" + h for h in hints]
|
hints = ["@overload\n" + h for h in hints]
|
||||||
|
|
@ -594,7 +596,7 @@ We gather the docstrings for torch with the following steps:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def gather_docstrs() -> Dict[str, str]:
|
def gather_docstrs() -> dict[str, str]:
|
||||||
docstrs = {}
|
docstrs = {}
|
||||||
|
|
||||||
def mock_add_docstr(func: Mock, docstr: str) -> None:
|
def mock_add_docstr(func: Mock, docstr: str) -> None:
|
||||||
|
|
@ -648,12 +650,12 @@ def gen_pyi(
|
||||||
# also needs to update the other file.
|
# also needs to update the other file.
|
||||||
|
|
||||||
# Dictionary for NamedTuple definitions
|
# Dictionary for NamedTuple definitions
|
||||||
structseqs: Dict[str, str] = {}
|
structseqs: dict[str, str] = {}
|
||||||
|
|
||||||
# Generate type signatures for top-level functions
|
# Generate type signatures for top-level functions
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
|
unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
|
||||||
|
|
||||||
for n, n1, n2 in [
|
for n, n1, n2 in [
|
||||||
("csr", "crow", "col"),
|
("csr", "crow", "col"),
|
||||||
|
|
@ -1054,7 +1056,7 @@ def gen_pyi(
|
||||||
# Generate type signatures for Tensor methods
|
# Generate type signatures for Tensor methods
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
|
unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
|
||||||
unsorted_tensor_method_hints.update(
|
unsorted_tensor_method_hints.update(
|
||||||
{
|
{
|
||||||
"size": [
|
"size": [
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import Any, List, Union
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from junitparser import ( # type: ignore[import]
|
from junitparser import ( # type: ignore[import]
|
||||||
|
|
@ -23,8 +26,8 @@ except ImportError:
|
||||||
print("rich not found, for color output use 'pip install rich'")
|
print("rich not found, for color output use 'pip install rich'")
|
||||||
|
|
||||||
|
|
||||||
def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported]
|
def parse_junit_reports(path_to_reports: str) -> list[TestCase]: # type: ignore[no-any-unimported]
|
||||||
def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported]
|
def parse_file(path: str) -> list[TestCase]: # type: ignore[no-any-unimported]
|
||||||
try:
|
try:
|
||||||
return convert_junit_to_testcases(JUnitXml.fromfile(path))
|
return convert_junit_to_testcases(JUnitXml.fromfile(path))
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|
@ -46,7 +49,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore
|
||||||
return ret_xml
|
return ret_xml
|
||||||
|
|
||||||
|
|
||||||
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported]
|
def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]: # type: ignore[no-any-unimported]
|
||||||
testcases = []
|
testcases = []
|
||||||
for item in xml:
|
for item in xml:
|
||||||
if isinstance(item, TestSuite):
|
if isinstance(item, TestSuite):
|
||||||
|
|
@ -56,7 +59,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase
|
||||||
return testcases
|
return testcases
|
||||||
|
|
||||||
|
|
||||||
def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported]
|
def render_tests(testcases: list[TestCase]) -> None: # type: ignore[no-any-unimported]
|
||||||
num_passed = 0
|
num_passed = 0
|
||||||
num_skipped = 0
|
num_skipped = 0
|
||||||
num_failed = 0
|
num_failed = 0
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
def which(thefile: str) -> Optional[str]:
|
def which(thefile: str) -> str | None:
|
||||||
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
|
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
|
||||||
for d in path:
|
for d in path:
|
||||||
fname = os.path.join(d, thefile)
|
fname = os.path.join(d, thefile)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"Manages CMake."
|
"Manages CMake."
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
|
@ -8,7 +9,7 @@ import sys
|
||||||
import sysconfig
|
import sysconfig
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
from subprocess import CalledProcessError, check_call, check_output
|
from subprocess import CalledProcessError, check_call, check_output
|
||||||
from typing import Any, cast, Dict, List, Optional
|
from typing import Any, cast
|
||||||
|
|
||||||
from . import which
|
from . import which
|
||||||
from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file
|
from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file
|
||||||
|
|
@ -77,7 +78,7 @@ class CMake:
|
||||||
return cmake_command
|
return cmake_command
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_version(cmd: Optional[str]) -> Any:
|
def _get_version(cmd: str | None) -> Any:
|
||||||
"Returns cmake version."
|
"Returns cmake version."
|
||||||
|
|
||||||
if cmd is None:
|
if cmd is None:
|
||||||
|
|
@ -87,7 +88,7 @@ class CMake:
|
||||||
return LooseVersion(line.strip().split(" ")[2])
|
return LooseVersion(line.strip().split(" ")[2])
|
||||||
raise RuntimeError("no version found")
|
raise RuntimeError("no version found")
|
||||||
|
|
||||||
def run(self, args: List[str], env: Dict[str, str]) -> None:
|
def run(self, args: list[str], env: dict[str, str]) -> None:
|
||||||
"Executes cmake with arguments and an environment."
|
"Executes cmake with arguments and an environment."
|
||||||
|
|
||||||
command = [self._cmake_command] + args
|
command = [self._cmake_command] + args
|
||||||
|
|
@ -101,13 +102,13 @@ class CMake:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def defines(args: List[str], **kwargs: CMakeValue) -> None:
|
def defines(args: list[str], **kwargs: CMakeValue) -> None:
|
||||||
"Adds definitions to a cmake argument list."
|
"Adds definitions to a cmake argument list."
|
||||||
for key, value in sorted(kwargs.items()):
|
for key, value in sorted(kwargs.items()):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
args.append(f"-D{key}={value}")
|
args.append(f"-D{key}={value}")
|
||||||
|
|
||||||
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
|
def get_cmake_cache_variables(self) -> dict[str, CMakeValue]:
|
||||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||||
Returns:
|
Returns:
|
||||||
dict: A ``dict`` containing the value of cached CMake variables.
|
dict: A ``dict`` containing the value of cached CMake variables.
|
||||||
|
|
@ -117,11 +118,11 @@ class CMake:
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
version: Optional[str],
|
version: str | None,
|
||||||
cmake_python_library: Optional[str],
|
cmake_python_library: str | None,
|
||||||
build_python: bool,
|
build_python: bool,
|
||||||
build_test: bool,
|
build_test: bool,
|
||||||
my_env: Dict[str, str],
|
my_env: dict[str, str],
|
||||||
rerun: bool,
|
rerun: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"Runs cmake to generate native build files."
|
"Runs cmake to generate native build files."
|
||||||
|
|
@ -181,7 +182,7 @@ class CMake:
|
||||||
_mkdir_p(self.build_dir)
|
_mkdir_p(self.build_dir)
|
||||||
|
|
||||||
# Store build options that are directly stored in environment variables
|
# Store build options that are directly stored in environment variables
|
||||||
build_options: Dict[str, CMakeValue] = {}
|
build_options: dict[str, CMakeValue] = {}
|
||||||
|
|
||||||
# Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars.
|
# Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars.
|
||||||
# This is a dict that maps environment variables to the corresponding variable name in CMake.
|
# This is a dict that maps environment variables to the corresponding variable name in CMake.
|
||||||
|
|
@ -340,7 +341,7 @@ class CMake:
|
||||||
args.append(base_dir)
|
args.append(base_dir)
|
||||||
self.run(args, env=my_env)
|
self.run(args, env=my_env)
|
||||||
|
|
||||||
def build(self, my_env: Dict[str, str]) -> None:
|
def build(self, my_env: dict[str, str]) -> None:
|
||||||
"Runs cmake to build binaries."
|
"Runs cmake to build binaries."
|
||||||
|
|
||||||
from .env import build_type
|
from .env import build_type
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,10 @@ This is refactored from cmake.py to avoid circular imports issue with env.py,
|
||||||
which calls get_cmake_cache_variables_from_file
|
which calls get_cmake_cache_variables_from_file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Dict, IO, Optional, Union
|
from typing import IO, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
CMakeValue = Optional[Union[bool, str]]
|
CMakeValue = Optional[Union[bool, str]]
|
||||||
|
|
@ -42,7 +44,7 @@ def convert_cmake_value_to_python_value(
|
||||||
|
|
||||||
def get_cmake_cache_variables_from_file(
|
def get_cmake_cache_variables_from_file(
|
||||||
cmake_cache_file: IO[str],
|
cmake_cache_file: IO[str],
|
||||||
) -> Dict[str, CMakeValue]:
|
) -> dict[str, CMakeValue]:
|
||||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import cast, Iterable, List, Optional
|
from typing import cast, Iterable
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS = platform.system() == "Windows"
|
IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
|
@ -30,11 +32,11 @@ def check_negative_env_flag(name: str, default: str = "") -> bool:
|
||||||
return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"]
|
return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"]
|
||||||
|
|
||||||
|
|
||||||
def gather_paths(env_vars: Iterable[str]) -> List[str]:
|
def gather_paths(env_vars: Iterable[str]) -> list[str]:
|
||||||
return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars)))
|
return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars)))
|
||||||
|
|
||||||
|
|
||||||
def lib_paths_from_base(base_path: str) -> List[str]:
|
def lib_paths_from_base(base_path: str) -> list[str]:
|
||||||
return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]]
|
return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +56,7 @@ class BuildType:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cmake_build_type_env: Optional[str] = None) -> None:
|
def __init__(self, cmake_build_type_env: str | None = None) -> None:
|
||||||
if cmake_build_type_env is not None:
|
if cmake_build_type_env is not None:
|
||||||
self.build_type_string = cmake_build_type_env
|
self.build_type_string = cmake_build_type_env
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@
|
||||||
# and use the version numbers from there as substitutions for
|
# and use the version numbers from there as substitutions for
|
||||||
# an expand_template action. Since there isn't, this silly script exists.
|
# an expand_template action. Since there isn't, this silly script exists.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import cast, Dict, Tuple
|
from typing import cast, Tuple
|
||||||
|
|
||||||
|
|
||||||
Version = Tuple[int, int, int]
|
Version = Tuple[int, int, int]
|
||||||
|
|
||||||
|
|
@ -30,7 +33,7 @@ def parse_version(version: str) -> Version:
|
||||||
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
|
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
|
||||||
|
|
||||||
|
|
||||||
def apply_replacements(replacements: Dict[str, str], text: str) -> str:
|
def apply_replacements(replacements: dict[str, str], text: str) -> str:
|
||||||
"""
|
"""
|
||||||
Applies the given replacements within the text.
|
Applies the given replacements within the text.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, cast, Optional
|
from typing import Any, cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -18,10 +20,10 @@ TAGS_PATH = "aten/src/ATen/native/tags.yaml"
|
||||||
|
|
||||||
def generate_code(
|
def generate_code(
|
||||||
gen_dir: pathlib.Path,
|
gen_dir: pathlib.Path,
|
||||||
native_functions_path: Optional[str] = None,
|
native_functions_path: str | None = None,
|
||||||
tags_path: Optional[str] = None,
|
tags_path: str | None = None,
|
||||||
install_dir: Optional[str] = None,
|
install_dir: str | None = None,
|
||||||
subset: Optional[str] = None,
|
subset: str | None = None,
|
||||||
disable_autograd: bool = False,
|
disable_autograd: bool = False,
|
||||||
force_schema_registration: bool = False,
|
force_schema_registration: bool = False,
|
||||||
operator_selector: Any = None,
|
operator_selector: Any = None,
|
||||||
|
|
@ -102,8 +104,8 @@ def get_selector_from_legacy_operator_selection_list(
|
||||||
|
|
||||||
|
|
||||||
def get_selector(
|
def get_selector(
|
||||||
selected_op_list_path: Optional[str],
|
selected_op_list_path: str | None,
|
||||||
operators_yaml_path: Optional[str],
|
operators_yaml_path: str | None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
# cwrap depends on pyyaml, so we can't import it earlier
|
# cwrap depends on pyyaml, so we can't import it earlier
|
||||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Dict, Generator, Tuple
|
from typing import Any, Generator
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import (
|
from tools.stats.upload_stats_lib import (
|
||||||
download_s3_artifacts,
|
download_s3_artifacts,
|
||||||
|
|
@ -14,13 +16,14 @@ from tools.stats.upload_stats_lib import (
|
||||||
)
|
)
|
||||||
from tools.stats.upload_test_stats import process_xml_element
|
from tools.stats.upload_test_stats import process_xml_element
|
||||||
|
|
||||||
|
|
||||||
TESTCASE_TAG = "testcase"
|
TESTCASE_TAG = "testcase"
|
||||||
SEPARATOR = ";"
|
SEPARATOR = ";"
|
||||||
|
|
||||||
|
|
||||||
def process_report(
|
def process_report(
|
||||||
report: Path,
|
report: Path,
|
||||||
) -> Dict[str, Dict[str, int]]:
|
) -> dict[str, dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
Return a list of disabled tests that should be re-enabled and those that are still
|
Return a list of disabled tests that should be re-enabled and those that are still
|
||||||
flaky (failed or skipped)
|
flaky (failed or skipped)
|
||||||
|
|
@ -36,7 +39,7 @@ def process_report(
|
||||||
# * Skipped tests from unittest
|
# * Skipped tests from unittest
|
||||||
#
|
#
|
||||||
# We want to keep track of how many times the test fails (num_red) or passes (num_green)
|
# We want to keep track of how many times the test fails (num_red) or passes (num_green)
|
||||||
all_tests: Dict[str, Dict[str, int]] = {}
|
all_tests: dict[str, dict[str, int]] = {}
|
||||||
|
|
||||||
for test_case in root.iter(TESTCASE_TAG):
|
for test_case in root.iter(TESTCASE_TAG):
|
||||||
parsed_test_case = process_xml_element(test_case)
|
parsed_test_case = process_xml_element(test_case)
|
||||||
|
|
@ -116,7 +119,7 @@ def get_test_reports(
|
||||||
yield from Path(".").glob("**/*.xml")
|
yield from Path(".").glob("**/*.xml")
|
||||||
|
|
||||||
|
|
||||||
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]:
|
def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]:
|
||||||
"""
|
"""
|
||||||
Follow flaky bot convention here, if that changes, this will also need to be updated
|
Follow flaky bot convention here, if that changes, this will also need to be updated
|
||||||
"""
|
"""
|
||||||
|
|
@ -133,7 +136,7 @@ def prepare_record(
|
||||||
flaky: bool,
|
flaky: bool,
|
||||||
num_red: int = 0,
|
num_red: int = 0,
|
||||||
num_green: int = 0,
|
num_green: int = 0,
|
||||||
) -> Tuple[Any, Dict[str, Any]]:
|
) -> tuple[Any, dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Prepare the record to save onto S3
|
Prepare the record to save onto S3
|
||||||
"""
|
"""
|
||||||
|
|
@ -162,7 +165,7 @@ def prepare_record(
|
||||||
def save_results(
|
def save_results(
|
||||||
workflow_id: int,
|
workflow_id: int,
|
||||||
workflow_run_attempt: int,
|
workflow_run_attempt: int,
|
||||||
all_tests: Dict[str, Dict[str, int]],
|
all_tests: dict[str, dict[str, int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save the result to S3, so it can go to Rockset
|
Save the result to S3, so it can go to Rockset
|
||||||
|
|
@ -228,7 +231,7 @@ def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
|
||||||
Find the list of all disabled tests that should be re-enabled
|
Find the list of all disabled tests that should be re-enabled
|
||||||
"""
|
"""
|
||||||
# Aggregated across all jobs
|
# Aggregated across all jobs
|
||||||
all_tests: Dict[str, Dict[str, int]] = {}
|
all_tests: dict[str, dict[str, int]] = {}
|
||||||
|
|
||||||
for report in get_test_reports(
|
for report in get_test_reports(
|
||||||
args.repo, args.workflow_run_id, args.workflow_run_attempt
|
args.repo, args.workflow_run_id, args.workflow_run_attempt
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,19 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Any, Callable, cast, Dict, List, Optional, Union
|
from typing import Any, Callable, cast, Dict
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
|
|
||||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
def get_disabled_issues() -> List[str]:
|
def get_disabled_issues() -> list[str]:
|
||||||
reenabled_issues = os.getenv("REENABLED_ISSUES", "")
|
reenabled_issues = os.getenv("REENABLED_ISSUES", "")
|
||||||
issue_numbers = reenabled_issues.split(",")
|
issue_numbers = reenabled_issues.split(",")
|
||||||
print("Ignoring disabled issues: ", issue_numbers)
|
print("Ignoring disabled issues: ", issue_numbers)
|
||||||
|
|
@ -34,11 +36,11 @@ FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_cache(
|
def fetch_and_cache(
|
||||||
dirpath: Union[str, pathlib.Path],
|
dirpath: str | pathlib.Path,
|
||||||
name: str,
|
name: str,
|
||||||
url: str,
|
url: str,
|
||||||
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
|
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
This fetch and cache utils allows sharing between different process.
|
This fetch and cache utils allows sharing between different process.
|
||||||
"""
|
"""
|
||||||
|
|
@ -76,7 +78,7 @@ def fetch_and_cache(
|
||||||
|
|
||||||
def get_slow_tests(
|
def get_slow_tests(
|
||||||
dirpath: str, filename: str = SLOW_TESTS_FILE
|
dirpath: str, filename: str = SLOW_TESTS_FILE
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> dict[str, float] | None:
|
||||||
url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json"
|
url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json"
|
||||||
try:
|
try:
|
||||||
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
||||||
|
|
@ -85,7 +87,7 @@ def get_slow_tests(
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_test_times() -> Dict[str, Dict[str, float]]:
|
def get_test_times() -> dict[str, dict[str, float]]:
|
||||||
return get_from_test_infra_generated_stats(
|
return get_from_test_infra_generated_stats(
|
||||||
"test-times.json",
|
"test-times.json",
|
||||||
TEST_TIMES_FILE,
|
TEST_TIMES_FILE,
|
||||||
|
|
@ -93,7 +95,7 @@ def get_test_times() -> Dict[str, Dict[str, float]]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_test_class_times() -> Dict[str, Dict[str, float]]:
|
def get_test_class_times() -> dict[str, dict[str, float]]:
|
||||||
return get_from_test_infra_generated_stats(
|
return get_from_test_infra_generated_stats(
|
||||||
"test-class-times.json",
|
"test-class-times.json",
|
||||||
TEST_CLASS_TIMES_FILE,
|
TEST_CLASS_TIMES_FILE,
|
||||||
|
|
@ -103,8 +105,8 @@ def get_test_class_times() -> Dict[str, Dict[str, float]]:
|
||||||
|
|
||||||
def get_disabled_tests(
|
def get_disabled_tests(
|
||||||
dirpath: str, filename: str = DISABLED_TESTS_FILE
|
dirpath: str, filename: str = DISABLED_TESTS_FILE
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> dict[str, Any] | None:
|
||||||
def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
|
def process_disabled_test(the_response: dict[str, Any]) -> dict[str, Any]:
|
||||||
# remove re-enabled tests and condense even further by getting rid of pr_num
|
# remove re-enabled tests and condense even further by getting rid of pr_num
|
||||||
disabled_issues = get_disabled_issues()
|
disabled_issues = get_disabled_issues()
|
||||||
disabled_test_from_issues = dict()
|
disabled_test_from_issues = dict()
|
||||||
|
|
@ -124,7 +126,7 @@ def get_disabled_tests(
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_test_file_ratings() -> Dict[str, Any]:
|
def get_test_file_ratings() -> dict[str, Any]:
|
||||||
return get_from_test_infra_generated_stats(
|
return get_from_test_infra_generated_stats(
|
||||||
"file_test_rating.json",
|
"file_test_rating.json",
|
||||||
TEST_FILE_RATINGS_FILE,
|
TEST_FILE_RATINGS_FILE,
|
||||||
|
|
@ -132,7 +134,7 @@ def get_test_file_ratings() -> Dict[str, Any]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_test_class_ratings() -> Dict[str, Any]:
|
def get_test_class_ratings() -> dict[str, Any]:
|
||||||
return get_from_test_infra_generated_stats(
|
return get_from_test_infra_generated_stats(
|
||||||
"file_test_class_rating.json",
|
"file_test_class_rating.json",
|
||||||
TEST_CLASS_RATINGS_FILE,
|
TEST_CLASS_RATINGS_FILE,
|
||||||
|
|
@ -140,7 +142,7 @@ def get_test_class_ratings() -> Dict[str, Any]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
|
def get_td_heuristic_historial_edited_files_json() -> dict[str, Any]:
|
||||||
return get_from_test_infra_generated_stats(
|
return get_from_test_infra_generated_stats(
|
||||||
"td_heuristic_historical_edited_files.json",
|
"td_heuristic_historical_edited_files.json",
|
||||||
TD_HEURISTIC_HISTORICAL_EDITED_FILES,
|
TD_HEURISTIC_HISTORICAL_EDITED_FILES,
|
||||||
|
|
@ -148,7 +150,7 @@ def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_td_heuristic_profiling_json() -> Dict[str, Any]:
|
def get_td_heuristic_profiling_json() -> dict[str, Any]:
|
||||||
return get_from_test_infra_generated_stats(
|
return get_from_test_infra_generated_stats(
|
||||||
"td_heuristic_profiling.json",
|
"td_heuristic_profiling.json",
|
||||||
TD_HEURISTIC_PROFILING_FILE,
|
TD_HEURISTIC_PROFILING_FILE,
|
||||||
|
|
@ -182,7 +184,7 @@ def copy_additional_previous_failures() -> None:
|
||||||
|
|
||||||
def get_from_test_infra_generated_stats(
|
def get_from_test_infra_generated_stats(
|
||||||
from_file: str, to_file: str, failure_explanation: str
|
from_file: str, to_file: str, failure_explanation: str
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}"
|
url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}"
|
||||||
try:
|
try:
|
||||||
return fetch_and_cache(
|
return fetch_and_cache(
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,17 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import psutil # type: ignore[import]
|
import psutil # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
def get_processes_running_python_tests() -> List[Any]:
|
def get_processes_running_python_tests() -> list[Any]:
|
||||||
python_processes = []
|
python_processes = []
|
||||||
for process in psutil.process_iter():
|
for process in psutil.process_iter():
|
||||||
try:
|
try:
|
||||||
|
|
@ -20,7 +23,7 @@ def get_processes_running_python_tests() -> List[Any]:
|
||||||
return python_processes
|
return python_processes
|
||||||
|
|
||||||
|
|
||||||
def get_per_process_cpu_info() -> List[Dict[str, Any]]:
|
def get_per_process_cpu_info() -> list[dict[str, Any]]:
|
||||||
processes = get_processes_running_python_tests()
|
processes = get_processes_running_python_tests()
|
||||||
per_process_info = []
|
per_process_info = []
|
||||||
for p in processes:
|
for p in processes:
|
||||||
|
|
@ -49,7 +52,7 @@ def get_per_process_cpu_info() -> List[Dict[str, Any]]:
|
||||||
return per_process_info
|
return per_process_info
|
||||||
|
|
||||||
|
|
||||||
def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
|
def get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
|
||||||
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||||
per_process_info = []
|
per_process_info = []
|
||||||
for p in processes:
|
for p in processes:
|
||||||
|
|
@ -58,7 +61,7 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
|
||||||
return per_process_info
|
return per_process_info
|
||||||
|
|
||||||
|
|
||||||
def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
|
def rocm_get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
|
||||||
processes = amdsmi.amdsmi_get_gpu_process_list(handle)
|
processes = amdsmi.amdsmi_get_gpu_process_list(handle)
|
||||||
per_process_info = []
|
per_process_info = []
|
||||||
for p in processes:
|
for p in processes:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
@ -6,7 +8,7 @@ from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, cast, Dict, List
|
from typing import Any, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
@ -18,6 +20,7 @@ from tools.stats.upload_stats_lib import (
|
||||||
upload_workflow_stats_to_s3,
|
upload_workflow_stats_to_s3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)"
|
REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -56,7 +59,7 @@ def get_test_config(job_name: str) -> str:
|
||||||
|
|
||||||
def get_td_exclusions(
|
def get_td_exclusions(
|
||||||
workflow_run_id: int, workflow_run_attempt: int
|
workflow_run_id: int, workflow_run_attempt: int
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
with TemporaryDirectory() as temp_dir:
|
with TemporaryDirectory() as temp_dir:
|
||||||
print("Using temporary directory:", temp_dir)
|
print("Using temporary directory:", temp_dir)
|
||||||
os.chdir(temp_dir)
|
os.chdir(temp_dir)
|
||||||
|
|
@ -68,7 +71,7 @@ def get_td_exclusions(
|
||||||
for path in s3_paths:
|
for path in s3_paths:
|
||||||
unzip(path)
|
unzip(path)
|
||||||
|
|
||||||
grouped_tests: Dict[str, Any] = defaultdict(lambda: defaultdict(set))
|
grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set))
|
||||||
for td_exclusions in Path(".").glob("**/td_exclusions*.json"):
|
for td_exclusions in Path(".").glob("**/td_exclusions*.json"):
|
||||||
with open(td_exclusions) as f:
|
with open(td_exclusions) as f:
|
||||||
exclusions = json.load(f)
|
exclusions = json.load(f)
|
||||||
|
|
@ -85,9 +88,9 @@ def get_td_exclusions(
|
||||||
return grouped_tests
|
return grouped_tests
|
||||||
|
|
||||||
|
|
||||||
def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
grouped_tests: Dict[str, Any] = defaultdict(
|
grouped_tests: dict[str, Any] = defaultdict(
|
||||||
lambda: defaultdict(
|
lambda: defaultdict(
|
||||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||||
)
|
)
|
||||||
|
|
@ -112,8 +115,8 @@ def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
return grouped_tests
|
return grouped_tests
|
||||||
|
|
||||||
|
|
||||||
def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]:
|
||||||
reruns: Dict[str, Any] = defaultdict(
|
reruns: dict[str, Any] = defaultdict(
|
||||||
lambda: defaultdict(
|
lambda: defaultdict(
|
||||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||||
)
|
)
|
||||||
|
|
@ -136,8 +139,8 @@ def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return reruns
|
return reruns
|
||||||
|
|
||||||
|
|
||||||
def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]:
|
||||||
invoking_file_summary: Dict[str, Any] = defaultdict(
|
invoking_file_summary: dict[str, Any] = defaultdict(
|
||||||
lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0}))
|
lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0}))
|
||||||
)
|
)
|
||||||
for build_name, build in grouped_tests.items():
|
for build_name, build in grouped_tests.items():
|
||||||
|
|
@ -157,7 +160,7 @@ def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
def upload_additional_info(
|
def upload_additional_info(
|
||||||
workflow_run_id: int, workflow_run_attempt: int, test_cases: List[Dict[str, Any]]
|
workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
grouped_tests = group_test_cases(test_cases)
|
grouped_tests = group_test_cases(test_cases)
|
||||||
reruns = get_reruns(grouped_tests)
|
reruns = get_reruns(grouped_tests)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3
|
from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3
|
||||||
|
|
||||||
|
|
||||||
ARTIFACTS = [
|
ARTIFACTS = [
|
||||||
"sccache-stats",
|
"sccache-stats",
|
||||||
"test-jsons",
|
"test-jsons",
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset
|
from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset
|
||||||
|
|
||||||
|
|
@ -23,7 +25,7 @@ def upload_dynamo_perf_stats_to_rockset(
|
||||||
workflow_run_attempt: int,
|
workflow_run_attempt: int,
|
||||||
head_branch: str,
|
head_branch: str,
|
||||||
match_filename: str,
|
match_filename: str,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
match_filename_regex = re.compile(match_filename)
|
match_filename_regex = re.compile(match_filename)
|
||||||
perf_stats = []
|
perf_stats = []
|
||||||
with TemporaryDirectory() as temp_dir:
|
with TemporaryDirectory() as temp_dir:
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Callable, cast, Dict, List, Optional, Set
|
from typing import Any, Callable, cast, Dict, List
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import upload_to_s3
|
from tools.stats.upload_stats_lib import upload_to_s3
|
||||||
|
|
||||||
|
|
||||||
FILTER_OUT_USERS = {
|
FILTER_OUT_USERS = {
|
||||||
"pytorchmergebot",
|
"pytorchmergebot",
|
||||||
"facebook-github-bot",
|
"facebook-github-bot",
|
||||||
|
|
@ -23,9 +25,9 @@ FILTER_OUT_USERS = {
|
||||||
|
|
||||||
def _fetch_url(
|
def _fetch_url(
|
||||||
url: str,
|
url: str,
|
||||||
headers: Dict[str, str],
|
headers: dict[str, str],
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
method: Optional[str] = None,
|
method: str | None = None,
|
||||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
token = os.environ.get("GITHUB_TOKEN")
|
token = os.environ.get("GITHUB_TOKEN")
|
||||||
|
|
@ -49,9 +51,9 @@ def _fetch_url(
|
||||||
|
|
||||||
def fetch_json(
|
def fetch_json(
|
||||||
url: str,
|
url: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||||
if params is not None and len(params) > 0:
|
if params is not None and len(params) > 0:
|
||||||
url += "?" + "&".join(
|
url += "?" + "&".join(
|
||||||
|
|
@ -65,16 +67,16 @@ def fetch_json(
|
||||||
|
|
||||||
def get_external_pr_data(
|
def get_external_pr_data(
|
||||||
start_date: datetime.date, end_date: datetime.date, period_length: int = 1
|
start_date: datetime.date, end_date: datetime.date, period_length: int = 1
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
pr_info = []
|
pr_info = []
|
||||||
period_begin_date = start_date
|
period_begin_date = start_date
|
||||||
|
|
||||||
pr_count = 0
|
pr_count = 0
|
||||||
users: Set[str] = set()
|
users: set[str] = set()
|
||||||
while period_begin_date < end_date:
|
while period_begin_date < end_date:
|
||||||
period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1)
|
period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1)
|
||||||
page = 1
|
page = 1
|
||||||
responses: List[Dict[str, Any]] = []
|
responses: list[dict[str, Any]] = []
|
||||||
while len(responses) > 0 or page == 1:
|
while len(responses) > 0 or page == 1:
|
||||||
response = cast(
|
response = cast(
|
||||||
Dict[str, Any],
|
Dict[str, Any],
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
|
|
||||||
# boto3 is an optional dependency. If it's not installed,
|
# boto3 is an optional dependency. If it's not installed,
|
||||||
# we'll just not emit the metrics.
|
# we'll just not emit the metrics.
|
||||||
# Keeping this logic here so that callers don't have to
|
# Keeping this logic here so that callers don't have to
|
||||||
|
|
@ -65,7 +67,7 @@ class EnvVarMetric:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
global_metrics: Dict[str, Any] = {}
|
global_metrics: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def add_global_metric(metric_name: str, metric_value: Any) -> None:
|
def add_global_metric(metric_name: str, metric_value: Any) -> None:
|
||||||
|
|
@ -79,7 +81,7 @@ def add_global_metric(metric_name: str, metric_value: Any) -> None:
|
||||||
|
|
||||||
def emit_metric(
|
def emit_metric(
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
metrics: Dict[str, Any],
|
metrics: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upload a metric to DynamoDB (and from there, Rockset).
|
Upload a metric to DynamoDB (and from there, Rockset).
|
||||||
|
|
@ -174,7 +176,7 @@ def emit_metric(
|
||||||
print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.")
|
print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.")
|
||||||
|
|
||||||
|
|
||||||
def _convert_float_values_to_decimals(data: Dict[str, Any]) -> Dict[str, Any]:
|
def _convert_float_values_to_decimals(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
# Attempt to recurse
|
# Attempt to recurse
|
||||||
def _helper(o: Any) -> Any:
|
def _helper(o: Any) -> Any:
|
||||||
if isinstance(o, float):
|
if isinstance(o, float):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import (
|
from tools.stats.upload_stats_lib import (
|
||||||
download_s3_artifacts,
|
download_s3_artifacts,
|
||||||
|
|
@ -13,7 +15,7 @@ from tools.stats.upload_stats_lib import (
|
||||||
|
|
||||||
def get_sccache_stats(
|
def get_sccache_stats(
|
||||||
workflow_run_id: int, workflow_run_attempt: int
|
workflow_run_id: int, workflow_run_attempt: int
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
with TemporaryDirectory() as temp_dir:
|
with TemporaryDirectory() as temp_dir:
|
||||||
print("Using temporary directory:", temp_dir)
|
print("Using temporary directory:", temp_dir)
|
||||||
os.chdir(temp_dir)
|
os.chdir(temp_dir)
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import gzip
|
import gzip
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import boto3 # type: ignore[import]
|
import boto3 # type: ignore[import]
|
||||||
import requests
|
import requests
|
||||||
import rockset # type: ignore[import]
|
import rockset # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
|
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
|
||||||
S3_RESOURCE = boto3.resource("s3")
|
S3_RESOURCE = boto3.resource("s3")
|
||||||
|
|
||||||
|
|
@ -21,14 +23,14 @@ MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3
|
||||||
BATCH_SIZE = 5000
|
BATCH_SIZE = 5000
|
||||||
|
|
||||||
|
|
||||||
def _get_request_headers() -> Dict[str, str]:
|
def _get_request_headers() -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
"Authorization": "token " + os.environ["GITHUB_TOKEN"],
|
"Authorization": "token " + os.environ["GITHUB_TOKEN"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]:
|
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]:
|
||||||
"""Get all workflow artifacts with 'test-report' in the name."""
|
"""Get all workflow artifacts with 'test-report' in the name."""
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
|
f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
|
||||||
|
|
@ -78,7 +80,7 @@ def _download_artifact(
|
||||||
|
|
||||||
def download_s3_artifacts(
|
def download_s3_artifacts(
|
||||||
prefix: str, workflow_run_id: int, workflow_run_attempt: int
|
prefix: str, workflow_run_id: int, workflow_run_attempt: int
|
||||||
) -> List[Path]:
|
) -> list[Path]:
|
||||||
bucket = S3_RESOURCE.Bucket("gha-artifacts")
|
bucket = S3_RESOURCE.Bucket("gha-artifacts")
|
||||||
objs = bucket.objects.filter(
|
objs = bucket.objects.filter(
|
||||||
Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
|
Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
|
||||||
|
|
@ -104,7 +106,7 @@ def download_s3_artifacts(
|
||||||
|
|
||||||
def download_gha_artifacts(
|
def download_gha_artifacts(
|
||||||
prefix: str, workflow_run_id: int, workflow_run_attempt: int
|
prefix: str, workflow_run_id: int, workflow_run_attempt: int
|
||||||
) -> List[Path]:
|
) -> list[Path]:
|
||||||
artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
|
artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
|
||||||
paths = []
|
paths = []
|
||||||
for name, url in artifact_urls.items():
|
for name, url in artifact_urls.items():
|
||||||
|
|
@ -114,7 +116,7 @@ def download_gha_artifacts(
|
||||||
|
|
||||||
def upload_to_rockset(
|
def upload_to_rockset(
|
||||||
collection: str,
|
collection: str,
|
||||||
docs: List[Any],
|
docs: list[Any],
|
||||||
workspace: str = "commons",
|
workspace: str = "commons",
|
||||||
client: Any = None,
|
client: Any = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -142,7 +144,7 @@ def upload_to_rockset(
|
||||||
def upload_to_s3(
|
def upload_to_s3(
|
||||||
bucket_name: str,
|
bucket_name: str,
|
||||||
key: str,
|
key: str,
|
||||||
docs: List[Dict[str, Any]],
|
docs: list[dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
print(f"Writing {len(docs)} documents to S3")
|
print(f"Writing {len(docs)} documents to S3")
|
||||||
body = io.StringIO()
|
body = io.StringIO()
|
||||||
|
|
@ -164,7 +166,7 @@ def upload_to_s3(
|
||||||
def read_from_s3(
|
def read_from_s3(
|
||||||
bucket_name: str,
|
bucket_name: str,
|
||||||
key: str,
|
key: str,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
print(f"Reading from s3://{bucket_name}/{key}")
|
print(f"Reading from s3://{bucket_name}/{key}")
|
||||||
body = (
|
body = (
|
||||||
S3_RESOURCE.Object(
|
S3_RESOURCE.Object(
|
||||||
|
|
@ -182,7 +184,7 @@ def upload_workflow_stats_to_s3(
|
||||||
workflow_run_id: int,
|
workflow_run_id: int,
|
||||||
workflow_run_attempt: int,
|
workflow_run_attempt: int,
|
||||||
collection: str,
|
collection: str,
|
||||||
docs: List[Dict[str, Any]],
|
docs: list[dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
bucket_name = "ossci-raw-job-status"
|
bucket_name = "ossci-raw-job-status"
|
||||||
key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
|
key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
|
||||||
|
|
@ -220,7 +222,7 @@ def unzip(p: Path) -> None:
|
||||||
zip.extractall(unzipped_dir)
|
zip.extractall(unzipped_dir)
|
||||||
|
|
||||||
|
|
||||||
def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
|
def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the test report is coming from rerun_disabled_tests workflow where
|
Check if the test report is coming from rerun_disabled_tests workflow where
|
||||||
each test is run multiple times
|
each test is run multiple times
|
||||||
|
|
@ -231,7 +233,7 @@ def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_job_id(report: Path) -> Optional[int]:
|
def get_job_id(report: Path) -> int | None:
|
||||||
# [Job id in artifacts]
|
# [Job id in artifacts]
|
||||||
# Retrieve the job id from the report path. In our GHA workflows, we append
|
# Retrieve the job id from the report path. In our GHA workflows, we append
|
||||||
# the job id to the end of the report name, so `report` looks like:
|
# the job id to the end of the report name, so `report` looks like:
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,19 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, List, Union
|
from typing import Any
|
||||||
|
|
||||||
import rockset # type: ignore[import]
|
import rockset # type: ignore[import]
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import upload_to_s3
|
from tools.stats.upload_stats_lib import upload_to_s3
|
||||||
|
|
||||||
|
|
||||||
def get_oncall_from_testfile(testfile: str) -> Union[List[str], None]:
|
def get_oncall_from_testfile(testfile: str) -> list[str] | None:
|
||||||
path = f"test/{testfile}"
|
path = f"test/{testfile}"
|
||||||
if not path.endswith(".py"):
|
if not path.endswith(".py"):
|
||||||
path += ".py"
|
path += ".py"
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -5,7 +7,7 @@ import xml.etree.ElementTree as ET
|
||||||
from multiprocessing import cpu_count, Pool
|
from multiprocessing import cpu_count, Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from tools.stats.test_dashboard import upload_additional_info
|
from tools.stats.test_dashboard import upload_additional_info
|
||||||
from tools.stats.upload_stats_lib import (
|
from tools.stats.upload_stats_lib import (
|
||||||
|
|
@ -21,14 +23,14 @@ def parse_xml_report(
|
||||||
report: Path,
|
report: Path,
|
||||||
workflow_id: int,
|
workflow_id: int,
|
||||||
workflow_run_attempt: int,
|
workflow_run_attempt: int,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
||||||
print(f"Parsing {tag}s for test report: {report}")
|
print(f"Parsing {tag}s for test report: {report}")
|
||||||
|
|
||||||
job_id = get_job_id(report)
|
job_id = get_job_id(report)
|
||||||
print(f"Found job id: {job_id}")
|
print(f"Found job id: {job_id}")
|
||||||
|
|
||||||
test_cases: List[Dict[str, Any]] = []
|
test_cases: list[dict[str, Any]] = []
|
||||||
|
|
||||||
root = ET.parse(report)
|
root = ET.parse(report)
|
||||||
for test_case in root.iter(tag):
|
for test_case in root.iter(tag):
|
||||||
|
|
@ -53,9 +55,9 @@ def parse_xml_report(
|
||||||
return test_cases
|
return test_cases
|
||||||
|
|
||||||
|
|
||||||
def process_xml_element(element: ET.Element) -> Dict[str, Any]:
|
def process_xml_element(element: ET.Element) -> dict[str, Any]:
|
||||||
"""Convert a test suite element into a JSON-serializable dict."""
|
"""Convert a test suite element into a JSON-serializable dict."""
|
||||||
ret: Dict[str, Any] = {}
|
ret: dict[str, Any] = {}
|
||||||
|
|
||||||
# Convert attributes directly into dict elements.
|
# Convert attributes directly into dict elements.
|
||||||
# e.g.
|
# e.g.
|
||||||
|
|
@ -110,7 +112,7 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]:
|
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, Any]]:
|
||||||
with TemporaryDirectory() as temp_dir:
|
with TemporaryDirectory() as temp_dir:
|
||||||
print("Using temporary directory:", temp_dir)
|
print("Using temporary directory:", temp_dir)
|
||||||
os.chdir(temp_dir)
|
os.chdir(temp_dir)
|
||||||
|
|
@ -146,7 +148,7 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
|
||||||
|
|
||||||
def get_tests_for_circleci(
|
def get_tests_for_circleci(
|
||||||
workflow_run_id: int, workflow_run_attempt: int
|
workflow_run_id: int, workflow_run_attempt: int
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
# Parse the reports and transform them to JSON
|
# Parse the reports and transform them to JSON
|
||||||
test_cases = []
|
test_cases = []
|
||||||
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
|
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
|
||||||
|
|
@ -159,13 +161,13 @@ def get_tests_for_circleci(
|
||||||
return test_cases
|
return test_cases
|
||||||
|
|
||||||
|
|
||||||
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
"""Group test cases by classname, file, and job_id. We perform the aggregation
|
"""Group test cases by classname, file, and job_id. We perform the aggregation
|
||||||
manually instead of using the `test-suite` XML tag because xmlrunner does
|
manually instead of using the `test-suite` XML tag because xmlrunner does
|
||||||
not produce reliable output for it.
|
not produce reliable output for it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_key(test_case: Dict[str, Any]) -> Any:
|
def get_key(test_case: dict[str, Any]) -> Any:
|
||||||
return (
|
return (
|
||||||
test_case.get("file"),
|
test_case.get("file"),
|
||||||
test_case.get("classname"),
|
test_case.get("classname"),
|
||||||
|
|
@ -176,7 +178,7 @@ def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any
|
||||||
test_case["invoking_file"],
|
test_case["invoking_file"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]:
|
def init_value(test_case: dict[str, Any]) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"file": test_case.get("file"),
|
"file": test_case.get("file"),
|
||||||
"classname": test_case.get("classname"),
|
"classname": test_case.get("classname"),
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import sys
|
||||||
from tools.stats.test_dashboard import upload_additional_info
|
from tools.stats.test_dashboard import upload_additional_info
|
||||||
from tools.stats.upload_test_stats import get_tests
|
from tools.stats.upload_test_stats import get_tests
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
|
parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import argparse
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from gen_operators_yaml import (
|
from gen_operators_yaml import (
|
||||||
|
|
@ -43,10 +42,10 @@ def _mock_load_op_dep_graph():
|
||||||
|
|
||||||
|
|
||||||
class GenOperatorsYAMLTest(unittest.TestCase):
|
class GenOperatorsYAMLTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_filter_creation(self):
|
def test_filter_creation(self) -> None:
|
||||||
filter_func = make_filter_from_options(
|
filter_func = make_filter_from_options(
|
||||||
model_name="abc",
|
model_name="abc",
|
||||||
model_versions=["100", "101"],
|
model_versions=["100", "101"],
|
||||||
|
|
@ -99,7 +98,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||||
len(filtered_configs) == 2
|
len(filtered_configs) == 2
|
||||||
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
||||||
|
|
||||||
def test_verification_success(self):
|
def test_verification_success(self) -> None:
|
||||||
filter_func = make_filter_from_options(
|
filter_func = make_filter_from_options(
|
||||||
model_name="abc",
|
model_name="abc",
|
||||||
model_versions=["100", "101"],
|
model_versions=["100", "101"],
|
||||||
|
|
@ -142,7 +141,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||||
"expected verify_all_specified_present to succeed instead it raised an exception"
|
"expected verify_all_specified_present to succeed instead it raised an exception"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_verification_fail(self):
|
def test_verification_fail(self) -> None:
|
||||||
config = [
|
config = [
|
||||||
{
|
{
|
||||||
"model": {
|
"model": {
|
||||||
|
|
@ -229,7 +228,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
def test_fill_output_with_arguments_not_include_all_overloads(
|
def test_fill_output_with_arguments_not_include_all_overloads(
|
||||||
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
|
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
|
||||||
):
|
) -> None:
|
||||||
parser = argparse.ArgumentParser(description="Generate used operators YAML")
|
parser = argparse.ArgumentParser(description="Generate used operators YAML")
|
||||||
options = get_parser_options(parser)
|
options = get_parser_options(parser)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads
|
||||||
|
|
||||||
|
|
||||||
class GenOplistTest(unittest.TestCase):
|
class GenOplistTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_throw_if_any_op_includes_overloads(self):
|
def test_throw_if_any_op_includes_overloads(self) -> None:
|
||||||
selective_builder = MagicMock()
|
selective_builder = MagicMock()
|
||||||
selective_builder.operators = MagicMock()
|
selective_builder.operators = MagicMock()
|
||||||
selective_builder.operators.items.return_value = [
|
selective_builder.operators.items.return_value = [
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
# For testing specific heuristics
|
# For testing specific heuristics
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||||
|
|
@ -28,14 +30,14 @@ sys.path.remove(str(REPO_ROOT))
|
||||||
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
|
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
|
||||||
|
|
||||||
|
|
||||||
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
|
def mocked_file(contents: dict[Any, Any]) -> io.IOBase:
|
||||||
file_object = io.StringIO()
|
file_object = io.StringIO()
|
||||||
json.dump(contents, file_object)
|
json.dump(contents, file_object)
|
||||||
file_object.seek(0)
|
file_object.seek(0)
|
||||||
return file_object
|
return file_object
|
||||||
|
|
||||||
|
|
||||||
def gen_historical_class_failures() -> Dict[str, Dict[str, float]]:
|
def gen_historical_class_failures() -> dict[str, dict[str, float]]:
|
||||||
return {
|
return {
|
||||||
"file1": {
|
"file1": {
|
||||||
"test1::classA": 0.5,
|
"test1::classA": 0.5,
|
||||||
|
|
@ -80,8 +82,8 @@ class TestHistoricalClassFailureCorrelation(TestTD):
|
||||||
)
|
)
|
||||||
def test_get_prediction_confidence(
|
def test_get_prediction_confidence(
|
||||||
self,
|
self,
|
||||||
historical_class_failures: Dict[str, Dict[str, float]],
|
historical_class_failures: dict[str, dict[str, float]],
|
||||||
changed_files: List[str],
|
changed_files: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
tests_to_prioritize = ALL_TESTS
|
tests_to_prioritize = ALL_TESTS
|
||||||
|
|
||||||
|
|
@ -113,7 +115,7 @@ class TestHistoricalClassFailureCorrelation(TestTD):
|
||||||
class TestParsePrevTests(TestTD):
|
class TestParsePrevTests(TestTD):
|
||||||
@mock.patch("os.path.exists", return_value=False)
|
@mock.patch("os.path.exists", return_value=False)
|
||||||
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
|
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
|
||||||
expected_failing_test_files: Set[str] = set()
|
expected_failing_test_files: set[str] = set()
|
||||||
|
|
||||||
found_tests = get_previous_failures()
|
found_tests = get_previous_failures()
|
||||||
|
|
||||||
|
|
@ -122,7 +124,7 @@ class TestParsePrevTests(TestTD):
|
||||||
@mock.patch("os.path.exists", return_value=True)
|
@mock.patch("os.path.exists", return_value=True)
|
||||||
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
|
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
|
||||||
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
|
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
|
||||||
expected_failing_test_files: Set[str] = set()
|
expected_failing_test_files: set[str] = set()
|
||||||
|
|
||||||
found_tests = get_previous_failures()
|
found_tests = get_previous_failures()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||||
sys.path.append(str(REPO_ROOT))
|
sys.path.append(str(REPO_ROOT))
|
||||||
|
|
@ -13,7 +15,7 @@ sys.path.remove(str(REPO_ROOT))
|
||||||
|
|
||||||
class TestTD(unittest.TestCase):
|
class TestTD(unittest.TestCase):
|
||||||
def assert_test_scores_almost_equal(
|
def assert_test_scores_almost_equal(
|
||||||
self, d1: Dict[TestRun, float], d2: Dict[TestRun, float]
|
self, d1: dict[TestRun, float], d2: dict[TestRun, float]
|
||||||
) -> None:
|
) -> None:
|
||||||
# Check that dictionaries are the same, except for floating point errors
|
# Check that dictionaries are the same, except for floating point errors
|
||||||
self.assertEqual(set(d1.keys()), set(d2.keys()))
|
self.assertEqual(set(d1.keys()), set(d2.keys()))
|
||||||
|
|
@ -24,7 +26,7 @@ class TestTD(unittest.TestCase):
|
||||||
# Create a dummy heuristic class
|
# Create a dummy heuristic class
|
||||||
class Heuristic(interface.HeuristicInterface):
|
class Heuristic(interface.HeuristicInterface):
|
||||||
def get_prediction_confidence(
|
def get_prediction_confidence(
|
||||||
self, tests: List[str]
|
self, tests: list[str]
|
||||||
) -> interface.TestPrioritizations:
|
) -> interface.TestPrioritizations:
|
||||||
# Return junk
|
# Return junk
|
||||||
return interface.TestPrioritizations([], {})
|
return interface.TestPrioritizations([], {})
|
||||||
|
|
@ -259,9 +261,9 @@ class TestTestPrioritizations(TestTD):
|
||||||
class TestAggregatedHeuristics(TestTD):
|
class TestAggregatedHeuristics(TestTD):
|
||||||
def check(
|
def check(
|
||||||
self,
|
self,
|
||||||
tests: List[str],
|
tests: list[str],
|
||||||
test_prioritizations: List[Dict[TestRun, float]],
|
test_prioritizations: list[dict[TestRun, float]],
|
||||||
expected: Dict[TestRun, float],
|
expected: dict[TestRun, float],
|
||||||
) -> None:
|
) -> None:
|
||||||
aggregated_heuristics = interface.AggregatedHeuristics(tests)
|
aggregated_heuristics = interface.AggregatedHeuristics(tests)
|
||||||
for i, test_prioritization in enumerate(test_prioritizations):
|
for i, test_prioritization in enumerate(test_prioritizations):
|
||||||
|
|
@ -429,7 +431,7 @@ class TestAggregatedHeuristicsTestStats(TestTD):
|
||||||
stats3 = aggregator.get_test_stats(TestRun("test3"))
|
stats3 = aggregator.get_test_stats(TestRun("test3"))
|
||||||
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
|
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
|
||||||
|
|
||||||
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None:
|
def assert_valid_dict(dict_contents: dict[str, Any]) -> None:
|
||||||
for key, value in dict_contents.items():
|
for key, value in dict_contents.items():
|
||||||
self.assertTrue(isinstance(key, str))
|
self.assertTrue(isinstance(key, str))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||||
|
|
@ -14,14 +16,14 @@ sys.path.remove(str(REPO_ROOT))
|
||||||
|
|
||||||
class TestHeuristicsUtils(unittest.TestCase):
|
class TestHeuristicsUtils(unittest.TestCase):
|
||||||
def assertDictAlmostEqual(
|
def assertDictAlmostEqual(
|
||||||
self, first: Dict[TestRun, Any], second: Dict[TestRun, Any]
|
self, first: dict[TestRun, Any], second: dict[TestRun, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.assertEqual(first.keys(), second.keys())
|
self.assertEqual(first.keys(), second.keys())
|
||||||
for key in first.keys():
|
for key in first.keys():
|
||||||
self.assertAlmostEqual(first[key], second[key])
|
self.assertAlmostEqual(first[key], second[key])
|
||||||
|
|
||||||
def test_normalize_ratings(self) -> None:
|
def test_normalize_ratings(self) -> None:
|
||||||
ratings: Dict[TestRun, float] = {
|
ratings: dict[TestRun, float] = {
|
||||||
TestRun("test1"): 1,
|
TestRun("test1"): 1,
|
||||||
TestRun("test2"): 2,
|
TestRun("test2"): 2,
|
||||||
TestRun("test3"): 4,
|
TestRun("test3"): 4,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
from typing import Iterator, Optional, Sequence
|
from typing import Iterator, Sequence
|
||||||
|
|
||||||
import tools.setup_helpers.cmake
|
import tools.setup_helpers.cmake
|
||||||
|
|
||||||
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
|
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -79,7 +80,7 @@ class TestCMake(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def env_var(key: str, value: Optional[str]) -> Iterator[None]:
|
def env_var(key: str, value: str | None) -> Iterator[None]:
|
||||||
"""Sets/clears an environment variable within a Python context."""
|
"""Sets/clears an environment variable within a Python context."""
|
||||||
# Get the previous value and then override it.
|
# Get the previous value and then override it.
|
||||||
previous_value = os.environ.get(key)
|
previous_value = os.environ.get(key)
|
||||||
|
|
@ -91,7 +92,7 @@ def env_var(key: str, value: Optional[str]) -> Iterator[None]:
|
||||||
set_env_var(key, previous_value)
|
set_env_var(key, previous_value)
|
||||||
|
|
||||||
|
|
||||||
def set_env_var(key: str, value: Optional[str]) -> None:
|
def set_env_var(key: str, value: str | None) -> None:
|
||||||
"""Sets/clears an environment variable."""
|
"""Sets/clears an environment variable."""
|
||||||
if value is None:
|
if value is None:
|
||||||
os.environ.pop(key, None)
|
os.environ.pop(key, None)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from tools.autograd import gen_autograd_functions, load_derivatives
|
from tools.autograd import gen_autograd_functions, load_derivatives
|
||||||
|
|
||||||
import torchgen.model
|
|
||||||
from torchgen import dest
|
from torchgen import dest
|
||||||
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
|
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
|
||||||
from torchgen.context import native_function_manager
|
from torchgen.context import native_function_manager
|
||||||
|
|
@ -22,6 +21,7 @@ from torchgen.model import (
|
||||||
BackendIndex,
|
BackendIndex,
|
||||||
BackendMetadata,
|
BackendMetadata,
|
||||||
DispatchKey,
|
DispatchKey,
|
||||||
|
FunctionSchema,
|
||||||
Location,
|
Location,
|
||||||
NativeFunction,
|
NativeFunction,
|
||||||
OperatorName,
|
OperatorName,
|
||||||
|
|
@ -32,7 +32,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
class TestCreateDerivative(unittest.TestCase):
|
class TestCreateDerivative(unittest.TestCase):
|
||||||
def test_named_grads(self) -> None:
|
def test_named_grads(self) -> None:
|
||||||
schema = torchgen.model.FunctionSchema.parse(
|
schema = FunctionSchema.parse(
|
||||||
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
||||||
)
|
)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
@ -47,7 +47,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||||
|
|
||||||
def test_non_differentiable_output(self) -> None:
|
def test_non_differentiable_output(self) -> None:
|
||||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
||||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
schema = FunctionSchema.parse(specification)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
||||||
_, differentiability_info = load_derivatives.create_differentiability_info(
|
_, differentiability_info = load_derivatives.create_differentiability_info(
|
||||||
|
|
@ -69,7 +69,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_indexed_grads(self) -> None:
|
def test_indexed_grads(self) -> None:
|
||||||
schema = torchgen.model.FunctionSchema.parse(
|
schema = FunctionSchema.parse(
|
||||||
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
||||||
)
|
)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
@ -84,7 +84,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||||
|
|
||||||
def test_named_grads_and_indexed_grads(self) -> None:
|
def test_named_grads_and_indexed_grads(self) -> None:
|
||||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
||||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
schema = FunctionSchema.parse(specification)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
|
@ -112,7 +112,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||||
class TestGenAutogradFunctions(unittest.TestCase):
|
class TestGenAutogradFunctions(unittest.TestCase):
|
||||||
def test_non_differentiable_output_invalid_type(self) -> None:
|
def test_non_differentiable_output_invalid_type(self) -> None:
|
||||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
||||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
schema = FunctionSchema.parse(specification)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
||||||
_, differentiability_info = load_derivatives.create_differentiability_info(
|
_, differentiability_info = load_derivatives.create_differentiability_info(
|
||||||
|
|
@ -141,7 +141,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
||||||
|
|
||||||
def test_non_differentiable_output_output_differentiability(self) -> None:
|
def test_non_differentiable_output_output_differentiability(self) -> None:
|
||||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
|
||||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
schema = FunctionSchema.parse(specification)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
||||||
_, differentiability_info = load_derivatives.create_differentiability_info(
|
_, differentiability_info = load_derivatives.create_differentiability_info(
|
||||||
|
|
@ -182,7 +182,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
||||||
|
|
||||||
def test_register_bogus_dispatch_key(self) -> None:
|
def test_register_bogus_dispatch_key(self) -> None:
|
||||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
||||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
schema = FunctionSchema.parse(specification)
|
||||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
|
@ -213,17 +213,17 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
||||||
class TestGenSchemaRegistration(unittest.TestCase):
|
class TestGenSchemaRegistration(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.selector = SelectiveBuilder.get_nop_selector()
|
self.selector = SelectiveBuilder.get_nop_selector()
|
||||||
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml(
|
self.custom_native_function, _ = NativeFunction.from_yaml(
|
||||||
{"func": "custom::func() -> bool"},
|
{"func": "custom::func() -> bool"},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
self.fragment_custom_native_function,
|
self.fragment_custom_native_function,
|
||||||
_,
|
_,
|
||||||
) = torchgen.model.NativeFunction.from_yaml(
|
) = NativeFunction.from_yaml(
|
||||||
{"func": "quantized_decomposed::func() -> bool"},
|
{"func": "quantized_decomposed::func() -> bool"},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -285,9 +285,9 @@ TORCH_LIBRARY(custom, m) {
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_3_namespaces_schema_registration_code_valid(self) -> None:
|
def test_3_namespaces_schema_registration_code_valid(self) -> None:
|
||||||
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml(
|
custom2_native_function, _ = NativeFunction.from_yaml(
|
||||||
{"func": "custom2::func() -> bool"},
|
{"func": "custom2::func() -> bool"},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
|
|
@ -320,7 +320,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
|
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
|
||||||
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
|
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
|
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
|
||||||
|
|
@ -328,11 +328,11 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
||||||
"func": "op_2() -> bool",
|
"func": "op_2() -> bool",
|
||||||
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
|
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
|
||||||
},
|
},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
|
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
|
||||||
DispatchKey.CPU: {},
|
DispatchKey.CPU: {},
|
||||||
DispatchKey.QuantizedCPU: {},
|
DispatchKey.QuantizedCPU: {},
|
||||||
}
|
}
|
||||||
|
|
@ -382,9 +382,9 @@ TORCH_API bool kernel_1();
|
||||||
# Test for native_function_generation
|
# Test for native_function_generation
|
||||||
class TestNativeFunctionGeneratrion(unittest.TestCase):
|
class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.native_functions: List[NativeFunction] = []
|
self.native_functions: list[NativeFunction] = []
|
||||||
self.backend_indices: Dict[
|
self.backend_indices: dict[
|
||||||
DispatchKey, Dict[OperatorName, BackendMetadata]
|
DispatchKey, dict[OperatorName, BackendMetadata]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
yaml_entry = """
|
yaml_entry = """
|
||||||
- func: op(Tensor self) -> Tensor
|
- func: op(Tensor self) -> Tensor
|
||||||
|
|
@ -405,7 +405,7 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||||
"dispatch": {"CPU": "kernel_1"},
|
"dispatch": {"CPU": "kernel_1"},
|
||||||
"autogen": "op_2.out",
|
"autogen": "op_2.out",
|
||||||
},
|
},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
|
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
|
||||||
|
|
@ -442,8 +442,8 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||||
# Test for static_dispatch
|
# Test for static_dispatch
|
||||||
class TestStaticDispatchGeneratrion(unittest.TestCase):
|
class TestStaticDispatchGeneratrion(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.backend_indices: Dict[
|
self.backend_indices: dict[
|
||||||
DispatchKey, Dict[OperatorName, BackendMetadata]
|
DispatchKey, dict[OperatorName, BackendMetadata]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
yaml_entry = """
|
yaml_entry = """
|
||||||
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
|
@ -500,9 +500,9 @@ class TestStaticDispatchGeneratrion(unittest.TestCase):
|
||||||
|
|
||||||
# Represents the most basic NativeFunction. Use dataclasses.replace()
|
# Represents the most basic NativeFunction. Use dataclasses.replace()
|
||||||
# to edit for use.
|
# to edit for use.
|
||||||
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
|
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
||||||
{"func": "func() -> bool"},
|
{"func": "func() -> bool"},
|
||||||
loc=torchgen.model.Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Any, List
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
from unittest import main, TestCase
|
from unittest import main, TestCase
|
||||||
|
|
||||||
from tools.alerts.create_alerts import filter_job_names, JobStatus
|
from tools.alerts.create_alerts import filter_job_names, JobStatus
|
||||||
|
|
@ -38,7 +40,7 @@ MOCK_TEST_DATA = [
|
||||||
class TestGitHubPR(TestCase):
|
class TestGitHubPR(TestCase):
|
||||||
# Should fail when jobs are ? ? Fail Fail
|
# Should fail when jobs are ? ? Fail Fail
|
||||||
def test_alert(self) -> None:
|
def test_alert(self) -> None:
|
||||||
modified_data: List[Any] = [{}]
|
modified_data: list[Any] = [{}]
|
||||||
modified_data.append({})
|
modified_data.append({})
|
||||||
modified_data.extend(MOCK_TEST_DATA)
|
modified_data.extend(MOCK_TEST_DATA)
|
||||||
status = JobStatus(JOB_NAME, modified_data)
|
status = JobStatus(JOB_NAME, modified_data)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
|
|
||||||
import expecttest
|
import expecttest
|
||||||
|
|
@ -13,10 +15,11 @@ from torchgen.model import Location, NativeFunction
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
from torchgen.utils import FileManager
|
from torchgen.utils import FileManager
|
||||||
|
|
||||||
|
|
||||||
SPACES = " "
|
SPACES = " "
|
||||||
|
|
||||||
|
|
||||||
def _get_native_function_from_yaml(yaml_obj: Dict[str, object]) -> NativeFunction:
|
def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
|
||||||
native_function, _ = NativeFunction.from_yaml(
|
native_function, _ = NativeFunction.from_yaml(
|
||||||
yaml_obj,
|
yaml_obj,
|
||||||
loc=Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
|
|
@ -33,7 +36,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _test_function_schema_generates_correct_kernel(
|
def _test_function_schema_generates_correct_kernel(
|
||||||
self, obj: Dict[str, Any], expected: str
|
self, obj: dict[str, Any], expected: str
|
||||||
) -> None:
|
) -> None:
|
||||||
func = _get_native_function_from_yaml(obj)
|
func = _get_native_function_from_yaml(obj)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
|
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
|
||||||
from torchgen.gen import LineLoader
|
from torchgen.gen import LineLoader
|
||||||
|
|
||||||
from torchgen.gen_executorch import (
|
from torchgen.gen_executorch import (
|
||||||
ComputeCodegenUnboxedKernels,
|
ComputeCodegenUnboxedKernels,
|
||||||
gen_functions_declarations,
|
gen_functions_declarations,
|
||||||
|
|
@ -24,6 +24,7 @@ from torchgen.model import (
|
||||||
)
|
)
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
TEST_YAML = """
|
TEST_YAML = """
|
||||||
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
|
|
@ -345,7 +346,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
|
||||||
valid_tags=set(),
|
valid_tags=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
|
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
|
||||||
DispatchKey.CPU: {},
|
DispatchKey.CPU: {},
|
||||||
DispatchKey.QuantizedCPU: {},
|
DispatchKey.QuantizedCPU: {},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from torchgen.executorch.api.types import ExecutorchCppSignature
|
||||||
from torchgen.local import parametrize
|
from torchgen.local import parametrize
|
||||||
from torchgen.model import Location, NativeFunction
|
from torchgen.model import Location, NativeFunction
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
||||||
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
|
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
|
||||||
loc=Location(__file__, 1),
|
loc=Location(__file__, 1),
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
# Owner(s): ["module: codegen"]
|
# Owner(s): ["module: codegen"]
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import expecttest
|
import expecttest
|
||||||
|
|
||||||
|
|
@ -29,7 +30,7 @@ class TestGenBackendStubs(expecttest.TestCase):
|
||||||
run(fp.name, "", True)
|
run(fp.name, "", True)
|
||||||
|
|
||||||
def get_errors_from_gen_backend_stubs(
|
def get_errors_from_gen_backend_stubs(
|
||||||
self, yaml_str: str, *, kernels_str: Optional[str] = None
|
self, yaml_str: str, *, kernels_str: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
||||||
fp.write(yaml_str)
|
fp.write(yaml_str)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from torchgen.selective_build.operator import * # noqa: F403
|
|
||||||
from torchgen.model import Location, NativeFunction
|
from torchgen.model import Location, NativeFunction
|
||||||
|
from torchgen.selective_build.operator import * # noqa: F403
|
||||||
from torchgen.selective_build.selector import (
|
from torchgen.selective_build.selector import (
|
||||||
combine_selective_builders,
|
combine_selective_builders,
|
||||||
SelectiveBuilder,
|
SelectiveBuilder,
|
||||||
|
|
@ -9,7 +9,7 @@ from torchgen.selective_build.selector import (
|
||||||
|
|
||||||
|
|
||||||
class TestSelectiveBuild(unittest.TestCase):
|
class TestSelectiveBuild(unittest.TestCase):
|
||||||
def test_selective_build_operator(self):
|
def test_selective_build_operator(self) -> None:
|
||||||
op = SelectiveBuildOperator(
|
op = SelectiveBuildOperator(
|
||||||
"aten::add.int",
|
"aten::add.int",
|
||||||
is_root_operator=True,
|
is_root_operator=True,
|
||||||
|
|
@ -21,7 +21,7 @@ class TestSelectiveBuild(unittest.TestCase):
|
||||||
self.assertFalse(op.is_used_for_training)
|
self.assertFalse(op.is_used_for_training)
|
||||||
self.assertFalse(op.include_all_overloads)
|
self.assertFalse(op.include_all_overloads)
|
||||||
|
|
||||||
def test_selector_factory(self):
|
def test_selector_factory(self) -> None:
|
||||||
yaml_config_v1 = """
|
yaml_config_v1 = """
|
||||||
debug_info:
|
debug_info:
|
||||||
- model1@v100
|
- model1@v100
|
||||||
|
|
@ -132,7 +132,7 @@ operators:
|
||||||
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_operator_combine(self):
|
def test_operator_combine(self) -> None:
|
||||||
op1 = SelectiveBuildOperator(
|
op1 = SelectiveBuildOperator(
|
||||||
"aten::add.int",
|
"aten::add.int",
|
||||||
is_root_operator=True,
|
is_root_operator=True,
|
||||||
|
|
@ -177,7 +177,7 @@ operators:
|
||||||
|
|
||||||
self.assertRaises(Exception, gen_new_op)
|
self.assertRaises(Exception, gen_new_op)
|
||||||
|
|
||||||
def test_training_op_fetch(self):
|
def test_training_op_fetch(self) -> None:
|
||||||
yaml_config = """
|
yaml_config = """
|
||||||
operators:
|
operators:
|
||||||
aten::add.int:
|
aten::add.int:
|
||||||
|
|
@ -194,7 +194,7 @@ operators:
|
||||||
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
|
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
|
||||||
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
|
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
|
||||||
|
|
||||||
def test_kernel_dtypes(self):
|
def test_kernel_dtypes(self) -> None:
|
||||||
yaml_config = """
|
yaml_config = """
|
||||||
kernel_metadata:
|
kernel_metadata:
|
||||||
add_kernel:
|
add_kernel:
|
||||||
|
|
@ -221,7 +221,7 @@ kernel_metadata:
|
||||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
||||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
||||||
|
|
||||||
def test_merge_kernel_dtypes(self):
|
def test_merge_kernel_dtypes(self) -> None:
|
||||||
yaml_config1 = """
|
yaml_config1 = """
|
||||||
kernel_metadata:
|
kernel_metadata:
|
||||||
add_kernel:
|
add_kernel:
|
||||||
|
|
@ -266,7 +266,7 @@ kernel_metadata:
|
||||||
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
|
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
|
||||||
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
|
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
|
||||||
|
|
||||||
def test_all_kernel_dtypes_selected(self):
|
def test_all_kernel_dtypes_selected(self) -> None:
|
||||||
yaml_config = """
|
yaml_config = """
|
||||||
include_all_non_op_selectives: True
|
include_all_non_op_selectives: True
|
||||||
"""
|
"""
|
||||||
|
|
@ -279,7 +279,7 @@ include_all_non_op_selectives: True
|
||||||
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||||
|
|
||||||
def test_custom_namespace_selected_correctly(self):
|
def test_custom_namespace_selected_correctly(self) -> None:
|
||||||
yaml_config = """
|
yaml_config = """
|
||||||
operators:
|
operators:
|
||||||
aten::add.int:
|
aten::add.int:
|
||||||
|
|
@ -301,7 +301,7 @@ operators:
|
||||||
|
|
||||||
|
|
||||||
class TestExecuTorchSelectiveBuild(unittest.TestCase):
|
class TestExecuTorchSelectiveBuild(unittest.TestCase):
|
||||||
def test_et_kernel_selected(self):
|
def test_et_kernel_selected(self) -> None:
|
||||||
yaml_config = """
|
yaml_config = """
|
||||||
et_kernel_metadata:
|
et_kernel_metadata:
|
||||||
aten::add.out:
|
aten::add.out:
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user