mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585 update - .ci android aten (#145177)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145177 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
00ffeca1b1
commit
b5655d9821
|
|
@ -4,10 +4,9 @@
|
|||
import os
|
||||
import shutil
|
||||
from subprocess import check_call, check_output
|
||||
from typing import List
|
||||
|
||||
|
||||
def list_dir(path: str) -> List[str]:
|
||||
def list_dir(path: str) -> list[str]:
|
||||
"""'
|
||||
Helper for getting paths for Python
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ ubuntu18_04_ami = os_amis["ubuntu18_04"]
|
|||
ubuntu20_04_ami = os_amis["ubuntu20_04"]
|
||||
|
||||
|
||||
def compute_keyfile_path(key_name: Optional[str] = None) -> Tuple[str, str]:
|
||||
def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]:
|
||||
if key_name is None:
|
||||
key_name = os.getenv("AWS_KEY_NAME")
|
||||
if key_name is None:
|
||||
|
|
@ -98,7 +98,7 @@ class RemoteHost:
|
|||
self.keyfile_path = keyfile_path
|
||||
self.login_name = login_name
|
||||
|
||||
def _gen_ssh_prefix(self) -> List[str]:
|
||||
def _gen_ssh_prefix(self) -> list[str]:
|
||||
return [
|
||||
"ssh",
|
||||
"-o",
|
||||
|
|
@ -110,13 +110,13 @@ class RemoteHost:
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def _split_cmd(args: Union[str, List[str]]) -> List[str]:
|
||||
def _split_cmd(args: Union[str, list[str]]) -> list[str]:
|
||||
return args.split() if isinstance(args, str) else args
|
||||
|
||||
def run_ssh_cmd(self, args: Union[str, List[str]]) -> None:
|
||||
def run_ssh_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args))
|
||||
|
||||
def check_ssh_output(self, args: Union[str, List[str]]) -> str:
|
||||
def check_ssh_output(self, args: Union[str, list[str]]) -> str:
|
||||
return subprocess.check_output(
|
||||
self._gen_ssh_prefix() + self._split_cmd(args)
|
||||
).decode("utf-8")
|
||||
|
|
@ -159,7 +159,7 @@ class RemoteHost:
|
|||
def using_docker(self) -> bool:
|
||||
return self.container_id is not None
|
||||
|
||||
def run_cmd(self, args: Union[str, List[str]]) -> None:
|
||||
def run_cmd(self, args: Union[str, list[str]]) -> None:
|
||||
if not self.using_docker():
|
||||
return self.run_ssh_cmd(args)
|
||||
assert self.container_id is not None
|
||||
|
|
@ -180,7 +180,7 @@ class RemoteHost:
|
|||
if rc != 0:
|
||||
raise subprocess.CalledProcessError(rc, docker_cmd)
|
||||
|
||||
def check_output(self, args: Union[str, List[str]]) -> str:
|
||||
def check_output(self, args: Union[str, list[str]]) -> str:
|
||||
if not self.using_docker():
|
||||
return self.check_ssh_output(args)
|
||||
assert self.container_id is not None
|
||||
|
|
@ -232,7 +232,7 @@ class RemoteHost:
|
|||
)
|
||||
self.download_file(remote_file, local_file)
|
||||
|
||||
def list_dir(self, path: str) -> List[str]:
|
||||
def list_dir(self, path: str) -> list[str]:
|
||||
return self.check_output(["ls", "-1", path]).split("\n")
|
||||
|
||||
|
||||
|
|
@ -360,7 +360,7 @@ def checkout_repo(
|
|||
branch: str = "main",
|
||||
url: str,
|
||||
git_clone_flags: str,
|
||||
mapping: Dict[str, Tuple[str, str]],
|
||||
mapping: dict[str, tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
for prefix in mapping:
|
||||
if not branch.startswith(prefix):
|
||||
|
|
@ -683,7 +683,7 @@ def build_domains(
|
|||
branch: str = "main",
|
||||
use_conda: bool = True,
|
||||
git_clone_flags: str = "",
|
||||
) -> Tuple[str, str, str, str]:
|
||||
) -> tuple[str, str, str, str]:
|
||||
vision_wheel_name = build_torchvision(
|
||||
host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags
|
||||
)
|
||||
|
|
@ -710,7 +710,7 @@ def start_build(
|
|||
pytorch_build_number: Optional[str] = None,
|
||||
shallow_clone: bool = True,
|
||||
enable_mkldnn: bool = False,
|
||||
) -> Tuple[str, str, str, str, str]:
|
||||
) -> tuple[str, str, str, str, str]:
|
||||
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
|
||||
if host.using_docker() and not use_conda:
|
||||
print("Auto-selecting conda option for docker images")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
|
||||
# We also check that there are [not] cxx11 symbols in libtorch
|
||||
|
|
@ -46,17 +46,17 @@ LIBTORCH_PRE_CXX11_PATTERNS = _apply_libtorch_symbols(PRE_CXX11_SYMBOLS)
|
|||
|
||||
|
||||
@functools.lru_cache(100)
|
||||
def get_symbols(lib: str) -> List[Tuple[str, str, str]]:
|
||||
def get_symbols(lib: str) -> list[tuple[str, str, str]]:
|
||||
from subprocess import check_output
|
||||
|
||||
lines = check_output(f'nm "{lib}"|c++filt', shell=True)
|
||||
return [x.split(" ", 2) for x in lines.decode("latin1").split("\n")[:-1]]
|
||||
|
||||
|
||||
def grep_symbols(lib: str, patterns: List[Any]) -> List[str]:
|
||||
def grep_symbols(lib: str, patterns: list[Any]) -> list[str]:
|
||||
def _grep_symbols(
|
||||
symbols: List[Tuple[str, str, str]], patterns: List[Any]
|
||||
) -> List[str]:
|
||||
symbols: list[tuple[str, str, str]], patterns: list[Any]
|
||||
) -> list[str]:
|
||||
rc = []
|
||||
for _s_addr, _s_type, s_name in symbols:
|
||||
for pattern in patterns:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -44,33 +44,33 @@ class Test(torch.jit.ScriptModule):
|
|||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]:
|
||||
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]:
|
||||
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]:
|
||||
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
|
||||
return input
|
||||
|
||||
@torch.jit.script_method
|
||||
def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]:
|
||||
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
return (input, sum)
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolConjunction(self, input: List[bool]) -> bool:
|
||||
def listBoolConjunction(self, input: list[bool]) -> bool:
|
||||
res = True
|
||||
for x in input:
|
||||
res = res and x
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def listBoolDisjunction(self, input: List[bool]) -> bool:
|
||||
def listBoolDisjunction(self, input: list[bool]) -> bool:
|
||||
res = False
|
||||
for x in input:
|
||||
res = res or x
|
||||
|
|
@ -78,8 +78,8 @@ class Test(torch.jit.ScriptModule):
|
|||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: Tuple[int, int, int]
|
||||
) -> Tuple[Tuple[int, int, int], int]:
|
||||
self, input: tuple[int, int, int]
|
||||
) -> tuple[tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
|
|
@ -104,7 +104,7 @@ class Test(torch.jit.ScriptModule):
|
|||
return torch.tensor([int(input.item())])[0]
|
||||
|
||||
@torch.jit.script_method
|
||||
def testAliasWithOffset(self) -> List[Tensor]:
|
||||
def testAliasWithOffset(self) -> list[Tensor]:
|
||||
x = torch.tensor([100, 200])
|
||||
a = [x[0], x[1]]
|
||||
return a
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import argparse
|
|||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
|
||||
DTYPE_MAP = {
|
||||
|
|
@ -61,7 +61,7 @@ class Kernel:
|
|||
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
|
||||
|
||||
|
||||
def get_all_kernels() -> List[Kernel]:
|
||||
def get_all_kernels() -> list[Kernel]:
|
||||
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
|
||||
for direction in ["fwd", "bwd", "fwd_split"]:
|
||||
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import collections
|
|||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, TypeVar
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
|
||||
DTYPES = {
|
||||
|
|
@ -48,10 +48,10 @@ KERNEL_IMPL_TEMPLATE = """__global__ void __launch_bounds__(
|
|||
|
||||
@dataclass(order=True)
|
||||
class FwdKernel:
|
||||
sort_index: Tuple[int, ...] = field(init=False, repr=False)
|
||||
sort_index: tuple[int, ...] = field(init=False, repr=False)
|
||||
aligned: bool
|
||||
dtype: str
|
||||
sm_range: Tuple[int, int]
|
||||
sm_range: tuple[int, int]
|
||||
q: int
|
||||
k: int
|
||||
max_k: int
|
||||
|
|
@ -114,8 +114,8 @@ class FwdKernel:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> List["FwdKernel"]:
|
||||
kernels: List[FwdKernel] = []
|
||||
def get_all(cls) -> list["FwdKernel"]:
|
||||
kernels: list[FwdKernel] = []
|
||||
for aligned, dtype, (sm, sm_max) in itertools.product(
|
||||
[True, False], DTYPES.keys(), zip(SM, SM[1:])
|
||||
):
|
||||
|
|
@ -145,8 +145,8 @@ class FwdKernel:
|
|||
|
||||
@dataclass(order=True)
|
||||
class BwdKernel:
|
||||
sort_index: Tuple[int, ...] = field(init=False, repr=False)
|
||||
sm_range: Tuple[int, int]
|
||||
sort_index: tuple[int, ...] = field(init=False, repr=False)
|
||||
sm_range: tuple[int, int]
|
||||
dtype: str
|
||||
aligned: bool
|
||||
apply_dropout: bool
|
||||
|
|
@ -223,8 +223,8 @@ class BwdKernel:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> List["BwdKernel"]:
|
||||
kernels: List[BwdKernel] = []
|
||||
def get_all(cls) -> list["BwdKernel"]:
|
||||
kernels: list[BwdKernel] = []
|
||||
for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product(
|
||||
[True, False],
|
||||
DTYPES.keys(),
|
||||
|
|
@ -304,7 +304,7 @@ T = TypeVar("T", FwdKernel, BwdKernel)
|
|||
|
||||
|
||||
def write_decl_impl(
|
||||
kernels: List[T],
|
||||
kernels: list[T],
|
||||
family_name: str,
|
||||
impl_file: str,
|
||||
autogen_dir: Path,
|
||||
|
|
@ -322,8 +322,8 @@ def write_decl_impl(
|
|||
|
||||
kernels.sort()
|
||||
|
||||
implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list)
|
||||
cat_to_kernels: Dict[Tuple[str, int, int], List[T]] = collections.defaultdict(list)
|
||||
implfile_to_kernels: dict[str, list[T]] = collections.defaultdict(list)
|
||||
cat_to_kernels: dict[tuple[str, int, int], list[T]] = collections.defaultdict(list)
|
||||
|
||||
dispatch_all = ""
|
||||
declarations = cpp_file_header + "#pragma once\n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user