PEP585 update - .ci android aten (#145177)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145177
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2025-01-20 16:32:44 -08:00 committed by PyTorch MergeBot
parent 00ffeca1b1
commit b5655d9821
6 changed files with 42 additions and 43 deletions

View File

@ -4,10 +4,9 @@
import os
import shutil
from subprocess import check_call, check_output
from typing import List
def list_dir(path: str) -> List[str]:
def list_dir(path: str) -> list[str]:
"""'
Helper for getting paths for Python
"""

View File

@ -12,7 +12,7 @@ import os
import subprocess
import sys
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union
import boto3
@ -29,7 +29,7 @@ ubuntu18_04_ami = os_amis["ubuntu18_04"]
ubuntu20_04_ami = os_amis["ubuntu20_04"]
def compute_keyfile_path(key_name: Optional[str] = None) -> Tuple[str, str]:
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
if key_name is None:
key_name = os.getenv("AWS_KEY_NAME")
if key_name is None:
@ -98,7 +98,7 @@ class RemoteHost:
self.keyfile_path = keyfile_path
self.login_name = login_name
def _gen_ssh_prefix(self) -> List[str]:
def _gen_ssh_prefix(self) -> list[str]:
return [
"ssh",
"-o",
@ -110,13 +110,13 @@ class RemoteHost:
]
@staticmethod
def _split_cmd(args: Union[str, List[str]]) -> List[str]:
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
return args.split() if isinstance(args, str) else args
def run_ssh_cmd(self, args: Union[str, List[str]]) -> None:
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
def check_ssh_output(self, args: Union[str, List[str]]) -> str:
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
return subprocess.check_output(
self._gen_ssh_prefix() + self._split_cmd(args)
).decode("utf-8")
@ -159,7 +159,7 @@ class RemoteHost:
def using_docker(self) -> bool:
return self.container_id is not None
def run_cmd(self, args: Union[str, List[str]]) -> None:
def run_cmd(self, args: Union[str, list[str]]) -> None:
if not self.using_docker():
return self.run_ssh_cmd(args)
assert self.container_id is not None
@ -180,7 +180,7 @@ class RemoteHost:
if rc != 0:
raise subprocess.CalledProcessError(rc, docker_cmd)
def check_output(self, args: Union[str, List[str]]) -> str:
def check_output(self, args: Union[str, list[str]]) -> str:
if not self.using_docker():
return self.check_ssh_output(args)
assert self.container_id is not None
@ -232,7 +232,7 @@ class RemoteHost:
)
self.download_file(remote_file, local_file)
def list_dir(self, path: str) -> List[str]:
def list_dir(self, path: str) -> list[str]:
return self.check_output(["ls", "-1", path]).split("\n")
@ -360,7 +360,7 @@ def checkout_repo(
branch: str = "main",
url: str,
git_clone_flags: str,
mapping: Dict[str, Tuple[str, str]],
mapping: dict[str, tuple[str, str]],
) -> Optional[str]:
for prefix in mapping:
if not branch.startswith(prefix):
@ -683,7 +683,7 @@ def build_domains(
branch: str = "main",
use_conda: bool = True,
git_clone_flags: str = "",
) -> Tuple[str, str, str, str]:
) -> tuple[str, str, str, str]:
vision_wheel_name = build_torchvision(
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
)
@ -710,7 +710,7 @@ def start_build(
pytorch_build_number: Optional[str] = None,
shallow_clone: bool = True,
enable_mkldnn: bool = False,
) -> Tuple[str, str, str, str, str]:
) -> tuple[str, str, str, str, str]:
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
if host.using_docker() and not use_conda:
print("Auto-selecting conda option for docker images")

View File

