pytorch/test/test_flop_counter.py
drisspg ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00

242 lines
9.2 KiB
Python

# Owner(s): ["module: unknown"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_TORCHDYNAMO
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
import torch.utils.flop_counter
import torch.nn.functional as F
import unittest
import functools
try:
from torchvision import models as torchvision_models
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
HAS_CUDA = torch.cuda.is_available()
def FlopCounterMode(*args, **kwargs):
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
def get_total_flops(mode):
return str(sum([v for _, v in mode.flop_counts["Global"].items()]))
def T(*shape, requires_grad=False):
return torch.randn(*shape, requires_grad=requires_grad)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now")
class TestFlopCounter(TestCase):
def test_flop_counter_variety(self):
mode = FlopCounterMode()
mod = torch.nn.Linear(9, 10)
with mode:
torch.mm(T(4, 5), T(5, 6))
torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
torch.matmul(T(5, 6), T(6, 7))
torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
mod(T(8, 9))
self.assertExpectedInline(get_total_flops(mode), """3012""")
def test_op(self):
mode = FlopCounterMode()
with mode:
torch.mm(T(4, 5), T(5, 6))
# 4 * 6 * 2 * 5 = 240
self.assertExpectedInline(get_total_flops(mode), """240""")
with mode:
torch.bmm(T(3, 4, 5), T(3, 5, 6))
# 3 * 4 * 6 * 2 * 5 = 720
self.assertExpectedInline(get_total_flops(mode), """720""")
with mode:
torch.addmm(T(4, 6), T(4, 5), T(5, 6))
torch.addmm(T(4, 1), T(4, 5), T(5, 6))
torch.addmm(T(6), T(4, 5), T(5, 6))
# 4 * 6 * 2 * 5 = 240
self.assertExpectedInline(get_total_flops(mode), """720""")
with mode:
torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
# 3 * 4 * 6 * 2 * 5 = 720
self.assertExpectedInline(get_total_flops(mode), """720""")
with mode:
torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
# out_image_size = 2 * 5 * 5
# kernel_size = 4 * 4
# c_out = 6
# c_in = 3
# out_image_size * kernel_size * c_out * 2 * c_in
# NB: I don't think this properly accounts for padding?
self.assertExpectedInline(get_total_flops(mode), """28800""")
with mode:
torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
# out_image_size = 2 * 5
# kernel_size = 4
# c_out = 6
# c_in = 3
# out_image_size * kernel_size * c_out * 2 * c_in
# NB: I don't think this properly accounts for padding?
self.assertExpectedInline(get_total_flops(mode), """1440""")
def test_backward(self):
mode = FlopCounterMode()
with mode:
a = T(4, 5, requires_grad=True)
a = torch.mm(a, T(5, 6))
a = a.unsqueeze(0).expand(7, 4, 6)
a = torch.bmm(a, T(7, 6, 7))
a.sum().backward()
self.assertExpectedInline(get_total_flops(mode), """5184""")
def test_torchscript(self):
def foo(x):
return torch.mm(x, x)
mode = FlopCounterMode()
with mode:
foo(T(5, 5))
unscripted_flops = get_total_flops(mode)
ts_foo = torch.jit.script(foo)
with mode:
ts_foo(T(5, 5))
self.assertEqual(unscripted_flops, get_total_flops(mode))
def test_autograd_op(self):
class _CustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
return torch.mm(input, input)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return torch.mm(grad_output, grad_output) + torch.mm(grad_output, grad_output)
a = T(5, 5, requires_grad=True)
mode = FlopCounterMode()
with mode:
a = _CustomOp.apply(a)
a.sum().backward()
self.assertExpectedInline(get_total_flops(mode), """750""")
@skipIfNoTorchVision
def test_module(self):
resnet18 = torchvision_models.resnet18()
mode = FlopCounterMode(resnet18)
with mode:
a = T(1, 3, 224, 224, requires_grad=True)
resnet18(a).sum().backward()
self.assertExpectedInline(get_total_flops(mode), """10884440064""")
layer1_conv_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution]
layer1_conv_back_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution_backward]
self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
def test_custom(self):
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5})
with mode:
a = T(4, 5)
a + a
self.assertExpectedInline(get_total_flops(mode), """5""")
def test_noop(self):
mode = FlopCounterMode()
with mode:
T(4, 5).cos()
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
def test_sdpa(self):
batch_size = 4
n_heads = 8
seq_len_q = 128
seq_len_k = 256
head_dim = 64
head_dim_v = 64
dtype = torch.float16
torch.manual_seed(0)
def get_flops(batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype, backend, with_backward=False):
query = torch.randn(batch_size, n_heads, seq_len_q, head_dim, device='cuda', dtype=dtype, requires_grad=True)
key = torch.randn(batch_size, n_heads, seq_len_k, head_dim, device='cuda', dtype=dtype, requires_grad=True)
value = torch.randn(batch_size, n_heads, seq_len_k, head_dim_v, device='cuda', dtype=dtype, requires_grad=True)
if backend == "math":
backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
elif backend == "flash":
backend = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
elif backend == "mem_efficient":
backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
mode = FlopCounterMode()
with backend, mode:
out = F.scaled_dot_product_attention(query, key, value, dropout_p=0, is_causal=True)
if with_backward:
out.sum().backward()
return int(get_total_flops(mode))
# Sets seq_len_q == seq_len_k and dim_q == dim_v
run_uniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_q, head_dim, head_dim, dtype)
flops = [run_uniform_flops(backend, with_backward=False) for backend in ["math", "flash", "mem_efficient"]]
flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
self.assertEqual(flops_fw_math, flops_fw_flash)
self.assertEqual(flops_fw_math, flops_fw_efficient)
self.assertExpectedInline(str(flops_fw_math), """134217728""")
flops = [run_uniform_flops(backend, with_backward=True) for backend in ["math", "flash", "mem_efficient"]]
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
run_nonuniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype)
flops = [run_nonuniform_flops(backend, with_backward=False) for backend in ["math", "flash", "mem_efficient"]]
flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
self.assertEqual(flops_fw_math, flops_fw_flash, flops_fw_efficient)
self.assertExpectedInline(str(flops_fw_math), """268435456""")
flops = [run_nonuniform_flops(backend, with_backward=True) for backend in ["math", "flash", "mem_efficient"]]
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
self.assertExpectedInline(str(flops_fw_bw_flash), """939524096""")
def test_hook_registration(self):
model = torch.nn.Linear(100, 100)
x = torch.randn(3, 100)
flop_counter = FlopCounterMode(model)
with flop_counter:
self.assertEqual(len(model._forward_pre_hooks), 1)
self.assertEqual(len(model._forward_hooks), 1)
model(x).sum().backward()
self.assertEqual(len(model._forward_pre_hooks), 0)
self.assertEqual(len(model._forward_hooks), 0)
if __name__ == '__main__':
run_tests()