mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove windows check for cmake to build Fused kernels (#91909)
# Summary Add support for fused attention kernels (FlashAttention and memory-efficient attention) on Windows. Previously we could not do this because the fixes required c++17 to do this but we have since update the PyTorch standard. This PR: - Changes invocations of unsigned long to the fixed width integer type - Adds in the #define FP16_SWITCH(COND, ...) which has been added to the flash_attention main branch - Changes the some macros used within mem-efficient attention code in order to work around the VA_ARG discrepancy between clang/gcc and msvc. An alternative would be setting the global flag Zc:preprocessor - Selectively applies /Zc:lambda to only the mem-efficient sources since applying this globally caused quantization files to not compile Pull Request resolved: https://github.com/pytorch/pytorch/pull/91909 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
f0d09572b0
commit
a3715efd8b
|
|
@ -717,7 +717,7 @@ include(cmake/Dependencies.cmake)
|
|||
cmake_dependent_option(
|
||||
USE_FLASH_ATTENTION
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention" ON
|
||||
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
|
||||
ENDIF()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ set(ATen_PUBLIC_HIP_DEPENDENCY_LIBS)
|
|||
set(ATEN_INSTALL_BIN_SUBDIR "bin" CACHE PATH "ATen install binary subdirectory")
|
||||
set(ATEN_INSTALL_LIB_SUBDIR "lib" CACHE PATH "ATen install library subdirectory")
|
||||
set(ATEN_INSTALL_INCLUDE_SUBDIR "include" CACHE PATH "ATen install include subdirectory")
|
||||
set(MEM_EFF_ATTENTION_CUDA_SOURCES)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CUDA_INCLUDE_DIRS})
|
||||
|
|
@ -125,3 +126,4 @@ set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
|
|||
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
|
||||
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
|
||||
|
|
@ -172,6 +172,7 @@ if(USE_FLASH_ATTENTION)
|
|||
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_cu})
|
||||
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_kernels_cu})
|
||||
list(APPEND native_transformers_cuda_cpp ${mem_eff_attention_cuda_cpp})
|
||||
list(APPEND MEM_EFF_ATTENTION_CUDA_SOURCES ${native_transformers_cuda_cu} ${mem_eff_attention_cuda_cu} ${mem_eff_attention_cuda_kernels_cu})
|
||||
endif()
|
||||
|
||||
# XNNPACK
|
||||
|
|
@ -621,3 +622,4 @@ set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
|
|||
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
|
||||
|
|
@ -10,7 +10,6 @@
|
|||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||
|
||||
#include <iostream>
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -308,10 +308,12 @@ class PredicatedTileIteratorPrefetch {
|
|||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
unsigned long addr =
|
||||
(unsigned long)((void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn /
|
||||
kElementsPerAccess]);
|
||||
// on windows using unsigned long here gives the error
|
||||
// error: asm operand type size(4) does not match
|
||||
// type/size implied by constraint 'l'
|
||||
uint64_t addr = (uint64_t)(
|
||||
(void*)&memory_pointer
|
||||
[column * ThreadMap::Delta::kColumn / kElementsPerAccess]);
|
||||
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ struct AttentionKernel {
|
|||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer =
|
||||
!kKeepOutputInRF && !std::is_same<output_accum_t, output_t>::value;
|
||||
!kKeepOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
|
||||
static_assert(kQueriesPerBlock % 32 == 0, "");
|
||||
static_assert(kKeysPerBlock % 32 == 0, "");
|
||||
|
|
@ -863,15 +863,19 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|||
int(__CUDA_ARCH_OR_ZERO__)); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
// On windows we don't build with /Zc:preprocessor
|
||||
// See: https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
|
||||
#define EXPAND( x ) x
|
||||
|
||||
// All kernels are disabled by default
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__))
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__))
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__))
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__))
|
||||
|
||||
// Enable the right one based on __CUDA_ARCH__
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 500
|
||||
|
|
@ -879,17 +883,17 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|||
#elif __CUDA_ARCH__ < 700
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__))
|
||||
#elif __CUDA_ARCH__ < 750
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__))
|
||||
#elif __CUDA_ARCH__ < 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__))
|
||||
#elif __CUDA_ARCH__ >= 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
|
||||
EXPAND(INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__))
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1814,6 +1814,20 @@ if(BUILD_TEST)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
# This is used to enable the conforming lambda processor in MSVC
|
||||
# Which allows us to capture constexpr in lambdas
|
||||
# Note that this will be turned on by default for std=c++20 and above
|
||||
# This should be applied globally when https://github.com/pytorch/pytorch/issues/92600 is fixed
|
||||
foreach(tmp ${MEM_EFF_ATTENTION_CUDA_SOURCES})
|
||||
# MEM_EFF_ATTENTION_CUDA is populated in pytorch/aten/src/ATen/CMakeLists.txt
|
||||
# We iterate over these files, updating paths and adding the compile flag
|
||||
FILE(RELATIVE_PATH tmp_path "${PROJECT_SOURCE_DIR}" "${tmp}")
|
||||
SET(tmp_path "../${tmp_path}")
|
||||
set_source_files_properties(${tmp_path} PROPERTIES COMPILE_FLAGS "-Xcompiler /Zc:lambda")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
# Note: we only install the caffe2 python files if BUILD_CAFFE2_OPS is ON
|
||||
# This is because the build rules here written in such a way that they always
|
||||
# appear to need to be re-run generating >600 pieces of work during the pytorch
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import itertools
|
|||
import unittest
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \
|
||||
IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like, IS_WINDOWS
|
||||
IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import functools
|
||||
|
|
@ -383,8 +383,9 @@ class TestOperators(TestCase):
|
|||
|
||||
# RuntimeError: Tensor must have a last dimension with stride 1
|
||||
xfail('view_as_complex'),
|
||||
decorate('nn.functional.scaled_dot_product_attention',
|
||||
decorator=expectedFailureIf(not IS_WINDOWS), device_type='cuda'),
|
||||
# query: last dimension must be contiguous
|
||||
# Fused attention kernels require last dim to be contiguous
|
||||
xfail('nn.functional.scaled_dot_product_attention', device_type='cuda'),
|
||||
}))
|
||||
@opsToleranceOverride('TestOperators', 'test_grad', (
|
||||
tol1('nn.functional.binary_cross_entropy_with_logits',
|
||||
|
|
@ -572,9 +573,8 @@ class TestOperators(TestCase):
|
|||
# expects last dim to have stride=1
|
||||
xfail('view_as_complex'),
|
||||
# RuntimeError: query: last dimension must be contiguous
|
||||
# NOTE: This passes on Windows!
|
||||
decorate('nn.functional.scaled_dot_product_attention',
|
||||
decorator=unittest.skipIf(not IS_WINDOWS, "expects contiguous inputs")),
|
||||
# The fused attention kernels require the last dim to be contiguous
|
||||
xfail('nn.functional.scaled_dot_product_attention', device_type="cuda"),
|
||||
# BUG
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
xfail('as_strided'),
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from torch.testing._internal.common_utils import (
|
|||
freeze_rng_state,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_ROCM,
|
||||
IS_WINDOWS,
|
||||
slowTest,
|
||||
set_default_dtype,
|
||||
gradcheck
|
||||
|
|
@ -36,7 +35,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater
|
|||
if TEST_FAIRSEQ:
|
||||
import fairseq.models.transformer as fairseq_transformer
|
||||
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM and not IS_WINDOWS
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_deterministic_algorithims(mode: bool, warn_only: bool):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user