@ -6,7 +6,7 @@ import itertools
import os
import re
from pathlib import Path
from typing import Any, List, Tuple
from typing import Any
# We also check that there are [not] cxx11 symbols in libtorch
@ -46,17 +46,17 @@ LIBTORCH_PRE_CXX11_PATTERNS = _apply_libtorch_symbols(PRE_CXX11_SYMBOLS)
@functools.lru_cache(100)
def get_symbols(lib: str) -> List[Tuple[str, str, str]]:
def get_symbols(lib: str) -> list[tuple[str, str, str]]:
from subprocess import check_output
lines = check_output(f'nm "{lib}"|c++filt', shell=True)
return [x.split(" ", 2) for x in lines.decode("latin1").split("\n")[:-1]]
def grep_symbols(lib: str, patterns: List[Any]) -> List[str]:
def grep_symbols(lib: str, patterns: list[Any]) -> list[str]:
def _grep_symbols(
symbols: List[Tuple[str, str, str]], patterns: List[Any]
) -> List[str]:
symbols: list[tuple[str, str, str]], patterns: list[Any]
) -> list[str]:
rc = []
for _s_addr, _s_type, s_name in symbols:
for pattern in patterns:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Optional
import torch
from torch import Tensor
@ -44,33 +44,33 @@ class Test(torch.jit.ScriptModule):
return input
@torch.jit.script_method
def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
return input
@torch.jit.script_method
def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
return input
@torch.jit.script_method
def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
return input
@torch.jit.script_method
def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def listBoolConjunction(self, input: List[bool]) -> bool:
def listBoolConjunction(self, input: list[bool]) -> bool:
res = True
for x in input:
res = res and x
return res
@torch.jit.script_method
def listBoolDisjunction(self, input: List[bool]) -> bool:
def listBoolDisjunction(self, input: list[bool]) -> bool:
res = False
for x in input:
res = res or x
@ -78,8 +78,8 @@ class Test(torch.jit.ScriptModule):
@torch.jit.script_method
def tupleIntSumReturnTuple(
self, input: Tuple[int, int, int]
) -> Tuple[Tuple[int, int, int], int]:
self, input: tuple[int, int, int]
) -> tuple[tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
@ -104,7 +104,7 @@ class Test(torch.jit.ScriptModule):
return torch.tensor([int(input.item())])[0]
@torch.jit.script_method
def testAliasWithOffset(self) -> List[Tensor]:
def testAliasWithOffset(self) -> list[Tensor]:
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a

View File

@ -5,7 +5,7 @@ import argparse
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import Optional
DTYPE_MAP = {
@ -61,7 +61,7 @@ class Kernel:
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
def get_all_kernels() -> List[Kernel]:
def get_all_kernels() -> list[Kernel]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
for direction in ["fwd", "bwd", "fwd_split"]:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)

View File

@ -13,7 +13,7 @@ import collections
import itertools
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar
from typing import Optional, TypeVar
DTYPES = {
@ -48,10 +48,10 @@ KERNEL_IMPL_TEMPLATE = """__global__ void __launch_bounds__(
@dataclass(order=True)
class FwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
sort_index: tuple[int, ...] = field(init=False, repr=False)
aligned: bool
dtype: str
sm_range: Tuple[int, int]
sm_range: tuple[int, int]
q: int
k: int
max_k: int
@ -114,8 +114,8 @@ class FwdKernel:
)
@classmethod
def get_all(cls) -> List["FwdKernel"]:
kernels: List[FwdKernel] = []
def get_all(cls) -> list["FwdKernel"]:
kernels: list[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), zip(SM, SM[1:])
):
@ -145,8 +145,8 @@ class FwdKernel:
@dataclass(order=True)
class BwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
sm_range: Tuple[int, int]
sort_index: tuple[int, ...] = field(init=False, repr=False)
sm_range: tuple[int, int]
dtype: str
aligned: bool
apply_dropout: bool
@ -223,8 +223,8 @@ class BwdKernel:
)
@classmethod
def get_all(cls) -> List["BwdKernel"]:
kernels: List[BwdKernel] = []
def get_all(cls) -> list["BwdKernel"]:
kernels: list[BwdKernel] = []
for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product(
[True, False],
DTYPES.keys(),
@ -304,7 +304,7 @@ T = TypeVar("T", FwdKernel, BwdKernel)
def write_decl_impl(
kernels: List[T],
kernels: list[T],
family_name: str,
impl_file: str,
autogen_dir: Path,
@ -322,8 +322,8 @@ def write_decl_impl(
kernels.sort()
implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list)
cat_to_kernels: Dict[Tuple[str, int, int], List[T]] = collections.defaultdict(list)
implfile_to_kernels: dict[str, list[T]] = collections.defaultdict(list)
cat_to_kernels: dict[tuple[str, int, int], list[T]] = collections.defaultdict(list)
dispatch_all = ""
declarations = cpp_file_header + "#pragma once\n"