diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index c4a425fe359..247fdb2537c 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -331,6 +331,16 @@ long CUDAHooks::versionCuDNN() const { #endif } +long CUDAHooks::versionMIOpen() const { +#if AT_ROCM_ENABLED() + return MIOPEN_VERSION_MAJOR * 10000 + + MIOPEN_VERSION_MINOR * 100 + + MIOPEN_VERSION_PATCH; +#else + TORCH_CHECK(false, "Cannot query MIOpen version if ATen_cuda is not built with ROCm"); +#endif +} + long CUDAHooks::versionCUDART() const { #ifdef CUDART_VERSION return CUDART_VERSION; diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 2b4c1113632..b0dac7a71e8 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -46,6 +46,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCUDART() const override; long versionCUDART() const override; long versionCuDNN() const override; + long versionMIOpen() const override; std::string showConfig() const override; double batchnormMinEpsilonCuDNN() const override; int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index c356ff57aa5..f99e03d156c 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -162,6 +162,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); } + virtual long versionMIOpen() const { + TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP); + } + virtual long versionCUDART() const { TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index fb4ce917bf1..ecad7d7f341 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -521,17 +521,17 @@ BatchNormBackend _select_batch_norm_backend( } if ( - input.is_cuda() + detail::getCUDAHooks().compiledWithMIOpen() + && cudnn_enabled + && input.is_cuda() && input.dim() <= MIOPEN_DIM_MAX + && input.dim() >= 3 && input.scalar_type() != at::kDouble - && input.scalar_type() != at::kBFloat16 - && (weight.scalar_type() != at::kHalf) + && (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16) + && weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input && weight.defined() && bias.defined() && ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training)) - && (input.dim() >= 3) - && detail::getCUDAHooks().compiledWithMIOpen() - && cudnn_enabled && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d ) { diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 6375f49386b..af69dfc76e5 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -79,7 +79,9 @@ std::tuple miopen_batch_norm( checkAllDefined(c, {running_mean, running_var}); } checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); - if (input->scalar_type() != ScalarType::Half) { + if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) { + checkScalarType(c, weight, ScalarType::Float); + } else { checkAllSameType(c, {input, weight}); } checkAllSameType(c, {weight, bias, running_mean, running_var}); @@ -186,7 +188,7 @@ std::tuple miopen_batch_norm_backward( checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); - if (input->scalar_type() == ScalarType::Half) { + if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) { checkScalarType(c, weight, ScalarType::Float); } else { checkAllSameType(c, {input, weight});