mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Remove support for CUDA versions below 12.3 in XLA.
PiperOrigin-RevId: 650373808
This commit is contained in:
parent
a4a5690b33
commit
de2c871507
|
|
@ -76,11 +76,6 @@ inline void LogIfError(const absl::Status &status) {
|
|||
LOG(ERROR) << status.message();
|
||||
}
|
||||
|
||||
// CUPTI_ERROR_INSUFFICIENT_PRIVILEGES is introduced at CUDA 10.1.
|
||||
#if CUDA_VERSION <= 10000
|
||||
#define CUPTI_ERROR_INSUFFICIENT_PRIVILEGES 35
|
||||
#endif
|
||||
|
||||
#define RETURN_IF_CUPTI_ERROR(expr) \
|
||||
do { \
|
||||
CUptiResult status = expr; \
|
||||
|
|
@ -673,9 +668,7 @@ static void SetCallbackEventUponApiExit(CuptiTracerEvent &event,
|
|||
uint64_t start_tsc, uint64_t end_tsc) {
|
||||
switch (cbid) {
|
||||
case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel:
|
||||
#if CUDA_VERSION >= 11080 // CUDA 11.8
|
||||
case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx:
|
||||
#endif // CUDA_VERSION >= 11080
|
||||
case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel:
|
||||
case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice:
|
||||
SetKernelEventUponApiExit(event, device_id, cbdata, start_tsc, end_tsc);
|
||||
|
|
|
|||
|
|
@ -57,17 +57,11 @@ CUptiResult CuptiWrapper::ActivityRegisterCallbacks(
|
|||
}
|
||||
|
||||
CUptiResult CuptiWrapper::ActivityUsePerThreadBuffer() {
|
||||
#if CUDA_VERSION >= 12030
|
||||
uint8_t use_per_thread_activity_buffer = 1;
|
||||
size_t value_size = sizeof(use_per_thread_activity_buffer);
|
||||
return cuptiActivitySetAttribute(
|
||||
CUPTI_ACTIVITY_ATTR_PER_THREAD_ACTIVITY_BUFFER, &value_size,
|
||||
&use_per_thread_activity_buffer);
|
||||
#else
|
||||
// cuptiActivitySetAttribute returns CUPTI_ERROR_INVALID_PARAMETER if invoked
|
||||
// with an invalid first parameter.
|
||||
return CUPTI_ERROR_INVALID_PARAMETER;
|
||||
#endif
|
||||
}
|
||||
|
||||
CUptiResult CuptiWrapper::GetDeviceId(CUcontext context, uint32_t* deviceId) {
|
||||
|
|
|
|||
|
|
@ -81,52 +81,50 @@ absl::Status GpuTracer::DoStart() {
|
|||
}
|
||||
|
||||
options_.cbids_selected = {
|
||||
// KERNEL
|
||||
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel,
|
||||
#if CUDA_VERSION >= 11080 // CUDA 11.8
|
||||
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx,
|
||||
#endif // CUDA_VERSION >= 11080
|
||||
// MEMCPY
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoH_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoHAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoD_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoA_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoA_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy2D_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy2DUnaligned_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy2DAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy3D_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy3DAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoA_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoAAsync_v2,
|
||||
// MemAlloc
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemAlloc_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemAllocPitch_v2,
|
||||
// MemFree
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemFree_v2,
|
||||
// Memset
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD8_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD16_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD32_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD8Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD16Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD32Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32Async,
|
||||
// GENERIC
|
||||
CUPTI_DRIVER_TRACE_CBID_cuStreamSynchronize,
|
||||
// KERNEL
|
||||
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx,
|
||||
// MEMCPY
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoH_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoHAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoD_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoA_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyAtoA_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy2D_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy2DUnaligned_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy2DAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy3D_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpy3DAsync_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoA_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoAAsync_v2,
|
||||
// MemAlloc
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemAlloc_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemAllocPitch_v2,
|
||||
// MemFree
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemFree_v2,
|
||||
// Memset
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD8_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD16_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD32_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32_v2,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD8Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD16Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD32Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D8Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D16Async,
|
||||
CUPTI_DRIVER_TRACE_CBID_cuMemsetD2D32Async,
|
||||
// GENERIC
|
||||
CUPTI_DRIVER_TRACE_CBID_cuStreamSynchronize,
|
||||
};
|
||||
|
||||
bool trace_concurrent_kernels = false;
|
||||
|
|
@ -141,10 +139,7 @@ absl::Status GpuTracer::DoStart() {
|
|||
options_.activities_selected.push_back(CUPTI_ACTIVITY_KIND_OVERHEAD);
|
||||
options_.activities_selected.push_back(CUPTI_ACTIVITY_KIND_MEMSET);
|
||||
|
||||
// CUDA/CUPTI 10 have issues (leaks and crashes) with CuptiFinalize.
|
||||
#if CUDA_VERSION >= 11000
|
||||
options_.cupti_finalize = true;
|
||||
#endif
|
||||
|
||||
CuptiTracerCollectorOptions collector_options;
|
||||
collector_options.num_gpus = cupti_tracer_->NumGpus();
|
||||
|
|
|
|||
|
|
@ -798,7 +798,7 @@ StreamExecutorGpuClient::Load(std::unique_ptr<PjRtExecutable> executable) {
|
|||
|
||||
namespace {
|
||||
|
||||
#if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
|
||||
#if defined(GOOGLE_CUDA)
|
||||
|
||||
absl::StatusOr<std::unique_ptr<se::GpuCudaMallocAsyncAllocator>>
|
||||
CreateCudaAsyncAllocator(const LocalDeviceState& device, double memory_fraction,
|
||||
|
|
@ -843,14 +843,14 @@ CreateCudaAsyncAllocator(const LocalDeviceState& device, double memory_fraction,
|
|||
return allocator;
|
||||
}
|
||||
|
||||
#else // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
|
||||
#else // defined(GOOGLE_CUDA)
|
||||
absl::StatusOr<std::unique_ptr<tsl::Allocator>> CreateCudaAsyncAllocator(
|
||||
const LocalDeviceState& device, double memory_fraction, bool reserve_memory,
|
||||
bool create_new_pool, bool sync_mode, bool compute_stats = true) {
|
||||
return FailedPrecondition("CUDA async allocator requires CUDA >= 11.2");
|
||||
}
|
||||
|
||||
#endif // defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
|
||||
#endif // defined(GOOGLE_CUDA)
|
||||
|
||||
// Builds a LocalDeviceState for each GPU present.
|
||||
absl::StatusOr<std::map<int, std::unique_ptr<LocalDeviceState>>>
|
||||
|
|
@ -942,7 +942,7 @@ GetStreamExecutorGpuDeviceAllocator(
|
|||
static_cast<int>(se::MemoryType::kHost));
|
||||
}
|
||||
|
||||
#if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
|
||||
#if defined(GOOGLE_CUDA)
|
||||
const auto& debug_options = xla::GetDebugOptionsFromFlags();
|
||||
if (debug_options.xla_gpu_temp_buffer_use_separate_color()) {
|
||||
// Add memory allocator to allocate memory buffers with persistent temp
|
||||
|
|
|
|||
|
|
@ -298,11 +298,11 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest {
|
|||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900)
|
||||
#define MAYBE_SKIP_TEST(CAUSE) \
|
||||
do { \
|
||||
if (absl::string_view(CAUSE) == "F8") \
|
||||
GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; \
|
||||
#if (CUDNN_VERSION < 8900)
|
||||
#define MAYBE_SKIP_TEST(CAUSE) \
|
||||
do { \
|
||||
if (absl::string_view(CAUSE) == "F8") \
|
||||
GTEST_SKIP() << "FP8 convolutions require cuDNN 8.9."; \
|
||||
} while (0)
|
||||
#else
|
||||
#define MAYBE_SKIP_TEST(CAUSE)
|
||||
|
|
|
|||
|
|
@ -1627,10 +1627,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
|
|||
comp->parent()->config().debug_options();
|
||||
const se::dnn::VersionInfo cudnn_version =
|
||||
GetDnnVersionInfoOrDefault(stream_executor_, cudnn_version_);
|
||||
#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
|
||||
// CUDA needs to be >= 12.0 for cuDNN to work with all supported hardware.
|
||||
// Some cuDNN versions work with CUDA 11, but it is impractical for us to
|
||||
// test those combinations so just disable them.
|
||||
#if !defined(GOOGLE_CUDA)
|
||||
return false;
|
||||
#endif
|
||||
if (!debug_options.xla_gpu_enable_cudnn_fmha() ||
|
||||
|
|
|
|||
|
|
@ -86,12 +86,7 @@ class CudnnFusedMhaRewriterTestHloTest : public HloTestBase {
|
|||
CudnnFusedMhaRewriterTestHloTest()
|
||||
: HloTestBase(/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/false,
|
||||
/*instruction_can_change_layout_func=*/{}) {
|
||||
#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
|
||||
skip_reason_ = "cuDNN fused MHA requires CUDA 12 or later.";
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
/*instruction_can_change_layout_func=*/{}) {}
|
||||
|
||||
protected:
|
||||
size_t CountFusedAttentionCall(HloModule* module, bool is_backward = false) {
|
||||
|
|
|
|||
|
|
@ -56,8 +56,8 @@ class CudnnNormRewriterTest : public GpuCodegenTest {
|
|||
// The following tests evaluate LayerNormXDY configurations, with X the rank of
|
||||
// the input and Y the dimensions that are normalized.
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm2D1) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -125,8 +125,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm2D1) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D3) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -194,8 +194,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -263,8 +263,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D2) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -333,8 +333,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -403,8 +403,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D12) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -473,8 +473,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -543,8 +543,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -600,8 +600,8 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -677,8 +677,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -754,8 +754,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -832,8 +832,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -910,8 +910,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -1025,8 +1025,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -1140,8 +1140,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -1258,8 +1258,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -1376,8 +1376,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) {
|
|||
}
|
||||
|
||||
TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -1496,8 +1496,8 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) {
|
|||
// TODO(b/343124533) Reenable when fixed
|
||||
TEST_F(CudnnNormRewriterTest,
|
||||
DISABLED_LayerNormTrainBackward4D1DoutputReshapeSplit) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
@ -1617,8 +1617,8 @@ TEST_F(CudnnNormRewriterTest,
|
|||
// TODO(b/343124533) Reenable when fixed
|
||||
TEST_F(CudnnNormRewriterTest,
|
||||
DISABLED_LayerNormTrainBackward4D1DoutputReshapeCombine) {
|
||||
#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5.";
|
||||
#if (CUDNN_VERSION < 8905)
|
||||
GTEST_SKIP() << "Layer norm kernels require cuDNN 8.9.5.";
|
||||
#endif
|
||||
if (!(GetCudaComputeCapability().major ==
|
||||
se::CudaComputeCapability::AMPERE) &&
|
||||
|
|
|
|||
|
|
@ -210,12 +210,10 @@ absl::StatusOr<CustomKernel> GetCutlassGemmKernel(
|
|||
return Load<F32xF32ToF32<Default>>(std::move(name), m, n, k, indices,
|
||||
slices, device);
|
||||
case PrimitiveType::BF16:
|
||||
#if CUDA_VERSION >= 12000
|
||||
if (cuda_cc.IsAtLeastHopper()) {
|
||||
return Load<Bf16xBf16ToBf16<Sm90>>(std::move(name), m, n, k, indices,
|
||||
slices, device);
|
||||
}
|
||||
#endif
|
||||
if (cuda_cc.IsAtLeastAmpere()) {
|
||||
return Load<Bf16xBf16ToBf16<Sm80>>(std::move(name), m, n, k, indices,
|
||||
slices, device);
|
||||
|
|
|
|||
|
|
@ -97,10 +97,7 @@ static bool IsAtLeastCuda12300() {
|
|||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
return false;
|
||||
#endif
|
||||
#if CUDA_VERSION >= 12030
|
||||
return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
// Give a short aliases to execution threads.
|
||||
|
|
|
|||
|
|
@ -4855,10 +4855,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteOnPreAdaWithF32Output) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -4891,10 +4887,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -4950,10 +4942,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) {
|
|||
|
||||
// Do not fuse FP8 matrix bias.
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60200
|
||||
GTEST_SKIP() << "F8 gemm rewrite for D to be fp8 with Matrix Bias is only "
|
||||
"supported in ROCm 6.2 and above.";
|
||||
|
|
@ -5007,10 +4995,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5068,10 +5052,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5135,10 +5115,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5179,10 +5155,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) {
|
|||
// Test case where F8 inputs are converted to F32 before the dot, but without
|
||||
// any scaling.
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5232,10 +5204,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5306,10 +5274,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
UnscaledABUnscaledDUnaryOpsWithConvertF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5371,10 +5335,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5441,10 +5401,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5517,10 +5473,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDSelectNonzeroConstantF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5556,10 +5508,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5617,10 +5565,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5682,10 +5626,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5749,10 +5689,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5851,10 +5787,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDApproxGeluActivationF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5950,10 +5882,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -5988,10 +5916,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6059,10 +5983,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6132,10 +6052,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6196,10 +6112,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6253,10 +6165,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6310,10 +6218,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) {
|
|||
// Do not fuse output scaling without type conversion when a matrix bias was
|
||||
// fused.
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 GEMM rewrite requires CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6371,10 +6275,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6445,10 +6345,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6496,10 +6392,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6571,10 +6463,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6660,10 +6548,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6742,10 +6626,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6811,10 +6691,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDVectorBiasThenReluActivationF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6880,10 +6756,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -6966,10 +6838,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
Rank3ScaledABUnscaledDVectorBiasPaddedF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7060,10 +6928,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7142,10 +7006,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
Rank3ScaledABUnscaledDMatrixBiasPaddedF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7235,10 +7095,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
// of dimensions.
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDMatrixBiasWithSliceF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7307,10 +7163,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7376,10 +7228,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7442,10 +7290,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDWithCollectivePermuteF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7508,10 +7352,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDMatrixBiasThenVectorBiasF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7581,10 +7421,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7666,10 +7502,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABScaledDWithDAmaxF8WithF16Intermediates) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7756,10 +7588,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABScaledDReluActivationWithDAmaxF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7842,10 +7670,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif // CUDA_VERSION < 12000
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7882,10 +7706,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDPrecisionF8) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -7958,10 +7778,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
|
|||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest,
|
||||
ScaledABUnscaledDF8ParameterizedBatched) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -8033,10 +7849,6 @@ ENTRY f {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
@ -8071,10 +7883,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
|
|||
}
|
||||
|
||||
TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) {
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 12000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
|
||||
#endif
|
||||
|
||||
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 60000
|
||||
GTEST_SKIP() << "F8 gemm rewrite is only supported in ROCm 6.0 and above.";
|
||||
#endif // TF_ROCM_VERSION < 60000
|
||||
|
|
|
|||
|
|
@ -61,10 +61,6 @@ namespace {
|
|||
class MultiHeadedAttentionTest : public GpuCodegenTest {
|
||||
public:
|
||||
MultiHeadedAttentionTest() {
|
||||
#if !defined(GOOGLE_CUDA) || CUDA_VERSION < 12000
|
||||
skip_reason_ = "cuDNN Fused MHA requires CUDA 12 or later.";
|
||||
return;
|
||||
#endif
|
||||
stream_executor::CudaComputeCapability cc = GetCudaComputeCapability();
|
||||
// Enforce capability minor == 0 because hardware with a non-zero minor
|
||||
// number typically has insufficient shared memory for cuDNN FMHA.
|
||||
|
|
|
|||
|
|
@ -207,25 +207,16 @@ bool CUDABlas::Init() {
|
|||
return false;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
if (!blas_lt_.Init().ok()) {
|
||||
LOG(ERROR) << kCublasNotInitializedExplanation;
|
||||
return false;
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
|
||||
: parent_(CHECK_NOTNULL(parent)),
|
||||
blas_(nullptr)
|
||||
#if CUDA_VERSION >= 11000
|
||||
,
|
||||
blas_lt_(parent)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr), blas_lt_(parent) {}
|
||||
|
||||
CUDABlas::~CUDABlas() {
|
||||
if (blas_ != nullptr) {
|
||||
|
|
@ -306,12 +297,10 @@ struct CUDADataType<Eigen::half> {
|
|||
static constexpr cudaDataType_t type = CUDA_R_16F; // NOLINT
|
||||
};
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
template <>
|
||||
struct CUDADataType<Eigen::bfloat16> {
|
||||
static constexpr cudaDataType_t type = CUDA_R_16BF; // NOLINT
|
||||
};
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
template <>
|
||||
struct CUDADataType<std::complex<Eigen::half>> {
|
||||
|
|
@ -551,18 +540,12 @@ absl::Status CUDABlas::DoBlasGemm(
|
|||
const NumericOptions &numeric_options, blas::CallContext context) {
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
|
||||
#if CUDA_VERSION < 11000
|
||||
if (dtype == blas::DataType::kHalf) {
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
}
|
||||
#else
|
||||
if (dtype == blas::DataType::kFloat) {
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
if (!numeric_options.allow_tf32) {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO(cheshire): Return an error instead.
|
||||
// TODO(cheshire): Why are these checked only for `half` and `float`?
|
||||
|
|
@ -607,7 +590,6 @@ absl::Status CUDABlas::DoBlasGemm(
|
|||
b.opaque(), CUDA_R_16F, ldb, static_cast<const float *>(beta),
|
||||
c->opaque(), CUDA_R_16F, ldc);
|
||||
}
|
||||
#if CUDA_VERSION > 11000
|
||||
case blas::DataType::kBF16: {
|
||||
return DoBlasInternalImpl(
|
||||
cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
|
||||
|
|
@ -616,7 +598,6 @@ absl::Status CUDABlas::DoBlasGemm(
|
|||
b.opaque(), CUDA_R_16BF, ldb, static_cast<const float *>(beta),
|
||||
c->opaque(), CUDA_R_16BF, ldc);
|
||||
}
|
||||
#endif
|
||||
case dnn::kFloat:
|
||||
return DoBlasInternalImpl(
|
||||
cublasSgemm, stream, true /* = pointer_mode_host */, math_type,
|
||||
|
|
@ -693,11 +674,6 @@ static absl::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
|
|||
" uses tensor ops, but tensor ops are not available in sm", cc.major,
|
||||
"X devices."));
|
||||
} else if (type_a == blas::DataType::kFloat) {
|
||||
#if CUDA_VERSION < 11000
|
||||
return absl::InternalError(
|
||||
"Algorithm ", algorithm,
|
||||
" uses tensor ops, but tensor ops are not available for fp32");
|
||||
#else
|
||||
if (cc.major < 8) {
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Algorithm ", algorithm,
|
||||
|
|
@ -705,11 +681,8 @@ static absl::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
|
|||
cc.major, "X devices for float input types."));
|
||||
}
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
#endif
|
||||
} else if (type_a == blas::DataType::kHalf) {
|
||||
#if CUDA_VERSION < 11000
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
#endif
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
} else {
|
||||
return absl::InternalError(
|
||||
absl::StrCat("Algorithm ", algorithm,
|
||||
|
|
@ -791,7 +764,6 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
|
|||
output_profile_result != nullptr));
|
||||
cudaDataType_t cuda_in_type = AsCudaDataType(type_a);
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
// Workaround CUDA bug where batched GEMM is erroneously marked as
|
||||
// unsupported by manually unbatching it on Pascal.
|
||||
if (cuda_in_type == CUDA_R_16BF &&
|
||||
|
|
@ -833,7 +805,6 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
|
|||
PopulateProfileFromTimer(timer, algorithm, output_profile_result));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#endif
|
||||
|
||||
TF_RETURN_IF_ERROR(DoBlasInternalImpl(
|
||||
AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
|
||||
|
|
@ -982,18 +953,10 @@ absl::Status CUDABlas::DoBlasGemmBatchedInternal(
|
|||
cublasMath_t math_type;
|
||||
cublasGemmAlgo_t algo;
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
bool is_16bit = data_type == CUDA_R_16F || data_type == CUDA_R_16BF;
|
||||
#else
|
||||
bool is_16bit = data_type == CUDA_R_16F;
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
if (is_16bit) {
|
||||
#if CUDA_VERSION < 11000
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
#else
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
#endif
|
||||
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
} else if (data_type == CUDA_R_32F) {
|
||||
|
|
@ -1171,18 +1134,11 @@ absl::Status CUDABlas::DoBlasGemmStridedBatched(
|
|||
DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
|
||||
const NumericOptions &numeric_options, blas::CallContext context) {
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
#if CUDA_VERSION < 11000
|
||||
if (dtype == dnn::kHalf) {
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
}
|
||||
#else
|
||||
if (dtype == dnn::kFloat && numeric_options.allow_tf32) {
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
}
|
||||
#endif
|
||||
|
||||
switch (dtype) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
case dnn::kBF16: {
|
||||
CudaComputeCapability cc = stream->GetCudaComputeCapability();
|
||||
if (cc.IsAtLeast(7)) {
|
||||
|
|
@ -1217,7 +1173,6 @@ absl::Status CUDABlas::DoBlasGemmStridedBatched(
|
|||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#endif
|
||||
case dnn::kHalf: {
|
||||
CudaComputeCapability cc = stream->GetCudaComputeCapability();
|
||||
if (cc.major >= 5) {
|
||||
|
|
|
|||
|
|
@ -124,7 +124,6 @@ absl::StatusOr<cublasLtEpilogue_t> AsCublasLtEpilogue(
|
|||
return CUBLASLT_EPILOGUE_BIAS;
|
||||
case gpu::BlasLt::Epilogue::kBiasThenReLU:
|
||||
return CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
#if CUDA_VERSION >= 11040
|
||||
case gpu::BlasLt::Epilogue::kGELU:
|
||||
return CUBLASLT_EPILOGUE_GELU;
|
||||
case gpu::BlasLt::Epilogue::kGELUWithAux:
|
||||
|
|
@ -133,13 +132,6 @@ absl::StatusOr<cublasLtEpilogue_t> AsCublasLtEpilogue(
|
|||
return CUBLASLT_EPILOGUE_GELU_BIAS;
|
||||
case gpu::BlasLt::Epilogue::kBiasThenGELUWithAux:
|
||||
return CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
|
||||
#else
|
||||
case gpu::BlasLt::Epilogue::kGELU:
|
||||
case gpu::BlasLt::Epilogue::kGELUWithAux:
|
||||
case gpu::BlasLt::Epilogue::kBiasThenGELU:
|
||||
case gpu::BlasLt::Epilogue::kBiasThenGELUWithAux:
|
||||
return absl::InternalError("GELU epilogues require cublasLt >= 11.4");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -437,7 +429,6 @@ absl::Status BlasLt::MatmulPlan::DoMatmul(
|
|||
TF_RETURN_IF_ERROR(SetAttr(
|
||||
op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias.opaque()));
|
||||
}
|
||||
#if CUDA_VERSION >= 11080
|
||||
if (a_scale != nullptr) {
|
||||
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
|
||||
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
||||
|
|
@ -463,16 +454,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul(
|
|||
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER,
|
||||
d_amax.opaque()));
|
||||
}
|
||||
#else
|
||||
if (a_scale != nullptr || b_scale != nullptr || c_scale != nullptr ||
|
||||
d_scale != nullptr || d_amax != nullptr) {
|
||||
return absl::InternalError(
|
||||
"A/B/C/D scales and amax require cublasLt >= 11.8");
|
||||
}
|
||||
#endif
|
||||
|
||||
if (aux != nullptr) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
|
||||
aux.opaque()));
|
||||
|
|
@ -495,10 +478,6 @@ absl::Status BlasLt::MatmulPlan::DoMatmul(
|
|||
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE,
|
||||
output_batch_stride));
|
||||
#else
|
||||
return absl::InternalError(
|
||||
"Auxiliary inputs / outputs require cublasLt >= 11.4");
|
||||
#endif
|
||||
}
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{blas_lt_ref_.parent_};
|
||||
|
|
@ -529,7 +508,6 @@ namespace {
|
|||
template <cudaDataType_t CudaT>
|
||||
struct CudaToNativeT;
|
||||
|
||||
#if CUDA_VERSION >= 11080
|
||||
template <>
|
||||
struct CudaToNativeT<CUDA_R_8F_E4M3> {
|
||||
using type = tsl::float8_e4m3fn;
|
||||
|
|
@ -538,7 +516,6 @@ template <>
|
|||
struct CudaToNativeT<CUDA_R_8F_E5M2> {
|
||||
using type = tsl::float8_e5m2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct CudaToNativeT<CUDA_R_16BF> {
|
||||
|
|
@ -592,7 +569,6 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream(
|
|||
profile_result); \
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11080
|
||||
// FP8 compatible type combinations (see cuBLASLt documentation):
|
||||
TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16BF, CUDA_R_16BF)
|
||||
TYPED_MATMUL(float, CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, CUDA_R_16BF,
|
||||
|
|
@ -625,7 +601,6 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream(
|
|||
CUDA_R_8F_E5M2)
|
||||
TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_16F, CUDA_R_16F)
|
||||
TYPED_MATMUL(float, CUDA_R_8F_E5M2, CUDA_R_8F_E4M3, CUDA_R_32F, CUDA_R_32F)
|
||||
#endif
|
||||
|
||||
// Other data types:
|
||||
TYPED_MATMUL(float, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF)
|
||||
|
|
|
|||
|
|
@ -27,11 +27,7 @@ namespace stream_executor {
|
|||
namespace cuda {
|
||||
|
||||
const char* ToString(cublasStatus_t status) {
|
||||
#if CUDA_VERSION >= 11050 // `GetStatusString` was added in 11.4 update 2.
|
||||
return cublasGetStatusString(status);
|
||||
#else
|
||||
return "cublas error";
|
||||
#endif // CUDA_VERSION >= 11050
|
||||
}
|
||||
|
||||
absl::Status ToStatus(cublasStatus_t status, const char* prefix) {
|
||||
|
|
@ -43,12 +39,10 @@ absl::Status ToStatus(cublasStatus_t status, const char* prefix) {
|
|||
|
||||
cudaDataType_t AsCudaDataType(blas::DataType type) {
|
||||
switch (type) {
|
||||
#if CUDA_VERSION >= 11080
|
||||
case blas::DataType::kF8E5M2:
|
||||
return CUDA_R_8F_E5M2;
|
||||
case blas::DataType::kF8E4M3FN:
|
||||
return CUDA_R_8F_E4M3;
|
||||
#endif
|
||||
case blas::DataType::kHalf:
|
||||
return CUDA_R_16F;
|
||||
case blas::DataType::kBF16:
|
||||
|
|
|
|||
|
|
@ -533,7 +533,6 @@ static std::string_view StreamCaptureModeToString(
|
|||
break;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 12030
|
||||
VLOG(2) << "Beginning stream " << stream << " capture in "
|
||||
<< StreamCaptureModeToString(mode) << " mode to graph " << graph;
|
||||
RETURN_IF_CUDA_RES_ERROR(
|
||||
|
|
@ -543,10 +542,6 @@ static std::string_view StreamCaptureModeToString(
|
|||
/*numDependencies=*/0, cu_mode),
|
||||
"Failed to begin stream capture to graph");
|
||||
return absl::OkStatus();
|
||||
#else
|
||||
return absl::UnimplementedError(
|
||||
"StreamBeginCaptureToGraph is not implemented");
|
||||
#endif // CUDA_VERSION >= 12030
|
||||
}
|
||||
|
||||
/* static */ absl::Status GpuDriver::StreamEndCapture(CUstream stream,
|
||||
|
|
@ -567,7 +562,6 @@ static std::string_view StreamCaptureModeToString(
|
|||
<< "use_node_priority=" << flags.use_node_prirotiy << ", "
|
||||
<< "upload=" << flags.upload << ")";
|
||||
|
||||
#if CUDA_VERSION >= 12000
|
||||
uint64_t cu_flags = 0;
|
||||
if (flags.auto_free_on_launch)
|
||||
cu_flags |= CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
|
||||
|
|
@ -579,10 +573,6 @@ static std::string_view StreamCaptureModeToString(
|
|||
|
||||
RETURN_IF_CUDA_RES_ERROR(cuGraphInstantiate(exec, graph, cu_flags),
|
||||
"Failed to instantiate CUDA graph");
|
||||
#else
|
||||
RETURN_IF_CUDA_RES_ERROR(cuGraphInstantiate(exec, graph, nullptr, nullptr, 0),
|
||||
"Failed to instantiate CUDA graph");
|
||||
#endif // CUDA_VERSION >= 12000
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
@ -612,7 +602,6 @@ static std::string_view StreamCaptureModeToString(
|
|||
CUgraphExec exec, CUgraph graph, GraphExecUpdateResultInfo* result) {
|
||||
VLOG(2) << "Update CUDA graph executable " << exec << " with graph " << graph;
|
||||
|
||||
#if CUDA_VERSION >= 12000
|
||||
CUgraphExecUpdateResultInfo cu_result;
|
||||
memset(&cu_result, 0, sizeof(cu_result));
|
||||
CUresult err_code = cuGraphExecUpdate(exec, graph, &cu_result);
|
||||
|
|
@ -623,11 +612,6 @@ static std::string_view StreamCaptureModeToString(
|
|||
if (cu_result.errorNode) {
|
||||
result->error_node = cu_result.errorNode;
|
||||
}
|
||||
#else
|
||||
CUgraphExecUpdateResult cu_result;
|
||||
CUresult err_code = cuGraphExecUpdate(exec, graph, nullptr, &cu_result);
|
||||
auto cu_result_enum = cu_result;
|
||||
#endif // CUDA_VERSION >= 12000
|
||||
|
||||
switch (cu_result_enum) {
|
||||
case CU_GRAPH_EXEC_UPDATE_SUCCESS:
|
||||
|
|
@ -651,14 +635,12 @@ static std::string_view StreamCaptureModeToString(
|
|||
case CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED:
|
||||
result->result = GraphExecUpdateResult::kNotSupported;
|
||||
break;
|
||||
#if CUDA_VERSION >= 12000
|
||||
case CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE:
|
||||
result->result = GraphExecUpdateResult::kUnsupportedFunctionChange;
|
||||
break;
|
||||
case CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED:
|
||||
result->result = GraphExecUpdateResult::kAttributesChanged;
|
||||
break;
|
||||
#endif // CUDA_VERSION >= 12000
|
||||
default:
|
||||
return absl::InternalError("Unknown graph update result");
|
||||
}
|
||||
|
|
@ -687,7 +669,6 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) {
|
|||
return GraphNodeType::kGraph;
|
||||
case CU_GRAPH_NODE_TYPE_EMPTY:
|
||||
return GraphNodeType::kEmpty;
|
||||
#if CUDA_VERSION >= 12000
|
||||
case CU_GRAPH_NODE_TYPE_WAIT_EVENT:
|
||||
return GraphNodeType::kWaitEvent;
|
||||
case CU_GRAPH_NODE_TYPE_EVENT_RECORD:
|
||||
|
|
@ -702,7 +683,6 @@ GpuDriver::GraphNodeGetType(CUgraphNode node) {
|
|||
return GraphNodeType::kMemFree;
|
||||
case CU_GRAPH_NODE_TYPE_BATCH_MEM_OP:
|
||||
return GraphNodeType::kBatchMemOp;
|
||||
#endif // CUDA_VERSION >= 12000
|
||||
default:
|
||||
return absl::InternalError("Unknown graph node type");
|
||||
}
|
||||
|
|
@ -738,7 +718,6 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) {
|
|||
|
||||
/* static */ absl::StatusOr<std::string> GpuDriver::GraphDebugDotPrint(
|
||||
CUgraph graph, const char* path, bool return_printed_graph) {
|
||||
#if CUDA_VERSION >= 12000
|
||||
VLOG(2) << "Print CUDA graph " << graph << " debug dot file to " << path;
|
||||
|
||||
int flags = CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE;
|
||||
|
|
@ -753,7 +732,6 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) {
|
|||
LOG(WARNING) << "failed to read gpu graph debug file " << path;
|
||||
}
|
||||
}
|
||||
#endif // CUDA_VERSION >= 12000
|
||||
|
||||
return std::string(path);
|
||||
}
|
||||
|
|
@ -784,15 +762,10 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) {
|
|||
<< "; default_launch_value: " << default_launch_value
|
||||
<< "; flags: " << flags;
|
||||
|
||||
#if CUDA_VERSION >= 12030
|
||||
RETURN_IF_CUDA_RES_ERROR(
|
||||
cuGraphConditionalHandleCreate(handle, graph, context->context(),
|
||||
default_launch_value, flags),
|
||||
"Failed to create conditional handle for a CUDA graph");
|
||||
#else
|
||||
return absl::UnimplementedError(
|
||||
"CUDA graph conditional nodes are not implemented");
|
||||
#endif // CUDA_VERSION >= 12030
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -810,7 +783,6 @@ static std::string ConditionalTypeToString(
|
|||
GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph,
|
||||
absl::Span<const CUgraphNode> deps,
|
||||
const GpuGraphNodeParams& params) {
|
||||
#if CUDA_VERSION >= 12030
|
||||
// Add conditional node to a graph.
|
||||
if (auto* conditional = std::get_if<GpuGraphConditionalNodeParams>(¶ms)) {
|
||||
VLOG(2) << "Add conditional node to a graph " << graph
|
||||
|
|
@ -844,7 +816,6 @@ GpuDriver::GraphAddNode(CUgraphNode* node, CUgraph graph,
|
|||
VLOG(2) << "Created conditional CUDA graph " << result.graph;
|
||||
return result;
|
||||
}
|
||||
#endif // CUDA_VERSION >= 12030
|
||||
|
||||
return absl::UnimplementedError("unsupported node type");
|
||||
}
|
||||
|
|
@ -967,19 +938,12 @@ static CUmemLocationType ToCudaLocationType(
|
|||
return CU_MEM_LOCATION_TYPE_INVALID;
|
||||
case GpuDriver::MemLocationType::kDevice:
|
||||
return CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
#if CUDA_VERSION >= 12030
|
||||
case GpuDriver::MemLocationType::kHost:
|
||||
return CU_MEM_LOCATION_TYPE_HOST;
|
||||
case GpuDriver::MemLocationType::kHostNuma:
|
||||
return CU_MEM_LOCATION_TYPE_HOST_NUMA;
|
||||
case GpuDriver::MemLocationType::kHostNumaCurrent:
|
||||
return CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT;
|
||||
#else
|
||||
case GpuDriver::MemLocationType::kHost:
|
||||
case GpuDriver::MemLocationType::kHostNuma:
|
||||
case GpuDriver::MemLocationType::kHostNumaCurrent:
|
||||
return CU_MEM_LOCATION_TYPE_INVALID;
|
||||
#endif // CUDA_VERSION >= 12030
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1014,9 +978,7 @@ static CUmemAllocationType ToCudaAllocationType(
|
|||
mem_pool_props.allocType = ToCudaAllocationType(allocation_type);
|
||||
mem_pool_props.handleTypes = CU_MEM_HANDLE_TYPE_NONE;
|
||||
mem_pool_props.location = mem_location;
|
||||
#if CUDA_VERSION >= 12030
|
||||
mem_pool_props.maxSize = max_pool_size;
|
||||
#endif // CUDA_VERSION >= 12030
|
||||
// cuda graph requires reserved space initialized to 0
|
||||
memset(mem_pool_props.reserved, 0, sizeof(mem_pool_props.reserved));
|
||||
|
||||
|
|
|
|||
|
|
@ -372,7 +372,6 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
|
|||
return false;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10010
|
||||
// Workaround a cuFFT bug, which mutates the input buffer when it shouldn't.
|
||||
// See b/155276727 and go/nvbugs/2959622.
|
||||
// TODO(b/155276727): refine the bounding condition.
|
||||
|
|
@ -395,7 +394,6 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
|
|||
// execution just because the allocation for the incorrect case fails.
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
cuda::ScopedActivateExecutorContext sac(parent_);
|
||||
auto ret =
|
||||
|
|
|
|||
|
|
@ -57,10 +57,8 @@ static std::string_view ToString(nvPTXCompileResult status) {
|
|||
return "COMPILER_INVOCATION_INCOMPLETE";
|
||||
case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION:
|
||||
return "UNSUPPORTED_PTX_VERSION";
|
||||
#if CUDA_VERSION > 12000
|
||||
case NVPTXCOMPILE_ERROR_UNSUPPORTED_DEVSIDE_SYNC:
|
||||
return "UNSUPPORTED_DEVSIDE_SYNC";
|
||||
#endif
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,10 +100,7 @@ static bool IsAtLeastCuda12300() {
|
|||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
return false;
|
||||
#endif
|
||||
#if CUDA_VERSION >= 12030
|
||||
return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
TEST(GpuCommandBufferTest, LaunchSingleKernel) {
|
||||
|
|
@ -162,9 +159,6 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) {
|
|||
TEST(CudaCommandBufferTest, TraceSingleKernel) {
|
||||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Not supported on ROCM";
|
||||
#endif
|
||||
#if CUDA_VERSION < 12030
|
||||
GTEST_SKIP() << "Command buffer tracing is not supported";
|
||||
#endif
|
||||
Platform* platform = GpuPlatform();
|
||||
StreamExecutor* executor = platform->ExecutorForDevice(0).value();
|
||||
|
|
|
|||
|
|
@ -78,12 +78,7 @@ using GpuDoubleComplexType = cuDoubleComplex;
|
|||
using GpuGraphHandle = CUgraph;
|
||||
using GpuGraphExecHandle = CUgraphExec;
|
||||
using GpuGraphNodeHandle = CUgraphNode;
|
||||
|
||||
#if CUDA_VERSION >= 12030
|
||||
using GpuGraphConditionalHandle = CUgraphConditionalHandle;
|
||||
#else
|
||||
using GpuGraphConditionalHandle = UnsupportedGpuFeature;
|
||||
#endif // #if CUDA_VERSION >= 12030
|
||||
|
||||
#endif
|
||||
|
||||
|
|
|
|||
2
third_party/xla/xla/tsl/cuda/cupti_stub.cc
vendored
2
third_party/xla/xla/tsl/cuda/cupti_stub.cc
vendored
|
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#if defined(PLATFORM_GOOGLE) && (CUDA_VERSION > 10000)
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user