Migrate from Tuple -> tuple in benchmarks (#144259)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144259
Approved by: https://github.com/yanboliang
This commit is contained in:
bobrenjc93 2025-01-06 08:45:43 -08:00 committed by PyTorch MergeBot
parent 2e42be0595
commit fcf9dc3b11
12 changed files with 77 additions and 80 deletions

View File

@ -32,7 +32,6 @@ from typing import (
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
)
@ -746,7 +745,7 @@ def timed(
return (time_total, result) if return_result else time_total
def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]:
# NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
# and consumed like `model(**example_inputs)`.
# For other benchmarks, example_inputs are formatted as tuple and consumed

View File

@ -4,7 +4,7 @@ import math
import os
from collections import Counter, defaultdict
from functools import partial
from typing import Any, Dict, Generator, Iterable, Tuple
from typing import Any, Dict, Generator, Iterable
import torch
from torch.testing import make_tensor
@ -263,7 +263,7 @@ class OperatorInputsLoader:
def get_inputs_for_operator(
self, operator, dtype=None, device="cuda"
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]:
) -> Generator[tuple[Iterable[Any], Dict[str, Any]], None, None]:
assert (
str(operator) in self.operator_db
), f"Could not find {operator}, must provide overload"

View File

@ -1,5 +1,3 @@
from typing import Tuple
import torch
from torch import Tensor
@ -27,12 +25,12 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
def lstm_cell(
input: Tensor,
hidden: Tuple[Tensor, Tensor],
hidden: tuple[Tensor, Tensor],
w_ih: Tensor,
w_hh: Tensor,
b_ih: Tensor,
b_hh: Tensor,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
hx, cx = hidden
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
@ -57,7 +55,7 @@ def flat_lstm_cell(
w_hh: Tensor,
b_ih: Tensor,
b_hh: Tensor,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
@ -75,11 +73,11 @@ def flat_lstm_cell(
def premul_lstm_cell(
igates: Tensor,
hidden: Tuple[Tensor, Tensor],
hidden: tuple[Tensor, Tensor],
w_hh: Tensor,
b_ih: Tensor,
b_hh: Tensor,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
hx, cx = hidden
gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
@ -97,8 +95,8 @@ def premul_lstm_cell(
def premul_lstm_cell_no_bias(
igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
) -> Tuple[Tensor, Tensor]:
igates: Tensor, hidden: tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
) -> tuple[Tensor, Tensor]:
hx, cx = hidden
gates = igates + torch.mm(hx, w_hh.t()) + b_hh

View File

@ -1,7 +1,7 @@
import numbers
import warnings
from collections import namedtuple
from typing import List, Tuple
from typing import List
import torch
import torch.jit as jit
@ -131,8 +131,8 @@ class LSTMCell(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, state: Tuple[Tensor, Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
self, input: Tensor, state: tuple[Tensor, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
hx, cx = state
gates = (
torch.mm(input, self.weight_ih.t())
@ -199,8 +199,8 @@ class LayerNormLSTMCell(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, state: Tuple[Tensor, Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
self, input: Tensor, state: tuple[Tensor, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
hx, cx = state
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
@ -225,8 +225,8 @@ class LSTMLayer(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, state: Tuple[Tensor, Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
self, input: Tensor, state: tuple[Tensor, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
@ -242,8 +242,8 @@ class ReverseLSTMLayer(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, state: Tuple[Tensor, Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
self, input: Tensor, state: tuple[Tensor, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
inputs = reverse(input.unbind(0))
outputs = jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
@ -266,11 +266,11 @@ class BidirLSTMLayer(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
# List[LSTMState]: [forward LSTMState, backward LSTMState]
outputs = jit.annotate(List[Tensor], [])
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for direction in self.directions:
@ -300,10 +300,10 @@ class StackedLSTM(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
@ -330,11 +330,11 @@ class StackedLSTM2(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]
) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
self, input: Tensor, states: List[List[tuple[Tensor, Tensor]]]
) -> tuple[Tensor, List[List[tuple[Tensor, Tensor]]]]:
# List[List[LSTMState]]: The outer list is for layers,
# inner list is for directions.
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
output_states = jit.annotate(List[List[tuple[Tensor, Tensor]]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
@ -370,10 +370,10 @@ class StackedLSTMWithDropout(jit.ScriptModule):
@jit.script_method
def forward(
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0

View File

@ -1,5 +1,5 @@
from collections import namedtuple
from typing import List, Tuple
from typing import List
import torch
from torch import Tensor
@ -266,12 +266,12 @@ def varlen_pytorch_lstm_creator(**kwargs):
def varlen_lstm_factory(cell, script):
def dynamic_rnn(
sequences: List[Tensor],
hiddens: Tuple[Tensor, Tensor],
hiddens: tuple[Tensor, Tensor],
wih: Tensor,
whh: Tensor,
bih: Tensor,
bhh: Tensor,
) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]:
) -> tuple[List[Tensor], tuple[List[Tensor], List[Tensor]]]:
hx, cx = hiddens
hxs = hx.unbind(1)
cxs = cx.unbind(1)
@ -406,12 +406,12 @@ def lstm_inputs(
def lstm_factory(cell, script):
def dynamic_rnn(
input: Tensor,
hidden: Tuple[Tensor, Tensor],
hidden: tuple[Tensor, Tensor],
wih: Tensor,
whh: Tensor,
bih: Tensor,
bhh: Tensor,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
hx, cx = hidden
outputs = []
inputs = input.unbind(0)
@ -432,12 +432,12 @@ def lstm_factory(cell, script):
def lstm_factory_premul(premul_cell, script):
def dynamic_rnn(
input: Tensor,
hidden: Tuple[Tensor, Tensor],
hidden: tuple[Tensor, Tensor],
wih: Tensor,
whh: Tensor,
bih: Tensor,
bhh: Tensor,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
hx, cx = hidden
outputs = []
inputs = torch.matmul(input, wih.t()).unbind(0)
@ -458,12 +458,12 @@ def lstm_factory_premul(premul_cell, script):
def lstm_factory_premul_bias(premul_cell, script):
def dynamic_rnn(
input: Tensor,
hidden: Tuple[Tensor, Tensor],
hidden: tuple[Tensor, Tensor],
wih: Tensor,
whh: Tensor,
bih: Tensor,
bhh: Tensor,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
hx, cx = hidden
outputs = []
inpSize = input.size()
@ -506,8 +506,8 @@ def lstm_factory_simple(cell, script):
def lstm_factory_multilayer(cell, script):
def dynamic_rnn(
input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
input: Tensor, hidden: tuple[Tensor, Tensor], params: List[Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
params_stride = 4 # NB: this assumes that biases are there
hx, cx = hidden
hy, cy = hidden # for scoping...

View File

@ -3,7 +3,7 @@
import math
from collections import OrderedDict
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@ -512,7 +512,7 @@ class MultiheadAttentionContainer(torch.nn.Module):
attn_mask: Optional[torch.Tensor] = None,
bias_k: Optional[torch.Tensor] = None,
bias_v: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
query, key, value (Tensor): map a query and a set of key-value pairs to an output.
@ -589,7 +589,7 @@ class ScaledDotProduct(torch.nn.Module):
attn_mask: Optional[torch.Tensor] = None,
bias_k: Optional[torch.Tensor] = None,
bias_v: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
r"""Uses a scaled dot product with the projected key-value pair to update
the projected query.
Args:
@ -686,7 +686,7 @@ class InProjContainer(torch.nn.Module):
def forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Projects the input sequences using in-proj layers.
Args:
query, key, value (Tensors): sequence to be projected

View File

@ -1,20 +1,20 @@
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Union
import torch
from torch import nn, Tensor
# Type helpers
InputsType = Union[Tensor, Tuple[Tensor, ...]]
InputsType = Union[Tensor, tuple[Tensor, ...]]
# A Getter takes in a device and returns a callable and the inputs to that callable
GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
GetterReturnType = tuple[Callable[..., Tensor], InputsType]
GetterType = Callable[[torch.device], GetterReturnType]
# V here refers to the v in either vjp, jvp, vhp or hvp
VType = Union[None, Tensor, Tuple[Tensor, ...]]
VType = Union[None, Tensor, tuple[Tensor, ...]]
# Type used to store timing results. The first key is the model name, the second key
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
TimingResultType = Dict[str, Dict[str, tuple[float, ...]]]
# Utilities to make nn.Module "functional"
@ -44,7 +44,7 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
@ -65,7 +65,7 @@ def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
def load_weights(mod: nn.Module, names: List[str], params: tuple[Tensor, ...]) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
@ -77,7 +77,7 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -
# Utilities to read/write markdown table-like content.
def to_markdown_table(
res: TimingResultType, header: Optional[Tuple[str, ...]] = None
res: TimingResultType, header: Optional[tuple[str, ...]] = None
) -> str:
if header is None:
header = ("model", "task", "mean", "var")

View File

@ -2,7 +2,7 @@ import dataclasses
import itertools
import platform
import time
from typing import Optional, Tuple
from typing import Optional
import torchao
from common import Experiment, register_experiment
@ -89,7 +89,7 @@ def prefill(
def decode_one_token(
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)

View File

@ -8,7 +8,7 @@ import subprocess
import textwrap
import threading
import time
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Union
from worker.main import WorkerFailure, WorkerOutput
@ -55,7 +55,7 @@ class CorePool:
True for _ in range(min_core_id, min_core_id + self._num_cores)
]
self._reservations: Dict[str, Tuple[int, ...]] = {}
self._reservations: Dict[str, tuple[int, ...]] = {}
self._lock = threading.Lock()
def reserve(self, n: int) -> Optional[str]:
@ -87,11 +87,11 @@ class CorePool:
class Runner:
def __init__(
self,
work_items: Tuple[WorkOrder, ...],
work_items: tuple[WorkOrder, ...],
core_pool: Optional[CorePool] = None,
cadence: float = 1.0,
) -> None:
self._work_items: Tuple[WorkOrder, ...] = work_items
self._work_items: tuple[WorkOrder, ...] = work_items
self._core_pool: CorePool = core_pool or CorePool(0, CPU_COUNT - 4)
self._cadence: float = cadence

View File

@ -24,7 +24,7 @@ import pickle
import sys
import timeit
import traceback
from typing import Any, Tuple, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING, Union
if TYPE_CHECKING:
@ -81,8 +81,8 @@ class WorkerTimerArgs:
@dataclasses.dataclass(frozen=True)
class WorkerOutput:
# Only return values to reduce communication between main process and workers.
wall_times: Tuple[float, ...]
instructions: Tuple[int, ...]
wall_times: tuple[float, ...]
instructions: tuple[int, ...]
@dataclasses.dataclass(frozen=True)
@ -145,7 +145,7 @@ def _run(timer_args: WorkerTimerArgs) -> WorkerOutput:
m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind(
stats: tuple[CallgrindStats, ...] = timer.collect_callgrind(
number=CALLGRIND_NUMBER,
collect_baseline=False,
repeats=CALLGRIND_REPEATS,

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Union
import numpy as np
from tabulate import tabulate
@ -41,7 +41,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
@dataclass(frozen=True)
class ExperimentConfig:
shape: Tuple[int] # [B, Hq, M, Hkv, N, D]
shape: tuple[int] # [B, Hq, M, Hkv, N, D]
attn_type: str
dtype: torch.dtype
calculate_bwd_time: bool
@ -149,7 +149,7 @@ def generate_inputs(
def generate_jagged_inputs(
shape: Tuple[int],
shape: tuple[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -611,7 +611,7 @@ softcap_value = 50
dropout_p = 0.0
def generate_score_mod(attn_type: str, shape: Tuple[int]) -> Callable | None:
def generate_score_mod(attn_type: str, shape: tuple[int]) -> Callable | None:
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap
@ -653,7 +653,7 @@ sliding_window_size = 512
prefix_length = 512
def generate_block_mask(attn_type: str, shape: Tuple[int]):
def generate_block_mask(attn_type: str, shape: tuple[int]):
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
@ -728,7 +728,7 @@ def generate_block_mask(attn_type: str, shape: Tuple[int]):
return block_mask, mask_mod_kwargs
def get_kernel_options(attn_type: str, shape: Tuple[int]):
def get_kernel_options(attn_type: str, shape: tuple[int]):
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
kernel_opt_training_dict = {
@ -815,7 +815,7 @@ def get_backend_context(backend: str):
def generate_FA_callable(
attn_type: str, shape: Tuple[int], dtype: torch.dtype, backend: str, **kwargs
attn_type: str, shape: tuple[int], dtype: torch.dtype, backend: str, **kwargs
) -> Callable | None:
if dtype not in [torch.float16, torch.bfloat16]:
return None
@ -882,7 +882,7 @@ def generate_FA_callable(
def generate_FD_callable(
attn_type: str, shape: Tuple[int], dtype: torch.dtype
attn_type: str, shape: tuple[int], dtype: torch.dtype
) -> Callable | None:
if dtype not in [torch.float16, torch.bfloat16]:
return None
@ -929,7 +929,7 @@ def generate_FD_callable(
def generate_attn_mask_linear_score_mod(
shape: Tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype
shape: tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype
):
B, Hq, M, N = shape
if block_mask is None and score_mod is None:
@ -954,7 +954,7 @@ def generate_attn_mask_linear_score_mod(
def generate_eager_sdpa(
attn_type: str,
shape: Tuple[int],
shape: tuple[int],
dtype: torch.dtype,
block_mask: BlockMask,
score_mod: Callable | None = None,
@ -1025,7 +1025,7 @@ def generate_experiment_configs(
calculate_bwd: bool,
dtype: torch.dtype,
batch_sizes: List[int],
num_heads: List[Tuple[int, int]],
num_heads: List[tuple[int, int]],
seq_lens: List[int],
head_dims: List[int],
score_mods_str: List[str],

View File

@ -2,7 +2,7 @@ import itertools
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from typing import Callable, List, Tuple
from typing import Callable, List
from tabulate import tabulate
from tqdm import tqdm
@ -68,7 +68,7 @@ class Experiment:
def get_input(
config: ExperimentConfig,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = torch.randn(
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
dtype=config.dtype,