mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Migrate from Tuple -> tuple in benchmarks (#144259)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144259 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
2e42be0595
commit
fcf9dc3b11
|
|
@ -32,7 +32,6 @@ from typing import (
|
|||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
|
@ -746,7 +745,7 @@ def timed(
|
|||
return (time_total, result) if return_result else time_total
|
||||
|
||||
|
||||
def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
|
||||
def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]:
|
||||
# NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
|
||||
# and consumed like `model(**example_inputs)`.
|
||||
# For other benchmarks, example_inputs are formatted as tuple and consumed
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import math
|
|||
import os
|
||||
from collections import Counter, defaultdict
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Generator, Iterable, Tuple
|
||||
from typing import Any, Dict, Generator, Iterable
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
|
|
@ -263,7 +263,7 @@ class OperatorInputsLoader:
|
|||
|
||||
def get_inputs_for_operator(
|
||||
self, operator, dtype=None, device="cuda"
|
||||
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]:
|
||||
) -> Generator[tuple[Iterable[Any], Dict[str, Any]], None, None]:
|
||||
assert (
|
||||
str(operator) in self.operator_db
|
||||
), f"Could not find {operator}, must provide overload"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -27,12 +25,12 @@ def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
|
|||
|
||||
def lstm_cell(
|
||||
input: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor],
|
||||
hidden: tuple[Tensor, Tensor],
|
||||
w_ih: Tensor,
|
||||
w_hh: Tensor,
|
||||
b_ih: Tensor,
|
||||
b_hh: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
hx, cx = hidden
|
||||
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
|
||||
|
||||
|
|
@ -57,7 +55,7 @@ def flat_lstm_cell(
|
|||
w_hh: Tensor,
|
||||
b_ih: Tensor,
|
||||
b_hh: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
|
||||
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
|
|
@ -75,11 +73,11 @@ def flat_lstm_cell(
|
|||
|
||||
def premul_lstm_cell(
|
||||
igates: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor],
|
||||
hidden: tuple[Tensor, Tensor],
|
||||
w_hh: Tensor,
|
||||
b_ih: Tensor,
|
||||
b_hh: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
hx, cx = hidden
|
||||
gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
|
||||
|
||||
|
|
@ -97,8 +95,8 @@ def premul_lstm_cell(
|
|||
|
||||
|
||||
def premul_lstm_cell_no_bias(
|
||||
igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
igates: Tensor, hidden: tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
hx, cx = hidden
|
||||
gates = igates + torch.mm(hx, w_hh.t()) + b_hh
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import numbers
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.jit as jit
|
||||
|
|
@ -131,8 +131,8 @@ class LSTMCell(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, state: Tuple[Tensor, Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
self, input: Tensor, state: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
hx, cx = state
|
||||
gates = (
|
||||
torch.mm(input, self.weight_ih.t())
|
||||
|
|
@ -199,8 +199,8 @@ class LayerNormLSTMCell(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, state: Tuple[Tensor, Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
self, input: Tensor, state: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
hx, cx = state
|
||||
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
|
||||
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
|
||||
|
|
@ -225,8 +225,8 @@ class LSTMLayer(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, state: Tuple[Tensor, Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
self, input: Tensor, state: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
inputs = input.unbind(0)
|
||||
outputs = torch.jit.annotate(List[Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
|
|
@ -242,8 +242,8 @@ class ReverseLSTMLayer(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, state: Tuple[Tensor, Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
self, input: Tensor, state: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
inputs = reverse(input.unbind(0))
|
||||
outputs = jit.annotate(List[Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
|
|
@ -266,11 +266,11 @@ class BidirLSTMLayer(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
|
||||
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
|
||||
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: [forward LSTMState, backward LSTMState]
|
||||
outputs = jit.annotate(List[Tensor], [])
|
||||
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
|
||||
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
for direction in self.directions:
|
||||
|
|
@ -300,10 +300,10 @@ class StackedLSTM(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
|
||||
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
|
||||
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: One state per layer
|
||||
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
|
||||
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
|
||||
output = input
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
|
|
@ -330,11 +330,11 @@ class StackedLSTM2(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]
|
||||
) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
|
||||
self, input: Tensor, states: List[List[tuple[Tensor, Tensor]]]
|
||||
) -> tuple[Tensor, List[List[tuple[Tensor, Tensor]]]]:
|
||||
# List[List[LSTMState]]: The outer list is for layers,
|
||||
# inner list is for directions.
|
||||
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
|
||||
output_states = jit.annotate(List[List[tuple[Tensor, Tensor]]], [])
|
||||
output = input
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
|
|
@ -370,10 +370,10 @@ class StackedLSTMWithDropout(jit.ScriptModule):
|
|||
|
||||
@jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
|
||||
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
|
||||
self, input: Tensor, states: List[tuple[Tensor, Tensor]]
|
||||
) -> tuple[Tensor, List[tuple[Tensor, Tensor]]]:
|
||||
# List[LSTMState]: One state per layer
|
||||
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
|
||||
output_states = jit.annotate(List[tuple[Tensor, Tensor]], [])
|
||||
output = input
|
||||
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
|
||||
i = 0
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections import namedtuple
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -266,12 +266,12 @@ def varlen_pytorch_lstm_creator(**kwargs):
|
|||
def varlen_lstm_factory(cell, script):
|
||||
def dynamic_rnn(
|
||||
sequences: List[Tensor],
|
||||
hiddens: Tuple[Tensor, Tensor],
|
||||
hiddens: tuple[Tensor, Tensor],
|
||||
wih: Tensor,
|
||||
whh: Tensor,
|
||||
bih: Tensor,
|
||||
bhh: Tensor,
|
||||
) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]:
|
||||
) -> tuple[List[Tensor], tuple[List[Tensor], List[Tensor]]]:
|
||||
hx, cx = hiddens
|
||||
hxs = hx.unbind(1)
|
||||
cxs = cx.unbind(1)
|
||||
|
|
@ -406,12 +406,12 @@ def lstm_inputs(
|
|||
def lstm_factory(cell, script):
|
||||
def dynamic_rnn(
|
||||
input: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor],
|
||||
hidden: tuple[Tensor, Tensor],
|
||||
wih: Tensor,
|
||||
whh: Tensor,
|
||||
bih: Tensor,
|
||||
bhh: Tensor,
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
hx, cx = hidden
|
||||
outputs = []
|
||||
inputs = input.unbind(0)
|
||||
|
|
@ -432,12 +432,12 @@ def lstm_factory(cell, script):
|
|||
def lstm_factory_premul(premul_cell, script):
|
||||
def dynamic_rnn(
|
||||
input: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor],
|
||||
hidden: tuple[Tensor, Tensor],
|
||||
wih: Tensor,
|
||||
whh: Tensor,
|
||||
bih: Tensor,
|
||||
bhh: Tensor,
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
hx, cx = hidden
|
||||
outputs = []
|
||||
inputs = torch.matmul(input, wih.t()).unbind(0)
|
||||
|
|
@ -458,12 +458,12 @@ def lstm_factory_premul(premul_cell, script):
|
|||
def lstm_factory_premul_bias(premul_cell, script):
|
||||
def dynamic_rnn(
|
||||
input: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor],
|
||||
hidden: tuple[Tensor, Tensor],
|
||||
wih: Tensor,
|
||||
whh: Tensor,
|
||||
bih: Tensor,
|
||||
bhh: Tensor,
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
hx, cx = hidden
|
||||
outputs = []
|
||||
inpSize = input.size()
|
||||
|
|
@ -506,8 +506,8 @@ def lstm_factory_simple(cell, script):
|
|||
|
||||
def lstm_factory_multilayer(cell, script):
|
||||
def dynamic_rnn(
|
||||
input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
input: Tensor, hidden: tuple[Tensor, Tensor], params: List[Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
params_stride = 4 # NB: this assumes that biases are there
|
||||
hx, cx = hidden
|
||||
hy, cy = hidden # for scoping...
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -512,7 +512,7 @@ class MultiheadAttentionContainer(torch.nn.Module):
|
|||
attn_mask: Optional[torch.Tensor] = None,
|
||||
bias_k: Optional[torch.Tensor] = None,
|
||||
bias_v: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value (Tensor): map a query and a set of key-value pairs to an output.
|
||||
|
|
@ -589,7 +589,7 @@ class ScaledDotProduct(torch.nn.Module):
|
|||
attn_mask: Optional[torch.Tensor] = None,
|
||||
bias_k: Optional[torch.Tensor] = None,
|
||||
bias_v: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""Uses a scaled dot product with the projected key-value pair to update
|
||||
the projected query.
|
||||
Args:
|
||||
|
|
@ -686,7 +686,7 @@ class InProjContainer(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
r"""Projects the input sequences using in-proj layers.
|
||||
Args:
|
||||
query, key, value (Tensors): sequence to be projected
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
# Type helpers
|
||||
InputsType = Union[Tensor, Tuple[Tensor, ...]]
|
||||
InputsType = Union[Tensor, tuple[Tensor, ...]]
|
||||
# A Getter takes in a device and returns a callable and the inputs to that callable
|
||||
GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
|
||||
GetterReturnType = tuple[Callable[..., Tensor], InputsType]
|
||||
GetterType = Callable[[torch.device], GetterReturnType]
|
||||
# V here refers to the v in either vjp, jvp, vhp or hvp
|
||||
VType = Union[None, Tensor, Tuple[Tensor, ...]]
|
||||
VType = Union[None, Tensor, tuple[Tensor, ...]]
|
||||
# Type used to store timing results. The first key is the model name, the second key
|
||||
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
|
||||
TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
|
||||
TimingResultType = Dict[str, Dict[str, tuple[float, ...]]]
|
||||
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
|
|
@ -44,7 +44,7 @@ def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
|||
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
||||
|
||||
|
||||
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
||||
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], List[str]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
|
|
@ -65,7 +65,7 @@ def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
|||
return params, names
|
||||
|
||||
|
||||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
|
||||
def load_weights(mod: nn.Module, names: List[str], params: tuple[Tensor, ...]) -> None:
|
||||
"""
|
||||
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
||||
Note that the `params` are regular Tensors (that can have history) and so are left
|
||||
|
|
@ -77,7 +77,7 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -
|
|||
|
||||
# Utilities to read/write markdown table-like content.
|
||||
def to_markdown_table(
|
||||
res: TimingResultType, header: Optional[Tuple[str, ...]] = None
|
||||
res: TimingResultType, header: Optional[tuple[str, ...]] = None
|
||||
) -> str:
|
||||
if header is None:
|
||||
header = ("model", "task", "mean", "var")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import dataclasses
|
|||
import itertools
|
||||
import platform
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torchao
|
||||
from common import Experiment, register_experiment
|
||||
|
|
@ -89,7 +89,7 @@ def prefill(
|
|||
|
||||
def decode_one_token(
|
||||
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# input_pos: [B, 1]
|
||||
assert input_pos.shape[-1] == 1
|
||||
logits = model(x, input_pos)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import subprocess
|
|||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
from worker.main import WorkerFailure, WorkerOutput
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ class CorePool:
|
|||
True for _ in range(min_core_id, min_core_id + self._num_cores)
|
||||
]
|
||||
|
||||
self._reservations: Dict[str, Tuple[int, ...]] = {}
|
||||
self._reservations: Dict[str, tuple[int, ...]] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def reserve(self, n: int) -> Optional[str]:
|
||||
|
|
@ -87,11 +87,11 @@ class CorePool:
|
|||
class Runner:
|
||||
def __init__(
|
||||
self,
|
||||
work_items: Tuple[WorkOrder, ...],
|
||||
work_items: tuple[WorkOrder, ...],
|
||||
core_pool: Optional[CorePool] = None,
|
||||
cadence: float = 1.0,
|
||||
) -> None:
|
||||
self._work_items: Tuple[WorkOrder, ...] = work_items
|
||||
self._work_items: tuple[WorkOrder, ...] = work_items
|
||||
self._core_pool: CorePool = core_pool or CorePool(0, CPU_COUNT - 4)
|
||||
self._cadence: float = cadence
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import pickle
|
|||
import sys
|
||||
import timeit
|
||||
import traceback
|
||||
from typing import Any, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -81,8 +81,8 @@ class WorkerTimerArgs:
|
|||
@dataclasses.dataclass(frozen=True)
|
||||
class WorkerOutput:
|
||||
# Only return values to reduce communication between main process and workers.
|
||||
wall_times: Tuple[float, ...]
|
||||
instructions: Tuple[int, ...]
|
||||
wall_times: tuple[float, ...]
|
||||
instructions: tuple[int, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
|
@ -145,7 +145,7 @@ def _run(timer_args: WorkerTimerArgs) -> WorkerOutput:
|
|||
|
||||
m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
|
||||
|
||||
stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind(
|
||||
stats: tuple[CallgrindStats, ...] = timer.collect_callgrind(
|
||||
number=CALLGRIND_NUMBER,
|
||||
collect_baseline=False,
|
||||
repeats=CALLGRIND_REPEATS,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
|
@ -41,7 +41,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentConfig:
|
||||
shape: Tuple[int] # [B, Hq, M, Hkv, N, D]
|
||||
shape: tuple[int] # [B, Hq, M, Hkv, N, D]
|
||||
attn_type: str
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
|
|
@ -149,7 +149,7 @@ def generate_inputs(
|
|||
|
||||
|
||||
def generate_jagged_inputs(
|
||||
shape: Tuple[int],
|
||||
shape: tuple[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
|
|
@ -611,7 +611,7 @@ softcap_value = 50
|
|||
dropout_p = 0.0
|
||||
|
||||
|
||||
def generate_score_mod(attn_type: str, shape: Tuple[int]) -> Callable | None:
|
||||
def generate_score_mod(attn_type: str, shape: tuple[int]) -> Callable | None:
|
||||
B, Hq, M, Hkv, N, D = shape
|
||||
is_decoding = M == 1
|
||||
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap
|
||||
|
|
@ -653,7 +653,7 @@ sliding_window_size = 512
|
|||
prefix_length = 512
|
||||
|
||||
|
||||
def generate_block_mask(attn_type: str, shape: Tuple[int]):
|
||||
def generate_block_mask(attn_type: str, shape: tuple[int]):
|
||||
B, Hq, M, Hkv, N, D = shape
|
||||
is_decoding = M == 1
|
||||
|
||||
|
|
@ -728,7 +728,7 @@ def generate_block_mask(attn_type: str, shape: Tuple[int]):
|
|||
return block_mask, mask_mod_kwargs
|
||||
|
||||
|
||||
def get_kernel_options(attn_type: str, shape: Tuple[int]):
|
||||
def get_kernel_options(attn_type: str, shape: tuple[int]):
|
||||
B, Hq, M, Hkv, N, D = shape
|
||||
is_decoding = M == 1
|
||||
kernel_opt_training_dict = {
|
||||
|
|
@ -815,7 +815,7 @@ def get_backend_context(backend: str):
|
|||
|
||||
|
||||
def generate_FA_callable(
|
||||
attn_type: str, shape: Tuple[int], dtype: torch.dtype, backend: str, **kwargs
|
||||
attn_type: str, shape: tuple[int], dtype: torch.dtype, backend: str, **kwargs
|
||||
) -> Callable | None:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
return None
|
||||
|
|
@ -882,7 +882,7 @@ def generate_FA_callable(
|
|||
|
||||
|
||||
def generate_FD_callable(
|
||||
attn_type: str, shape: Tuple[int], dtype: torch.dtype
|
||||
attn_type: str, shape: tuple[int], dtype: torch.dtype
|
||||
) -> Callable | None:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
return None
|
||||
|
|
@ -929,7 +929,7 @@ def generate_FD_callable(
|
|||
|
||||
|
||||
def generate_attn_mask_linear_score_mod(
|
||||
shape: Tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype
|
||||
shape: tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype
|
||||
):
|
||||
B, Hq, M, N = shape
|
||||
if block_mask is None and score_mod is None:
|
||||
|
|
@ -954,7 +954,7 @@ def generate_attn_mask_linear_score_mod(
|
|||
|
||||
def generate_eager_sdpa(
|
||||
attn_type: str,
|
||||
shape: Tuple[int],
|
||||
shape: tuple[int],
|
||||
dtype: torch.dtype,
|
||||
block_mask: BlockMask,
|
||||
score_mod: Callable | None = None,
|
||||
|
|
@ -1025,7 +1025,7 @@ def generate_experiment_configs(
|
|||
calculate_bwd: bool,
|
||||
dtype: torch.dtype,
|
||||
batch_sizes: List[int],
|
||||
num_heads: List[Tuple[int, int]],
|
||||
num_heads: List[tuple[int, int]],
|
||||
seq_lens: List[int],
|
||||
head_dims: List[int],
|
||||
score_mods_str: List[str],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable, List
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
|
@ -68,7 +68,7 @@ class Experiment:
|
|||
|
||||
def get_input(
|
||||
config: ExperimentConfig,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = torch.randn(
|
||||
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
|
||||
dtype=config.dtype,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user