[cuDNN][conv] Re-enable cuDNN for 3D convolutions (fixed in 9.15+) (#166480)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166480
Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
Eddie Yan 2025-10-30 20:47:20 +00:00 committed by PyTorch MergeBot
parent 80ba6e458f
commit df71b70727

View File

@ -410,8 +410,8 @@ struct ConvParams {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
// broken on cuDNN 9.8
if (cudnn_version >= 90800) {
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
weight.dim() == 5) {