mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[cuDNN][conv] Re-enable cuDNN for 3D convolutions (fixed in 9.15+) (#166480)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166480 Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
parent
80ba6e458f
commit
df71b70727
|
|
@ -410,8 +410,8 @@ struct ConvParams {
|
|||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
// broken on cuDNN 9.8
|
||||
if (cudnn_version >= 90800) {
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
|
||||
weight.dim() == 5) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user