mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Enable BF16 NCHW Mixed batchnorm on MIOpen if ROCm>=6.4 (#154611)
This PR enables MIOpen for BF16 NCHW Mixed batchnorm if MIOpen version >=3.4 (ROCm >= 6.4) CUDAHooks::versionMIOpen() was added to detect MIOpen version Pull Request resolved: https://github.com/pytorch/pytorch/pull/154611 Approved by: https://github.com/jeffdaily, https://github.com/jithunnair-amd
This commit is contained in:
parent
085f270a00
commit
f402eed4d9
|
|
@ -331,6 +331,16 @@ long CUDAHooks::versionCuDNN() const {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
long CUDAHooks::versionMIOpen() const {
|
||||||
|
#if AT_ROCM_ENABLED()
|
||||||
|
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||||
|
MIOPEN_VERSION_MINOR * 100 +
|
||||||
|
MIOPEN_VERSION_PATCH;
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Cannot query MIOpen version if ATen_cuda is not built with ROCm");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
long CUDAHooks::versionCUDART() const {
|
long CUDAHooks::versionCUDART() const {
|
||||||
#ifdef CUDART_VERSION
|
#ifdef CUDART_VERSION
|
||||||
return CUDART_VERSION;
|
return CUDART_VERSION;
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||||
bool hasCUDART() const override;
|
bool hasCUDART() const override;
|
||||||
long versionCUDART() const override;
|
long versionCUDART() const override;
|
||||||
long versionCuDNN() const override;
|
long versionCuDNN() const override;
|
||||||
|
long versionMIOpen() const override;
|
||||||
std::string showConfig() const override;
|
std::string showConfig() const override;
|
||||||
double batchnormMinEpsilonCuDNN() const override;
|
double batchnormMinEpsilonCuDNN() const override;
|
||||||
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual long versionMIOpen() const {
|
||||||
|
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||||
|
}
|
||||||
|
|
||||||
virtual long versionCUDART() const {
|
virtual long versionCUDART() const {
|
||||||
TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
|
TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -521,17 +521,17 @@ BatchNormBackend _select_batch_norm_backend(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
input.is_cuda()
|
detail::getCUDAHooks().compiledWithMIOpen()
|
||||||
|
&& cudnn_enabled
|
||||||
|
&& input.is_cuda()
|
||||||
&& input.dim() <= MIOPEN_DIM_MAX
|
&& input.dim() <= MIOPEN_DIM_MAX
|
||||||
|
&& input.dim() >= 3
|
||||||
&& input.scalar_type() != at::kDouble
|
&& input.scalar_type() != at::kDouble
|
||||||
&& input.scalar_type() != at::kBFloat16
|
&& (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16)
|
||||||
&& (weight.scalar_type() != at::kHalf)
|
&& weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input
|
||||||
&& weight.defined() && bias.defined()
|
&& weight.defined() && bias.defined()
|
||||||
&& ((running_mean.defined() && running_var.defined())
|
&& ((running_mean.defined() && running_var.defined())
|
||||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
|| (!running_mean.defined() && !running_var.defined() && training))
|
||||||
&& (input.dim() >= 3)
|
|
||||||
&& detail::getCUDAHooks().compiledWithMIOpen()
|
|
||||||
&& cudnn_enabled
|
|
||||||
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
|
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
|
||||||
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
|
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
|
||||||
checkAllDefined(c, {running_mean, running_var});
|
checkAllDefined(c, {running_mean, running_var});
|
||||||
}
|
}
|
||||||
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
|
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
|
||||||
if (input->scalar_type() != ScalarType::Half) {
|
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
|
||||||
|
checkScalarType(c, weight, ScalarType::Float);
|
||||||
|
} else {
|
||||||
checkAllSameType(c, {input, weight});
|
checkAllSameType(c, {input, weight});
|
||||||
}
|
}
|
||||||
checkAllSameType(c, {weight, bias, running_mean, running_var});
|
checkAllSameType(c, {weight, bias, running_mean, running_var});
|
||||||
|
|
@ -186,7 +188,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
|
||||||
|
|
||||||
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
|
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
|
||||||
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
|
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
|
||||||
if (input->scalar_type() == ScalarType::Half) {
|
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
|
||||||
checkScalarType(c, weight, ScalarType::Float);
|
checkScalarType(c, weight, ScalarType::Float);
|
||||||
} else {
|
} else {
|
||||||
checkAllSameType(c, {input, weight});
|
checkAllSameType(c, {input, weight});
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user