Remove windows check for cmake to build Fused kernels (#91909)

# Summary
Add support for fused attention kernels (FlashAttention and memory-efficient attention) on Windows. Previously we could not do this because the fixes required c++17 to do this but we have since update the PyTorch standard.

This PR:
- Changes invocations of unsigned long to the fixed width integer type
- Adds in the #define FP16_SWITCH(COND, ...) which has been added to the flash_attention main branch
- Changes the some macros used within mem-efficient attention code in order to work around the VA_ARG discrepancy between clang/gcc and msvc. An alternative would be setting the global flag Zc:preprocessor
- Selectively applies /Zc:lambda to only the mem-efficient sources since applying this globally caused quantization files to not compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91909
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous 2023-01-25 01:21:12 +00:00 committed by PyTorch MergeBot
parent f0d09572b0
commit a3715efd8b
9 changed files with 45 additions and 23 deletions

View File

@ -717,7 +717,7 @@ include(cmake/Dependencies.cmake)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
if(USE_FLASH_ATTENTION)
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
ENDIF()

View File

@ -43,6 +43,7 @@ set(ATen_PUBLIC_HIP_DEPENDENCY_LIBS)
set(ATEN_INSTALL_BIN_SUBDIR "bin" CACHE PATH "ATen install binary subdirectory")
set(ATEN_INSTALL_LIB_SUBDIR "lib" CACHE PATH "ATen install library subdirectory")
set(ATEN_INSTALL_INCLUDE_SUBDIR "include" CACHE PATH "ATen install include subdirectory")
set(MEM_EFF_ATTENTION_CUDA_SOURCES)
if(USE_CUDA)
list(APPEND ATen_CUDA_INCLUDE ${CUDA_INCLUDE_DIRS})
@ -125,3 +126,4 @@ set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)

View File

@ -172,6 +172,7 @@ if(USE_FLASH_ATTENTION)
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_cu})
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_kernels_cu})
list(APPEND native_transformers_cuda_cpp ${mem_eff_attention_cuda_cpp})
list(APPEND MEM_EFF_ATTENTION_CUDA_SOURCES ${native_transformers_cuda_cu} ${mem_eff_attention_cuda_cu} ${mem_eff_attention_cuda_kernels_cu})
endif()
# XNNPACK
@ -621,3 +622,4 @@ set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)

View File

@ -10,7 +10,6 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <iostream>
#ifdef USE_FLASH_ATTENTION
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
#endif

View File

@ -308,10 +308,12 @@ class PredicatedTileIteratorPrefetch {
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
unsigned long addr =
(unsigned long)((void*)&memory_pointer
[column * ThreadMap::Delta::kColumn /
kElementsPerAccess]);
// on windows using unsigned long here gives the error
// error: asm operand type size(4) does not match
// type/size implied by constraint 'l'
uint64_t addr = (uint64_t)(
(void*)&memory_pointer
[column * ThreadMap::Delta::kColumn / kElementsPerAccess]);
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
}

View File

@ -79,7 +79,7 @@ struct AttentionKernel {
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer =
!kKeepOutputInRF && !std::is_same<output_accum_t, output_t>::value;
!kKeepOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
@ -863,15 +863,19 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
int(__CUDA_ARCH_OR_ZERO__)); \
_ATTENTION_KERNEL_FORWARD_END();
// On windows we don't build with /Zc:preprocessor
// See: https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
#define EXPAND( x ) x
// All kernels are disabled by default
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__))
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__))
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__))
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__))
// Enable the right one based on __CUDA_ARCH__
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 500
@ -879,17 +883,17 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
#elif __CUDA_ARCH__ < 700
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__))
#elif __CUDA_ARCH__ < 750
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__))
#elif __CUDA_ARCH__ < 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__))
#elif __CUDA_ARCH__ >= 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__))
#endif

View File

@ -1814,6 +1814,20 @@ if(BUILD_TEST)
endif()
endif()
if(MSVC)
# This is used to enable the conforming lambda processor in MSVC
# Which allows us to capture constexpr in lambdas
# Note that this will be turned on by default for std=c++20 and above
# This should be applied globally when https://github.com/pytorch/pytorch/issues/92600 is fixed
foreach(tmp ${MEM_EFF_ATTENTION_CUDA_SOURCES})
# MEM_EFF_ATTENTION_CUDA is populated in pytorch/aten/src/ATen/CMakeLists.txt
# We iterate over these files, updating paths and adding the compile flag
FILE(RELATIVE_PATH tmp_path "${PROJECT_SOURCE_DIR}" "${tmp}")
SET(tmp_path "../${tmp_path}")
set_source_files_properties(${tmp_path} PROPERTIES COMPILE_FLAGS "-Xcompiler /Zc:lambda")
endforeach()
endif()
# Note: we only install the caffe2 python files if BUILD_CAFFE2_OPS is ON
# This is because the build rules here written in such a way that they always
# appear to need to be re-run generating >600 pieces of work during the pytorch

View File

@ -10,7 +10,7 @@ import itertools
import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \
IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like, IS_WINDOWS
IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like
import torch
from torch import Tensor
import functools
@ -383,8 +383,9 @@ class TestOperators(TestCase):
# RuntimeError: Tensor must have a last dimension with stride 1
xfail('view_as_complex'),
decorate('nn.functional.scaled_dot_product_attention',
decorator=expectedFailureIf(not IS_WINDOWS), device_type='cuda'),
# query: last dimension must be contiguous
# Fused attention kernels require last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention', device_type='cuda'),
}))
@opsToleranceOverride('TestOperators', 'test_grad', (
tol1('nn.functional.binary_cross_entropy_with_logits',
@ -572,9 +573,8 @@ class TestOperators(TestCase):
# expects last dim to have stride=1
xfail('view_as_complex'),
# RuntimeError: query: last dimension must be contiguous
# NOTE: This passes on Windows!
decorate('nn.functional.scaled_dot_product_attention',
decorator=unittest.skipIf(not IS_WINDOWS, "expects contiguous inputs")),
# The fused attention kernels require the last dim to be contiguous
xfail('nn.functional.scaled_dot_product_attention', device_type="cuda"),
# BUG
# AssertionError: Tensor-likes are not close!
xfail('as_strided'),

View File

@ -23,7 +23,6 @@ from torch.testing._internal.common_utils import (
freeze_rng_state,
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
IS_WINDOWS,
slowTest,
set_default_dtype,
gradcheck
@ -36,7 +35,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater
if TEST_FAIRSEQ:
import fairseq.models.transformer as fairseq_transformer
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM and not IS_WINDOWS
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
@contextlib.contextmanager
def use_deterministic_algorithims(mode: bool, warn_only: bool):