tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)

High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.

The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```

TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy.  My new heuristic achieves 94% accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
This commit is contained in:
Horace He 2023-04-21 17:10:41 +00:00 committed by PyTorch MergeBot
parent bb830224e3
commit 547bef11ee
14 changed files with 54 additions and 473 deletions

View File

@ -14034,7 +14034,7 @@
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offse, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
variants: function
dispatch:
CUDA: _scaled_dot_product_flash_attention_backward_cuda

View File

@ -39,22 +39,38 @@ struct sdp_params {
bool is_causal;
};
inline bool check_requires_grad(sdp_params params, bool debug) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
if ((any_inputs_require_grad && gradmode_enabled)) {
if (debug) {
TORCH_WARN("Flash Attention does not currently support training.");
}
return false;
}
return true;
}
inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
constexpr std::array<SDPBackend, num_backends> default_order{
SDPBackend::flash_attention,
SDPBackend::efficient_attention,
SDPBackend::math};
constexpr std::array<SDPBackend, num_backends> efficient_first{
SDPBackend::efficient_attention,
SDPBackend::flash_attention,
SDPBackend::math};
// Logic is taken from xformers
// FlashAttention parallelizes across "batch_size * num_heads"
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
if (params.query.is_nested() || params.key.is_nested() ||
params.value.is_nested()) {
// See check_for_nested_inputs for details
return {
SDPBackend::efficient_attention,
SDPBackend::flash_attention,
SDPBackend::math};
return efficient_first;
}
if (params.query.dim() != 4) {
return default_order;
@ -70,13 +86,14 @@ inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash;
bool small_threads_flash = threads_flash < 60;
bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128;
if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
return {
SDPBackend::efficient_attention,
SDPBackend::flash_attention,
SDPBackend::math};
// The training heuristic is taken from https://github.com/pytorch/pytorch/pull/99644
// Revisit when updated cutlass kernel is upstreamed.
if (check_requires_grad(params, false)) {
if (6 * threads_flash > query_lengths) return efficient_first;
} else if ((small_threads_flash && more_threads_cutlass) || large_head_dim)
return efficient_first;
}
}
return default_order;
}
@ -253,19 +270,6 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
return true;
}
inline bool check_requires_grad(sdp_params params, bool debug) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
if ((any_inputs_require_grad && gradmode_enabled)) {
if (debug) {
TORCH_WARN("Flash Attention does not currently support training.");
}
return false;
}
return true;
}
inline bool check_requires_grad_and_nested(sdp_params params, bool debug) {
// If we fail both checks then we return false
if (check_for_nested_inputs(params) && !check_requires_grad(params, false)){

View File

@ -1,5 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -1,190 +0,0 @@
import torch
from functorch.compile import memory_efficient_fusion
import benchmark_helper
device = "cuda"
dtype = torch.float16
# LightSeq pattern 1
class DropoutResBias:
@staticmethod
def fn(input, bias, residual):
a = torch.add(input, bias)
b = torch.nn.functional.dropout(a, p=0.7, training=True)
c = b + residual
return c
@staticmethod
def args():
batch_size, seq_len, hidden_size = 32, 196, 1024
input = torch.randn(
batch_size,
seq_len,
hidden_size,
requires_grad=True,
device=device,
dtype=dtype,
)
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
residual = torch.randn(
batch_size,
seq_len,
hidden_size,
requires_grad=False,
device=device,
dtype=dtype,
)
args = (input, bias, residual)
return args
class DropoutResBiasScalar:
@staticmethod
def fn(input, bias, residual, p: float):
a = torch.add(input, bias)
b = torch.nn.functional.dropout(a, p, training=True)
c = b + residual
return c
@staticmethod
def args():
batch_size, seq_len, hidden_size = 32, 196, 1024
input = torch.randn(
batch_size,
seq_len,
hidden_size,
requires_grad=True,
device=device,
dtype=dtype,
)
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
residual = torch.randn(
batch_size,
seq_len,
hidden_size,
requires_grad=False,
device=device,
dtype=dtype,
)
args = (input, bias, residual, 0.7)
return args
# LightSeq pattern 2
class BiasReluDropout:
@staticmethod
def fn(input, bias):
a = torch.add(input, bias)
b = torch.nn.functional.relu(a)
c = torch.nn.functional.dropout(b, p=0.6, training=True)
return c
@staticmethod
def args():
batch_size = 32
seq_len = 196
intermediate_size = 4096
input = torch.randn(
batch_size,
seq_len,
intermediate_size,
requires_grad=True,
device=device,
dtype=dtype,
)
bias = torch.randn(
intermediate_size, requires_grad=True, device=device, dtype=dtype
)
args = (input, bias)
return args
class BiasDropoutResLayerNorm:
@staticmethod
def fn(input, bias, residual):
hidden_size = 1024
a = torch.add(input, bias)
b = torch.nn.functional.dropout(a, p=0.7, training=True)
c = b + residual
d = torch.nn.functional.layer_norm(c, normalized_shape=(hidden_size,))
return d
@staticmethod
def args():
batch_size = 32
seq_len = 196
hidden_size = 1024
input = torch.randn(
batch_size,
seq_len,
hidden_size,
requires_grad=True,
device=device,
dtype=dtype,
)
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
residual = torch.randn(
batch_size,
seq_len,
hidden_size,
requires_grad=False,
device=device,
dtype=dtype,
)
args = (input, bias, residual)
return args
class LayerNormSigmoid:
@staticmethod
def fn(inp):
hidden_size = 512
a = torch.nn.functional.layer_norm(inp, normalized_shape=(hidden_size,))
b = torch.sigmoid(a)
return b
@staticmethod
def args():
batch_size = 8192
hidden_size = 512
inp = torch.randn(
batch_size, hidden_size, requires_grad=True, device=device, dtype=dtype
)
args = (inp,)
return args
for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]:
# Clear the compile cache
# Get the function and inputs
obj = cl()
fn = obj.fn
args = obj.args()
# Find the static args
static_argnums = []
for idx, arg in enumerate(args):
if not isinstance(arg, torch.Tensor):
static_argnums.append(idx)
# Get the optimized function
opt_fn = memory_efficient_fusion(fn, static_argnums)
# Profile cuda kernels
benchmark_helper.profile_cuda_kernels(fn, args, "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd")
# Time it with Torch Timer
benchmark_helper.time_with_torch_timer(fn, args, "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd")
# Time it with manual Timer
benchmark_helper.time_with_manual_timer(fn, args, "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd")

View File

@ -1,148 +0,0 @@
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torch.utils.benchmark import Timer
import time
def profile_cuda_kernels(fn, args, string_id="Model time"):
print("################################################")
print(f"#### Profiling for {string_id} starts #########")
print("################################################")
warmup = 50
old_args = args[:]
n_repeats = 1
n_layers = 1
ref = fn(*old_args)
gO = torch.rand_like(ref)
for _ in range(0, warmup // n_layers):
args = list(old_args[:])
ref = fn(*args)
ref.backward(gO)
torch.cuda.synchronize()
# Forward profile
def fwd_run():
for _ in range(0, n_repeats // n_layers):
args = list(old_args[:])
for arg in args:
if isinstance(arg, torch.Tensor):
arg.grad = None
ref = fn(*args)
print(f"###### Forward profile for {string_id} starts #####")
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("baseline"):
fwd_run()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
print(f"###### Forward profile for {string_id} ends #####")
# Backward profile
def bwd_run():
for _ in range(0, n_repeats // n_layers):
args = list(old_args[:])
for arg in args:
if isinstance(arg, torch.Tensor):
arg.grad = None
ref = fn(*args)
print(f"###### Backward profile for {string_id} starts #####")
torch.cuda.synchronize()
with profile(
activities=[ProfilerActivity.CUDA], record_shapes=True
) as prof:
with record_function("baseline"):
ref.backward(gO)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
torch.cuda.synchronize()
print(f"###### Backward profile for {string_id} ends #####")
bwd_run()
print("################################################")
print(f"#### Profiling for {string_id} ends #########")
print("################################################\n\n\n\n")
def time_with_torch_timer(fn, args, string_id, kwargs=None):
if kwargs is None:
kwargs = {}
print("################################################")
print(f"#### Torch Timer for {string_id} starts #########")
print("################################################")
ref = fn(*args, **kwargs)
gO = torch.rand_like(ref)
env = {"args": args, "gO": gO, "kwargs": kwargs, "fn": fn}
grad_none = {"for x in args: x.grad=None"}
fn_call = "fn(*args, **kwargs)"
# Measure end-to-end fwd time
timer = Timer(stmt=f"{fn_call}", globals=env)
fwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3)
timer_blocked = timer.blocked_autorange()
print(f"Forward = {fwd_latency}")
# Measure end-to-end fwd bwd
timer = Timer(
stmt=f"{grad_none}; fwd = {fn_call}; fwd.backward(gO)",
globals=env,
)
fwd_bwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3)
timer_blocked = timer.blocked_autorange()
# print(f"Forward + sum + Backward = {fwd_sum_bwd_latency}")
bwd_latency = round(fwd_bwd_latency - fwd_latency, 3)
print(f"Backward = {bwd_latency}")
print("################################################")
print(f"#### Torch Timer for {string_id} ends ###############")
print("################################################\n\n\n\n")
def time_with_manual_timer(fn, args, string_id):
print("################################################")
print(f"#### Manual Timer for {string_id} starts #########")
print("################################################")
warmup = 50
repeats = 1000
old_args = args[:]
ref = fn(*old_args)
gO = torch.rand_like(ref)
for _ in range(0, warmup):
args = list(old_args[:])
for arg in args:
if isinstance(arg, torch.Tensor):
arg.grad = None
ref = fn(*args)
ref.backward(gO)
torch.cuda.synchronize()
fwd_times = []
bwd_times = []
for _ in range(0, repeats):
args = list(old_args[:])
for arg in args:
if isinstance(arg, torch.Tensor):
arg.grad = None
fwd_start = time.time()
ref = fn(*args)
torch.cuda.synchronize()
fwd_end = time.time()
bwd_start = time.time()
ref.backward(gO)
torch.cuda.synchronize()
bwd_end = time.time()
fwd_times.append(fwd_end - fwd_start)
bwd_times.append(bwd_end - bwd_start)
avg_fwd = round(sum(fwd_times) / repeats * 10 ** 6, 2)
avg_bwd = round(sum(bwd_times) / repeats * 10 ** 6, 2)
avg_total = round(avg_fwd + avg_bwd, 2)
print(f"Forward = {avg_fwd}")
print(f"Backward = {avg_bwd}")
print("################################################")
print(f"#### Manual Timer for {string_id} ends #########")
print("################################################\n\n\n")

View File

@ -1,65 +0,0 @@
import torch
from functorch.compile import memory_efficient_pointwise_fusion
import benchmark_helper
# ALL comments regarding the patetrns
def bias_gelu_dropout(input, bias):
a = torch.add(input, bias)
b = torch.nn.functional.gelu(a)
c = torch.nn.functional.dropout(b, p=0.6, training=True)
return c
def aot_fn(input, bias):
a = torch.add(input, bias)
b = a * 0.5 * (1.0 + torch.tanh(0.79788456 * a * (1 + 0.044715 * a * a)))
c = torch.nn.functional.dropout(b, p=0.6, training=True)
return c
fn = bias_gelu_dropout
# Set inputs
device = "cuda"
dtype = torch.float16
batch_size = 32
seq_len = 196
intermediate_size = 4096
# batch_size = 2
# seq_len = 4
# intermediate_size = 3
input = torch.randn(
batch_size,
seq_len,
intermediate_size,
requires_grad=True,
device=device,
dtype=dtype,
)
bias = torch.randn(intermediate_size, requires_grad=True, device=device, dtype=dtype)
# Get the optimized function
opt_fn = memory_efficient_pointwise_fusion(
aot_fn, compiler_name="torchscript_nvfuser"
)
# Profile cuda kernels
benchmark_helper.profile_cuda_kernels(fn, (input, bias), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.profile_cuda_kernels(opt_fn, (input, bias), "AOTAutograd")
# Time it with Torch Timer
benchmark_helper.time_with_torch_timer(fn, (input, bias), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_torch_timer(opt_fn, (input, bias), "AOTAutograd")
# Time it with manual Timer
benchmark_helper.time_with_manual_timer(fn, (input, bias), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_manual_timer(opt_fn, (input, bias), "AOTAutograd")

View File

@ -357,6 +357,7 @@ ALLOW_LIST = [
("aten::_nested_view_from_buffer_copy.out", datetime.date(2023, 5, 1)),
("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)),
("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)),
("aten::_scaled_dot_product_flash_attention_backward", datetime.date(2023, 6, 1)),
# These ops were moved to python under the c10d_functional namespace
("aten::wait_tensor", datetime.date(9999, 1, 30)),
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),

View File

@ -148,7 +148,7 @@ class TestFlopCounter(TestCase):
self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
def test_custom(self):
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out: 5})
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5})
with mode:
a = T(4, 5)
a + a

View File

@ -1352,7 +1352,7 @@ class TestSDPA(NNTestCase):
assert torch._fused_sdp_choice(q, k, v) == SDPBackend.MATH
if PLATFORM_SUPPORTS_FUSED_SDPA:
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
shape = (batch_size, seq_len, num_heads, head_dim)
device = "cuda"
make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float16, packed=True)

View File

@ -2932,8 +2932,6 @@ def aot_function(
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
num_params_buffers: int = 0,
hasher_type=None, # deprecated
static_argnums: Optional[Tuple[int]] = None, # deprecated
keep_inference_input_mutations: bool = False,
inference_compiler: Optional[Callable] = None,
*,
@ -2992,10 +2990,6 @@ def aot_function(
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)
"""
if static_argnums is not None:
raise RuntimeError(
"static_argnums has been deprecated - manually wrap your function or use torchdynamo."
)
if bw_compiler is None:
bw_compiler = fw_compiler
@ -3127,8 +3121,6 @@ def aot_module_simplified(
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
hasher_type=None,
static_argnums=None,
keep_inference_input_mutations=False,
inference_compiler: Optional[Callable] = None,
) -> nn.Module:
@ -3175,6 +3167,7 @@ def aot_module_simplified(
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
):
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(
@ -3193,7 +3186,6 @@ def aot_module_simplified(
)
return out
assert static_argnums is None
if bw_compiler is None:
bw_compiler = fw_compiler
if inference_compiler is None:

View File

@ -5,7 +5,7 @@ import pickle
import random
from contextlib import contextmanager
from functools import partial
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Union
import sympy
import torch
@ -188,8 +188,8 @@ def simple_ts_compile(fx_g, _):
return f
def nnc_jit(f, static_argnums=None):
return aot_function(f, simple_ts_compile, static_argnums=static_argnums)
def nnc_jit(f):
return aot_function(f, simple_ts_compile)
aten = torch.ops.aten
@ -229,7 +229,6 @@ def print_compile(fx_g, _):
def memory_efficient_fusion(
fn: Union[Callable, nn.Module],
static_argnums: Optional[Tuple[int]] = None,
**kwargs,
):
"""
@ -245,8 +244,6 @@ def memory_efficient_fusion(
Args:
fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
that takes one ore more arguments. Must return one or more Tensors.
static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
the arguments of the function as static.
**kwargs: Any other overrides you want to make to the settings
Returns:
@ -261,7 +258,6 @@ def memory_efficient_fusion(
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
"decompositions": default_decompositions,
"static_argnums": static_argnums,
}
config.update(kwargs)
if isinstance(fn, torch.nn.Module):

View File

@ -94,7 +94,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
func1.__name__ = name
return func1
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
buf_to_fx_node = {}
@ -135,7 +135,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
name = snode.get_name()
fx_node.name = name
fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
if isinstance(snode, FusedSchedulerNode):
for x in snode.snodes:

View File

@ -1,5 +1,5 @@
from functools import wraps
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional
import torch.utils._pytree as pytree
from torch._functorch.aot_autograd import (
@ -19,8 +19,6 @@ def patched_aot_function(
partition_fn: Callable[..., object] = default_partition,
decompositions: Optional[Dict[object, object]] = None,
num_params_buffers: int = 0,
hasher_type: object = None, # deprecated
static_argnums: Optional[Tuple[int]] = None, # deprecated
keep_inference_input_mutations: bool = False,
pre_compile_fn: Optional[Callable[..., object]] = None,
) -> Callable[..., object]:
@ -98,11 +96,6 @@ def patched_aot_function(
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)
"""
if static_argnums is not None:
raise RuntimeError(
"static_argnums has been deprecated - manually wrap your function or use torchdynamo."
)
if bw_compiler is None:
bw_compiler = fw_compiler

View File

@ -14,7 +14,7 @@ def get_shape(i):
return i.shape
return i
def mm_flop(a_shape, b_shape, out=None) -> int:
def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
"""
Count flops for matmul.
"""
@ -26,13 +26,13 @@ def mm_flop(a_shape, b_shape, out=None) -> int:
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
return m * n * 2 * k
def addmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int:
def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
"""
Count flops for addmm
"""
return mm_flop(a_shape, b_shape)
def bmm_flop(a_shape, b_shape, out=None, **kwargs) -> int:
def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
"""
Count flops for the bmm operation.
"""
@ -46,7 +46,7 @@ def bmm_flop(a_shape, b_shape, out=None, **kwargs) -> int:
flop = b * m * n * 2 * k
return flop
def baddbmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int:
def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
"""
Count flops for the baddbmm operation.
"""
@ -83,11 +83,11 @@ def conv_flop_count(
flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
return flop
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out=None, **kwargs) -> int:
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
"""
Count flops for convolution.
"""
return conv_flop_count(x_shape, w_shape, out, transposed=transposed)
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
@ -104,14 +104,14 @@ def conv_backward_flop(
_output_padding,
_groups,
output_mask,
out) -> int:
out_shape) -> int:
flop_count = 0
if output_mask[0]:
grad_input_shape = get_shape(out[0])
grad_input_shape = get_shape(out_shape[0])
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
if output_mask[1]:
grad_weight_shape = get_shape(out[1])
grad_weight_shape = get_shape(out_shape[1])
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed)
return flop_count
@ -134,7 +134,7 @@ def sdpa_flop_count(query_shape, key_shape, value_shape):
def sdpa_flop(query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int:
def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
"""
Count flops for self-attention.
"""
@ -169,7 +169,7 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape
return total_flops
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int:
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
"""
Count flops for self-attention backward.
"""
@ -306,6 +306,9 @@ class FlopCounterMode(TorchDispatchMode):
return PopState.apply
def get_total_flops(self) -> int:
return sum(self.flop_counts['Global'].values())
def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
"""Returns the flop counts as a dictionary of dictionaries. The outer
dictionary is keyed by module name, and the inner dictionary is keyed by
@ -326,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode):
tabulate.PRESERVE_WHITESPACE = True
header = ["Module", "FLOP", "% Total"]
values = []
global_flops = sum(self.flop_counts['Global'].values())
global_flops = self.get_total_flops()
global_suffix = get_suffix_str(global_flops)
is_global_subsumed = False
@ -394,7 +397,7 @@ class FlopCounterMode(TorchDispatchMode):
if func_packet in self.flop_mapping:
flop_count_func = self.flop_mapping[func_packet]
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
flop_count = flop_count_func(*args, **kwargs, out=out_shape) # type: ignore[operator]
flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count