From b5655d9821b7214af200d0b8796a10ad34b85229 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Mon, 20 Jan 2025 16:32:44 -0800 Subject: [PATCH] PEP585 update - .ci android aten (#145177) See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145177 Approved by: https://github.com/Skylion007 --- .ci/aarch64_linux/aarch64_wheel_ci_build.py | 3 +-- .ci/aarch64_linux/build_aarch64_wheel.py | 24 +++++++++---------- .../smoke_test/check_binary_symbols.py | 10 ++++---- .../generate_test_torchscripts.py | 20 ++++++++-------- .../flash_attn/kernels/generate_kernels.py | 4 ++-- .../kernels/generate_kernels.py | 24 +++++++++---------- 6 files changed, 42 insertions(+), 43 deletions(-) diff --git a/.ci/aarch64_linux/aarch64_wheel_ci_build.py b/.ci/aarch64_linux/aarch64_wheel_ci_build.py index 8cc0a8aea29..8b68d4963ed 100755 --- a/.ci/aarch64_linux/aarch64_wheel_ci_build.py +++ b/.ci/aarch64_linux/aarch64_wheel_ci_build.py @@ -4,10 +4,9 @@ import os import shutil from subprocess import check_call, check_output -from typing import List -def list_dir(path: str) -> List[str]: +def list_dir(path: str) -> list[str]: """' Helper for getting paths for Python """ diff --git a/.ci/aarch64_linux/build_aarch64_wheel.py b/.ci/aarch64_linux/build_aarch64_wheel.py index e427fe924d0..874511b891e 100755 --- a/.ci/aarch64_linux/build_aarch64_wheel.py +++ b/.ci/aarch64_linux/build_aarch64_wheel.py @@ -12,7 +12,7 @@ import os import subprocess import sys import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import boto3 @@ -29,7 +29,7 @@ ubuntu18_04_ami = os_amis["ubuntu18_04"] ubuntu20_04_ami = os_amis["ubuntu20_04"] -def compute_keyfile_path(key_name: Optional[str] = None) -> Tuple[str, str]: +def compute_keyfile_path(key_name: Optional[str] = None) -> tuple[str, str]: if key_name is None: key_name = os.getenv("AWS_KEY_NAME") if key_name is None: @@ -98,7 +98,7 @@ class RemoteHost: self.keyfile_path = keyfile_path self.login_name = login_name - def _gen_ssh_prefix(self) -> List[str]: + def _gen_ssh_prefix(self) -> list[str]: return [ "ssh", "-o", @@ -110,13 +110,13 @@ class RemoteHost: ] @staticmethod - def _split_cmd(args: Union[str, List[str]]) -> List[str]: + def _split_cmd(args: Union[str, list[str]]) -> list[str]: return args.split() if isinstance(args, str) else args - def run_ssh_cmd(self, args: Union[str, List[str]]) -> None: + def run_ssh_cmd(self, args: Union[str, list[str]]) -> None: subprocess.check_call(self._gen_ssh_prefix() + self._split_cmd(args)) - def check_ssh_output(self, args: Union[str, List[str]]) -> str: + def check_ssh_output(self, args: Union[str, list[str]]) -> str: return subprocess.check_output( self._gen_ssh_prefix() + self._split_cmd(args) ).decode("utf-8") @@ -159,7 +159,7 @@ class RemoteHost: def using_docker(self) -> bool: return self.container_id is not None - def run_cmd(self, args: Union[str, List[str]]) -> None: + def run_cmd(self, args: Union[str, list[str]]) -> None: if not self.using_docker(): return self.run_ssh_cmd(args) assert self.container_id is not None @@ -180,7 +180,7 @@ class RemoteHost: if rc != 0: raise subprocess.CalledProcessError(rc, docker_cmd) - def check_output(self, args: Union[str, List[str]]) -> str: + def check_output(self, args: Union[str, list[str]]) -> str: if not self.using_docker(): return self.check_ssh_output(args) assert self.container_id is not None @@ -232,7 +232,7 @@ class RemoteHost: ) self.download_file(remote_file, local_file) - def list_dir(self, path: str) -> List[str]: + def list_dir(self, path: str) -> list[str]: return self.check_output(["ls", "-1", path]).split("\n") @@ -360,7 +360,7 @@ def checkout_repo( branch: str = "main", url: str, git_clone_flags: str, - mapping: Dict[str, Tuple[str, str]], + mapping: dict[str, tuple[str, str]], ) -> Optional[str]: for prefix in mapping: if not branch.startswith(prefix): @@ -683,7 +683,7 @@ def build_domains( branch: str = "main", use_conda: bool = True, git_clone_flags: str = "", -) -> Tuple[str, str, str, str]: +) -> tuple[str, str, str, str]: vision_wheel_name = build_torchvision( host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags ) @@ -710,7 +710,7 @@ def start_build( pytorch_build_number: Optional[str] = None, shallow_clone: bool = True, enable_mkldnn: bool = False, -) -> Tuple[str, str, str, str, str]: +) -> tuple[str, str, str, str, str]: git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else "" if host.using_docker() and not use_conda: print("Auto-selecting conda option for docker images") diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index e91d0f680f1..97d6482d63b 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -6,7 +6,7 @@ import itertools import os import re from pathlib import Path -from typing import Any, List, Tuple +from typing import Any # We also check that there are [not] cxx11 symbols in libtorch @@ -46,17 +46,17 @@ LIBTORCH_PRE_CXX11_PATTERNS = _apply_libtorch_symbols(PRE_CXX11_SYMBOLS) @functools.lru_cache(100) -def get_symbols(lib: str) -> List[Tuple[str, str, str]]: +def get_symbols(lib: str) -> list[tuple[str, str, str]]: from subprocess import check_output lines = check_output(f'nm "{lib}"|c++filt', shell=True) return [x.split(" ", 2) for x in lines.decode("latin1").split("\n")[:-1]] -def grep_symbols(lib: str, patterns: List[Any]) -> List[str]: +def grep_symbols(lib: str, patterns: list[Any]) -> list[str]: def _grep_symbols( - symbols: List[Tuple[str, str, str]], patterns: List[Any] - ) -> List[str]: + symbols: list[tuple[str, str, str]], patterns: list[Any] + ) -> list[str]: rc = [] for _s_addr, _s_type, s_name in symbols: for pattern in patterns: diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py index 9c548740968..55622da8926 100644 --- a/android/pytorch_android/generate_test_torchscripts.py +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -44,33 +44,33 @@ class Test(torch.jit.ScriptModule): return input @torch.jit.script_method - def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: + def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]: return input @torch.jit.script_method - def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: + def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]: return input @torch.jit.script_method - def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: + def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]: return input @torch.jit.script_method - def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: + def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def listBoolConjunction(self, input: List[bool]) -> bool: + def listBoolConjunction(self, input: list[bool]) -> bool: res = True for x in input: res = res and x return res @torch.jit.script_method - def listBoolDisjunction(self, input: List[bool]) -> bool: + def listBoolDisjunction(self, input: list[bool]) -> bool: res = False for x in input: res = res or x @@ -78,8 +78,8 @@ class Test(torch.jit.ScriptModule): @torch.jit.script_method def tupleIntSumReturnTuple( - self, input: Tuple[int, int, int] - ) -> Tuple[Tuple[int, int, int], int]: + self, input: tuple[int, int, int] + ) -> tuple[tuple[int, int, int], int]: sum = 0 for x in input: sum += x @@ -104,7 +104,7 @@ class Test(torch.jit.ScriptModule): return torch.tensor([int(input.item())])[0] @torch.jit.script_method - def testAliasWithOffset(self) -> List[Tensor]: + def testAliasWithOffset(self) -> list[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py index b125d431f49..a9276c98e65 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py @@ -5,7 +5,7 @@ import argparse import itertools from dataclasses import dataclass from pathlib import Path -from typing import List, Optional +from typing import Optional DTYPE_MAP = { @@ -61,7 +61,7 @@ class Kernel: return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu" -def get_all_kernels() -> List[Kernel]: +def get_all_kernels() -> list[Kernel]: for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): for direction in ["fwd", "bwd", "fwd_split"]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction) diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py index 74df83f85f1..2ef59f42140 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py @@ -13,7 +13,7 @@ import collections import itertools from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional, Tuple, TypeVar +from typing import Optional, TypeVar DTYPES = { @@ -48,10 +48,10 @@ KERNEL_IMPL_TEMPLATE = """__global__ void __launch_bounds__( @dataclass(order=True) class FwdKernel: - sort_index: Tuple[int, ...] = field(init=False, repr=False) + sort_index: tuple[int, ...] = field(init=False, repr=False) aligned: bool dtype: str - sm_range: Tuple[int, int] + sm_range: tuple[int, int] q: int k: int max_k: int @@ -114,8 +114,8 @@ class FwdKernel: ) @classmethod - def get_all(cls) -> List["FwdKernel"]: - kernels: List[FwdKernel] = [] + def get_all(cls) -> list["FwdKernel"]: + kernels: list[FwdKernel] = [] for aligned, dtype, (sm, sm_max) in itertools.product( [True, False], DTYPES.keys(), zip(SM, SM[1:]) ): @@ -145,8 +145,8 @@ class FwdKernel: @dataclass(order=True) class BwdKernel: - sort_index: Tuple[int, ...] = field(init=False, repr=False) - sm_range: Tuple[int, int] + sort_index: tuple[int, ...] = field(init=False, repr=False) + sm_range: tuple[int, int] dtype: str aligned: bool apply_dropout: bool @@ -223,8 +223,8 @@ class BwdKernel: ) @classmethod - def get_all(cls) -> List["BwdKernel"]: - kernels: List[BwdKernel] = [] + def get_all(cls) -> list["BwdKernel"]: + kernels: list[BwdKernel] = [] for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product( [True, False], DTYPES.keys(), @@ -304,7 +304,7 @@ T = TypeVar("T", FwdKernel, BwdKernel) def write_decl_impl( - kernels: List[T], + kernels: list[T], family_name: str, impl_file: str, autogen_dir: Path, @@ -322,8 +322,8 @@ def write_decl_impl( kernels.sort() - implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list) - cat_to_kernels: Dict[Tuple[str, int, int], List[T]] = collections.defaultdict(list) + implfile_to_kernels: dict[str, list[T]] = collections.defaultdict(list) + cat_to_kernels: dict[tuple[str, int, int], list[T]] = collections.defaultdict(list) dispatch_all = "" declarations = cpp_file_header + "#pragma once\n"