mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)
High level approach: 1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK) 2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR). 2a. I did a bunch of feature augmentation to capture interactions between features. The heuristic I ended up with is: ``` use_flash = seq_len / (num_heads * batch_size) > 6 ``` TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy. My new heuristic achieves 94% accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644 Approved by: https://github.com/ngimel, https://github.com/drisspg
This commit is contained in:
parent
bb830224e3
commit
547bef11ee
|
|
@ -14034,7 +14034,7 @@
|
|||
CUDA: _scaled_dot_product_flash_attention_cuda
|
||||
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
|
||||
|
||||
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offse, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
|
||||
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_flash_attention_backward_cuda
|
||||
|
|
|
|||
|
|
@ -39,22 +39,38 @@ struct sdp_params {
|
|||
bool is_causal;
|
||||
};
|
||||
|
||||
inline bool check_requires_grad(sdp_params params, bool debug) {
|
||||
const bool any_inputs_require_grad = params.query.requires_grad() ||
|
||||
params.key.requires_grad() || params.value.requires_grad();
|
||||
const bool gradmode_enabled = at::GradMode::is_enabled();
|
||||
if ((any_inputs_require_grad && gradmode_enabled)) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Flash Attention does not currently support training.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
|
||||
constexpr std::array<SDPBackend, num_backends> default_order{
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::math};
|
||||
|
||||
constexpr std::array<SDPBackend, num_backends> efficient_first{
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::math};
|
||||
// Logic is taken from xformers
|
||||
// FlashAttention parallelizes across "batch_size * num_heads"
|
||||
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
|
||||
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
|
||||
|
||||
if (params.query.is_nested() || params.key.is_nested() ||
|
||||
params.value.is_nested()) {
|
||||
// See check_for_nested_inputs for details
|
||||
return {
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::math};
|
||||
return efficient_first;
|
||||
}
|
||||
if (params.query.dim() != 4) {
|
||||
return default_order;
|
||||
|
|
@ -70,13 +86,14 @@ inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
|
|||
bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash;
|
||||
bool small_threads_flash = threads_flash < 60;
|
||||
bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128;
|
||||
if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
|
||||
return {
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::math};
|
||||
|
||||
// The training heuristic is taken from https://github.com/pytorch/pytorch/pull/99644
|
||||
// Revisit when updated cutlass kernel is upstreamed.
|
||||
if (check_requires_grad(params, false)) {
|
||||
if (6 * threads_flash > query_lengths) return efficient_first;
|
||||
} else if ((small_threads_flash && more_threads_cutlass) || large_head_dim)
|
||||
return efficient_first;
|
||||
}
|
||||
}
|
||||
return default_order;
|
||||
}
|
||||
|
||||
|
|
@ -253,19 +270,6 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool check_requires_grad(sdp_params params, bool debug) {
|
||||
const bool any_inputs_require_grad = params.query.requires_grad() ||
|
||||
params.key.requires_grad() || params.value.requires_grad();
|
||||
const bool gradmode_enabled = at::GradMode::is_enabled();
|
||||
if ((any_inputs_require_grad && gradmode_enabled)) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Flash Attention does not currently support training.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_requires_grad_and_nested(sdp_params params, bool debug) {
|
||||
// If we fail both checks then we return false
|
||||
if (check_for_nested_inputs(params) && !check_requires_grad(params, false)){
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
import torch
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
import benchmark_helper
|
||||
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
|
||||
# LightSeq pattern 1
|
||||
class DropoutResBias:
|
||||
@staticmethod
|
||||
def fn(input, bias, residual):
|
||||
a = torch.add(input, bias)
|
||||
b = torch.nn.functional.dropout(a, p=0.7, training=True)
|
||||
c = b + residual
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def args():
|
||||
batch_size, seq_len, hidden_size = 32, 196, 1024
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
|
||||
residual = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
requires_grad=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
args = (input, bias, residual)
|
||||
return args
|
||||
|
||||
|
||||
class DropoutResBiasScalar:
|
||||
@staticmethod
|
||||
def fn(input, bias, residual, p: float):
|
||||
a = torch.add(input, bias)
|
||||
b = torch.nn.functional.dropout(a, p, training=True)
|
||||
c = b + residual
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def args():
|
||||
batch_size, seq_len, hidden_size = 32, 196, 1024
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
|
||||
residual = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
requires_grad=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
args = (input, bias, residual, 0.7)
|
||||
return args
|
||||
|
||||
|
||||
|
||||
# LightSeq pattern 2
|
||||
class BiasReluDropout:
|
||||
@staticmethod
|
||||
def fn(input, bias):
|
||||
a = torch.add(input, bias)
|
||||
b = torch.nn.functional.relu(a)
|
||||
c = torch.nn.functional.dropout(b, p=0.6, training=True)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def args():
|
||||
batch_size = 32
|
||||
seq_len = 196
|
||||
intermediate_size = 4096
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
intermediate_size,
|
||||
requires_grad=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
bias = torch.randn(
|
||||
intermediate_size, requires_grad=True, device=device, dtype=dtype
|
||||
)
|
||||
args = (input, bias)
|
||||
return args
|
||||
|
||||
|
||||
class BiasDropoutResLayerNorm:
|
||||
@staticmethod
|
||||
def fn(input, bias, residual):
|
||||
hidden_size = 1024
|
||||
a = torch.add(input, bias)
|
||||
b = torch.nn.functional.dropout(a, p=0.7, training=True)
|
||||
c = b + residual
|
||||
d = torch.nn.functional.layer_norm(c, normalized_shape=(hidden_size,))
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def args():
|
||||
batch_size = 32
|
||||
seq_len = 196
|
||||
hidden_size = 1024
|
||||
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
requires_grad=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
|
||||
residual = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
requires_grad=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
args = (input, bias, residual)
|
||||
return args
|
||||
|
||||
|
||||
class LayerNormSigmoid:
|
||||
@staticmethod
|
||||
def fn(inp):
|
||||
hidden_size = 512
|
||||
a = torch.nn.functional.layer_norm(inp, normalized_shape=(hidden_size,))
|
||||
b = torch.sigmoid(a)
|
||||
return b
|
||||
|
||||
@staticmethod
|
||||
def args():
|
||||
batch_size = 8192
|
||||
hidden_size = 512
|
||||
inp = torch.randn(
|
||||
batch_size, hidden_size, requires_grad=True, device=device, dtype=dtype
|
||||
)
|
||||
args = (inp,)
|
||||
return args
|
||||
|
||||
|
||||
for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]:
|
||||
# Clear the compile cache
|
||||
|
||||
# Get the function and inputs
|
||||
obj = cl()
|
||||
fn = obj.fn
|
||||
args = obj.args()
|
||||
|
||||
# Find the static args
|
||||
static_argnums = []
|
||||
for idx, arg in enumerate(args):
|
||||
if not isinstance(arg, torch.Tensor):
|
||||
static_argnums.append(idx)
|
||||
|
||||
# Get the optimized function
|
||||
opt_fn = memory_efficient_fusion(fn, static_argnums)
|
||||
|
||||
# Profile cuda kernels
|
||||
benchmark_helper.profile_cuda_kernels(fn, args, "Eager")
|
||||
with torch.jit.fuser("fuser2"):
|
||||
benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd")
|
||||
|
||||
# Time it with Torch Timer
|
||||
benchmark_helper.time_with_torch_timer(fn, args, "Eager")
|
||||
with torch.jit.fuser("fuser2"):
|
||||
benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd")
|
||||
|
||||
# Time it with manual Timer
|
||||
benchmark_helper.time_with_manual_timer(fn, args, "Eager")
|
||||
with torch.jit.fuser("fuser2"):
|
||||
benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd")
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
import torch
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
from torch.utils.benchmark import Timer
|
||||
import time
|
||||
|
||||
|
||||
def profile_cuda_kernels(fn, args, string_id="Model time"):
|
||||
print("################################################")
|
||||
print(f"#### Profiling for {string_id} starts #########")
|
||||
print("################################################")
|
||||
warmup = 50
|
||||
old_args = args[:]
|
||||
n_repeats = 1
|
||||
n_layers = 1
|
||||
ref = fn(*old_args)
|
||||
gO = torch.rand_like(ref)
|
||||
for _ in range(0, warmup // n_layers):
|
||||
args = list(old_args[:])
|
||||
ref = fn(*args)
|
||||
ref.backward(gO)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Forward profile
|
||||
def fwd_run():
|
||||
for _ in range(0, n_repeats // n_layers):
|
||||
args = list(old_args[:])
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
arg.grad = None
|
||||
ref = fn(*args)
|
||||
|
||||
print(f"###### Forward profile for {string_id} starts #####")
|
||||
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
||||
with record_function("baseline"):
|
||||
fwd_run()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
||||
print(f"###### Forward profile for {string_id} ends #####")
|
||||
|
||||
# Backward profile
|
||||
def bwd_run():
|
||||
for _ in range(0, n_repeats // n_layers):
|
||||
args = list(old_args[:])
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
arg.grad = None
|
||||
ref = fn(*args)
|
||||
|
||||
print(f"###### Backward profile for {string_id} starts #####")
|
||||
torch.cuda.synchronize()
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
with record_function("baseline"):
|
||||
ref.backward(gO)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
||||
torch.cuda.synchronize()
|
||||
print(f"###### Backward profile for {string_id} ends #####")
|
||||
|
||||
bwd_run()
|
||||
print("################################################")
|
||||
print(f"#### Profiling for {string_id} ends #########")
|
||||
print("################################################\n\n\n\n")
|
||||
|
||||
|
||||
def time_with_torch_timer(fn, args, string_id, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
print("################################################")
|
||||
print(f"#### Torch Timer for {string_id} starts #########")
|
||||
print("################################################")
|
||||
ref = fn(*args, **kwargs)
|
||||
gO = torch.rand_like(ref)
|
||||
env = {"args": args, "gO": gO, "kwargs": kwargs, "fn": fn}
|
||||
grad_none = {"for x in args: x.grad=None"}
|
||||
fn_call = "fn(*args, **kwargs)"
|
||||
# Measure end-to-end fwd time
|
||||
timer = Timer(stmt=f"{fn_call}", globals=env)
|
||||
fwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3)
|
||||
timer_blocked = timer.blocked_autorange()
|
||||
print(f"Forward = {fwd_latency}")
|
||||
|
||||
# Measure end-to-end fwd bwd
|
||||
timer = Timer(
|
||||
stmt=f"{grad_none}; fwd = {fn_call}; fwd.backward(gO)",
|
||||
globals=env,
|
||||
)
|
||||
fwd_bwd_latency = round(timer.timeit(1000).mean * 10 ** 6, 3)
|
||||
timer_blocked = timer.blocked_autorange()
|
||||
# print(f"Forward + sum + Backward = {fwd_sum_bwd_latency}")
|
||||
|
||||
bwd_latency = round(fwd_bwd_latency - fwd_latency, 3)
|
||||
print(f"Backward = {bwd_latency}")
|
||||
|
||||
print("################################################")
|
||||
print(f"#### Torch Timer for {string_id} ends ###############")
|
||||
print("################################################\n\n\n\n")
|
||||
|
||||
|
||||
def time_with_manual_timer(fn, args, string_id):
|
||||
print("################################################")
|
||||
print(f"#### Manual Timer for {string_id} starts #########")
|
||||
print("################################################")
|
||||
warmup = 50
|
||||
repeats = 1000
|
||||
old_args = args[:]
|
||||
ref = fn(*old_args)
|
||||
gO = torch.rand_like(ref)
|
||||
for _ in range(0, warmup):
|
||||
args = list(old_args[:])
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
arg.grad = None
|
||||
ref = fn(*args)
|
||||
ref.backward(gO)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
fwd_times = []
|
||||
bwd_times = []
|
||||
for _ in range(0, repeats):
|
||||
args = list(old_args[:])
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
arg.grad = None
|
||||
fwd_start = time.time()
|
||||
ref = fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
|
||||
bwd_start = time.time()
|
||||
ref.backward(gO)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
|
||||
fwd_times.append(fwd_end - fwd_start)
|
||||
bwd_times.append(bwd_end - bwd_start)
|
||||
avg_fwd = round(sum(fwd_times) / repeats * 10 ** 6, 2)
|
||||
avg_bwd = round(sum(bwd_times) / repeats * 10 ** 6, 2)
|
||||
avg_total = round(avg_fwd + avg_bwd, 2)
|
||||
|
||||
print(f"Forward = {avg_fwd}")
|
||||
print(f"Backward = {avg_bwd}")
|
||||
|
||||
print("################################################")
|
||||
print(f"#### Manual Timer for {string_id} ends #########")
|
||||
print("################################################\n\n\n")
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
import torch
|
||||
from functorch.compile import memory_efficient_pointwise_fusion
|
||||
import benchmark_helper
|
||||
|
||||
# ALL comments regarding the patetrns
|
||||
|
||||
|
||||
def bias_gelu_dropout(input, bias):
|
||||
a = torch.add(input, bias)
|
||||
b = torch.nn.functional.gelu(a)
|
||||
c = torch.nn.functional.dropout(b, p=0.6, training=True)
|
||||
return c
|
||||
|
||||
|
||||
def aot_fn(input, bias):
|
||||
a = torch.add(input, bias)
|
||||
b = a * 0.5 * (1.0 + torch.tanh(0.79788456 * a * (1 + 0.044715 * a * a)))
|
||||
c = torch.nn.functional.dropout(b, p=0.6, training=True)
|
||||
return c
|
||||
|
||||
|
||||
fn = bias_gelu_dropout
|
||||
|
||||
|
||||
# Set inputs
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
batch_size = 32
|
||||
seq_len = 196
|
||||
intermediate_size = 4096
|
||||
# batch_size = 2
|
||||
# seq_len = 4
|
||||
# intermediate_size = 3
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
intermediate_size,
|
||||
requires_grad=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
bias = torch.randn(intermediate_size, requires_grad=True, device=device, dtype=dtype)
|
||||
|
||||
|
||||
# Get the optimized function
|
||||
opt_fn = memory_efficient_pointwise_fusion(
|
||||
aot_fn, compiler_name="torchscript_nvfuser"
|
||||
)
|
||||
|
||||
|
||||
# Profile cuda kernels
|
||||
benchmark_helper.profile_cuda_kernels(fn, (input, bias), "Eager")
|
||||
with torch.jit.fuser("fuser2"):
|
||||
benchmark_helper.profile_cuda_kernels(opt_fn, (input, bias), "AOTAutograd")
|
||||
|
||||
|
||||
# Time it with Torch Timer
|
||||
benchmark_helper.time_with_torch_timer(fn, (input, bias), "Eager")
|
||||
with torch.jit.fuser("fuser2"):
|
||||
benchmark_helper.time_with_torch_timer(opt_fn, (input, bias), "AOTAutograd")
|
||||
|
||||
# Time it with manual Timer
|
||||
benchmark_helper.time_with_manual_timer(fn, (input, bias), "Eager")
|
||||
with torch.jit.fuser("fuser2"):
|
||||
benchmark_helper.time_with_manual_timer(opt_fn, (input, bias), "AOTAutograd")
|
||||
|
|
@ -357,6 +357,7 @@ ALLOW_LIST = [
|
|||
("aten::_nested_view_from_buffer_copy.out", datetime.date(2023, 5, 1)),
|
||||
("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)),
|
||||
("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)),
|
||||
("aten::_scaled_dot_product_flash_attention_backward", datetime.date(2023, 6, 1)),
|
||||
# These ops were moved to python under the c10d_functional namespace
|
||||
("aten::wait_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class TestFlopCounter(TestCase):
|
|||
self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
|
||||
|
||||
def test_custom(self):
|
||||
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out: 5})
|
||||
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5})
|
||||
with mode:
|
||||
a = T(4, 5)
|
||||
a + a
|
||||
|
|
|
|||
|
|
@ -1352,7 +1352,7 @@ class TestSDPA(NNTestCase):
|
|||
assert torch._fused_sdp_choice(q, k, v) == SDPBackend.MATH
|
||||
|
||||
if PLATFORM_SUPPORTS_FUSED_SDPA:
|
||||
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
|
||||
batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
|
||||
shape = (batch_size, seq_len, num_heads, head_dim)
|
||||
device = "cuda"
|
||||
make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float16, packed=True)
|
||||
|
|
|
|||
|
|
@ -2932,8 +2932,6 @@ def aot_function(
|
|||
partition_fn: Callable = default_partition,
|
||||
decompositions: Optional[Dict] = None,
|
||||
num_params_buffers: int = 0,
|
||||
hasher_type=None, # deprecated
|
||||
static_argnums: Optional[Tuple[int]] = None, # deprecated
|
||||
keep_inference_input_mutations: bool = False,
|
||||
inference_compiler: Optional[Callable] = None,
|
||||
*,
|
||||
|
|
@ -2992,10 +2990,6 @@ def aot_function(
|
|||
>>> x = torch.randn(4, 5, requires_grad=True)
|
||||
>>> aot_fn(x)
|
||||
"""
|
||||
if static_argnums is not None:
|
||||
raise RuntimeError(
|
||||
"static_argnums has been deprecated - manually wrap your function or use torchdynamo."
|
||||
)
|
||||
|
||||
if bw_compiler is None:
|
||||
bw_compiler = fw_compiler
|
||||
|
|
@ -3127,8 +3121,6 @@ def aot_module_simplified(
|
|||
bw_compiler: Optional[Callable] = None,
|
||||
partition_fn: Callable = default_partition,
|
||||
decompositions: Optional[Dict] = None,
|
||||
hasher_type=None,
|
||||
static_argnums=None,
|
||||
keep_inference_input_mutations=False,
|
||||
inference_compiler: Optional[Callable] = None,
|
||||
) -> nn.Module:
|
||||
|
|
@ -3175,6 +3167,7 @@ def aot_module_simplified(
|
|||
with stateless._reparametrize_module(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
):
|
||||
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
|
|
@ -3193,7 +3186,6 @@ def aot_module_simplified(
|
|||
)
|
||||
return out
|
||||
|
||||
assert static_argnums is None
|
||||
if bw_compiler is None:
|
||||
bw_compiler = fw_compiler
|
||||
if inference_compiler is None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pickle
|
|||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, Union
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
|
@ -188,8 +188,8 @@ def simple_ts_compile(fx_g, _):
|
|||
return f
|
||||
|
||||
|
||||
def nnc_jit(f, static_argnums=None):
|
||||
return aot_function(f, simple_ts_compile, static_argnums=static_argnums)
|
||||
def nnc_jit(f):
|
||||
return aot_function(f, simple_ts_compile)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
@ -229,7 +229,6 @@ def print_compile(fx_g, _):
|
|||
|
||||
def memory_efficient_fusion(
|
||||
fn: Union[Callable, nn.Module],
|
||||
static_argnums: Optional[Tuple[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
|
@ -245,8 +244,6 @@ def memory_efficient_fusion(
|
|||
Args:
|
||||
fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
|
||||
that takes one ore more arguments. Must return one or more Tensors.
|
||||
static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
|
||||
the arguments of the function as static.
|
||||
**kwargs: Any other overrides you want to make to the settings
|
||||
|
||||
Returns:
|
||||
|
|
@ -261,7 +258,6 @@ def memory_efficient_fusion(
|
|||
"bw_compiler": ts_compile,
|
||||
"partition_fn": min_cut_rematerialization_partition,
|
||||
"decompositions": default_decompositions,
|
||||
"static_argnums": static_argnums,
|
||||
}
|
||||
config.update(kwargs)
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|||
func1.__name__ = name
|
||||
return func1
|
||||
|
||||
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
|
||||
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
|
||||
|
||||
func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
|
||||
buf_to_fx_node = {}
|
||||
|
|
@ -135,7 +135,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|||
name = snode.get_name()
|
||||
fx_node.name = name
|
||||
|
||||
fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
|
||||
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
|
||||
|
||||
if isinstance(snode, FusedSchedulerNode):
|
||||
for x in snode.snodes:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from functools import wraps
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._functorch.aot_autograd import (
|
||||
|
|
@ -19,8 +19,6 @@ def patched_aot_function(
|
|||
partition_fn: Callable[..., object] = default_partition,
|
||||
decompositions: Optional[Dict[object, object]] = None,
|
||||
num_params_buffers: int = 0,
|
||||
hasher_type: object = None, # deprecated
|
||||
static_argnums: Optional[Tuple[int]] = None, # deprecated
|
||||
keep_inference_input_mutations: bool = False,
|
||||
pre_compile_fn: Optional[Callable[..., object]] = None,
|
||||
) -> Callable[..., object]:
|
||||
|
|
@ -98,11 +96,6 @@ def patched_aot_function(
|
|||
>>> x = torch.randn(4, 5, requires_grad=True)
|
||||
>>> aot_fn(x)
|
||||
"""
|
||||
if static_argnums is not None:
|
||||
raise RuntimeError(
|
||||
"static_argnums has been deprecated - manually wrap your function or use torchdynamo."
|
||||
)
|
||||
|
||||
if bw_compiler is None:
|
||||
bw_compiler = fw_compiler
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def get_shape(i):
|
|||
return i.shape
|
||||
return i
|
||||
|
||||
def mm_flop(a_shape, b_shape, out=None) -> int:
|
||||
def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for matmul.
|
||||
"""
|
||||
|
|
@ -26,13 +26,13 @@ def mm_flop(a_shape, b_shape, out=None) -> int:
|
|||
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
|
||||
return m * n * 2 * k
|
||||
|
||||
def addmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int:
|
||||
def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for addmm
|
||||
"""
|
||||
return mm_flop(a_shape, b_shape)
|
||||
|
||||
def bmm_flop(a_shape, b_shape, out=None, **kwargs) -> int:
|
||||
def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for the bmm operation.
|
||||
"""
|
||||
|
|
@ -46,7 +46,7 @@ def bmm_flop(a_shape, b_shape, out=None, **kwargs) -> int:
|
|||
flop = b * m * n * 2 * k
|
||||
return flop
|
||||
|
||||
def baddbmm_flop(self_shape, a_shape, b_shape, out=None, **kwargs) -> int:
|
||||
def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for the baddbmm operation.
|
||||
"""
|
||||
|
|
@ -83,11 +83,11 @@ def conv_flop_count(
|
|||
flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
|
||||
return flop
|
||||
|
||||
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out=None, **kwargs) -> int:
|
||||
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for convolution.
|
||||
"""
|
||||
return conv_flop_count(x_shape, w_shape, out, transposed=transposed)
|
||||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
||||
|
||||
def transpose_shape(shape):
|
||||
return [shape[1], shape[0]] + list(shape[2:])
|
||||
|
|
@ -104,14 +104,14 @@ def conv_backward_flop(
|
|||
_output_padding,
|
||||
_groups,
|
||||
output_mask,
|
||||
out) -> int:
|
||||
out_shape) -> int:
|
||||
flop_count = 0
|
||||
|
||||
if output_mask[0]:
|
||||
grad_input_shape = get_shape(out[0])
|
||||
grad_input_shape = get_shape(out_shape[0])
|
||||
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
|
||||
if output_mask[1]:
|
||||
grad_weight_shape = get_shape(out[1])
|
||||
grad_weight_shape = get_shape(out_shape[1])
|
||||
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, transposed)
|
||||
|
||||
return flop_count
|
||||
|
|
@ -134,7 +134,7 @@ def sdpa_flop_count(query_shape, key_shape, value_shape):
|
|||
|
||||
|
||||
|
||||
def sdpa_flop(query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int:
|
||||
def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for self-attention.
|
||||
"""
|
||||
|
|
@ -169,7 +169,7 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape
|
|||
return total_flops
|
||||
|
||||
|
||||
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out=None, **kwargs) -> int:
|
||||
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
|
||||
"""
|
||||
Count flops for self-attention backward.
|
||||
"""
|
||||
|
|
@ -306,6 +306,9 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
|
||||
return PopState.apply
|
||||
|
||||
def get_total_flops(self) -> int:
|
||||
return sum(self.flop_counts['Global'].values())
|
||||
|
||||
def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
|
||||
"""Returns the flop counts as a dictionary of dictionaries. The outer
|
||||
dictionary is keyed by module name, and the inner dictionary is keyed by
|
||||
|
|
@ -326,7 +329,7 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
tabulate.PRESERVE_WHITESPACE = True
|
||||
header = ["Module", "FLOP", "% Total"]
|
||||
values = []
|
||||
global_flops = sum(self.flop_counts['Global'].values())
|
||||
global_flops = self.get_total_flops()
|
||||
global_suffix = get_suffix_str(global_flops)
|
||||
is_global_subsumed = False
|
||||
|
||||
|
|
@ -394,7 +397,7 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
if func_packet in self.flop_mapping:
|
||||
flop_count_func = self.flop_mapping[func_packet]
|
||||
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
|
||||
flop_count = flop_count_func(*args, **kwargs, out=out_shape) # type: ignore[operator]
|
||||
flop_count = flop_count_func(*args, **kwargs, out_shape=out_shape) # type: ignore[operator]
|
||||
for par in self.parents:
|
||||
self.flop_counts[par][func_packet] += flop_count
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user