diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 5bf93243650..1496b6ee551 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -78,6 +78,14 @@ void Context::alertNotDeterministic(c10::string_view const& caller) { } } +bool Context::allowTF32CuDNN() const { + return allow_tf32_cudnn; +} + +void Context::setAllowTF32CuDNN(bool b) { + allow_tf32_cudnn = b; +} + static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG"; static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" }; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 3a41138a49c..19960f3b6f1 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -115,6 +115,8 @@ class CAFFE2_API Context { bool deterministic() const; void setDeterministic(bool); void alertNotDeterministic(c10::string_view const& caller); + bool allowTF32CuDNN() const; + void setAllowTF32CuDNN(bool); bool allowTF32CuBLAS() const; void setAllowTF32CuBLAS(bool); void alertCuBLASConfigNotDeterministic(); @@ -146,6 +148,7 @@ class CAFFE2_API Context { bool deterministic_cudnn = false; bool _deterministic = false; bool benchmark_cudnn = false; + bool allow_tf32_cudnn = true; bool allow_tf32_cublas = true; bool enabled_mkldnn = true; #ifdef C10_MOBILE diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index c7180f9d558..efe3a994bb8 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -255,7 +255,8 @@ TORCH_LIBRARY_IMPL(_, Autocast, m) { } TORCH_LIBRARY_IMPL(aten, Autocast, m) { - KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16) + KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16) + KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16) KERNEL(ADD_NS(_convolution_nogroup), "_convolution_nogroup", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef), fp16) KERNEL(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16) KERNEL(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16) @@ -267,8 +268,10 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp16) KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) - KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) + KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) + KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16) + KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16) + KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16) KERNEL(ADD_NS(prelu), "prelu", Tensor (const Tensor &, const Tensor &), fp16) KERNEL(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16) KERNEL(ADD_NS(addmv), "addmv", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16) diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index bd7f1115556..04e02749170 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -164,7 +164,7 @@ struct TORCH_CUDA_API ConvolutionDescriptor &cudnnCreateConvolutionDescriptor, &cudnnDestroyConvolutionDescriptor> { - void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) { + void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) { cudnnDataType_t mathType = dataType; if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT; AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, @@ -172,9 +172,13 @@ struct TORCH_CUDA_API ConvolutionDescriptor AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups)); // See Note [behavior of cudnnFind and cudnnGet] AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH)); - if(dataType == CUDNN_DATA_HALF) + if(dataType == CUDNN_DATA_HALF) { AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); - + } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) { +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH)); +#endif + } } }; @@ -235,7 +239,7 @@ struct TORCH_CUDA_API RNNDescriptor DropoutDescriptor dropout_desc_; void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, - cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo) { + cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) { dropout_desc_ = std::move(dropout_desc); AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6( handle, @@ -252,7 +256,13 @@ struct TORCH_CUDA_API RNNDescriptor if (prop->major >= 7) { if (input_type == CUDNN_DATA_HALF) { cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH); - } else { + } +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 + else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) { + cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH); + } +#endif + else { // Technically, as the default it's not necessary to explicitly // set this. cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 212dd15eef3..e3881d21ea2 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -30,6 +30,7 @@ struct ConvParams { bool benchmark; bool deterministic; bool cudnn_enabled; + bool allow_tf32; bool is_strided() const; bool is_dilated() const; @@ -582,7 +583,7 @@ at::Tensor convolution( auto& ctx = at::globalContext(); return at::_convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, - ctx.benchmarkCuDNN(), ctx.deterministicCuDNN() || ctx.deterministic(), ctx.userEnabledCuDNN()); + ctx.benchmarkCuDNN(), ctx.deterministicCuDNN() || ctx.deterministic(), ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN()); } at::Tensor convolution_overrideable( @@ -596,7 +597,7 @@ at::Tensor _convolution( const Tensor& input_r, const Tensor& weight_r, const Tensor& bias_r, IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, bool transposed_, IntArrayRef output_padding_, int64_t groups_, - bool benchmark, bool deterministic, bool cudnn_enabled) { + bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { const bool input_is_mkldnn = input_r.is_mkldnn(); auto input = input_r; @@ -618,6 +619,7 @@ at::Tensor _convolution( params.benchmark = benchmark; params.deterministic = deterministic; params.cudnn_enabled = cudnn_enabled; + params.allow_tf32 = allow_tf32; check_shape_forward(input, weight_sizes, bias, params); @@ -664,7 +666,7 @@ at::Tensor _convolution( if (params.use_cudnn_depthwise(input, weight)) { output = at::cudnn_convolution( input.contiguous(cudnn_memory_format), weight, - padding, stride, dilation, params.groups, params.benchmark, params.deterministic); + padding, stride, dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32); if (bias.defined()) { output.add_(reshape_bias(input.dim(), bias)); } @@ -687,14 +689,14 @@ at::Tensor _convolution( if (params.transposed) { output = at::cudnn_convolution_transpose( input.contiguous(cudnn_memory_format), weight, - params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); + params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32); if (bias.defined()) { output.add_(reshape_bias(input.dim(), bias)); } } else { output = at::cudnn_convolution( input.contiguous(cudnn_memory_format), weight, - params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); + params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32); if (bias.defined()) { output.add_(reshape_bias(input.dim(), bias)); } @@ -793,6 +795,15 @@ at::Tensor _convolution( return output; } +at::Tensor _convolution( + const Tensor& input_r, const Tensor& weight_r, const Tensor& bias_r, + IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, + bool transposed_, IntArrayRef output_padding_, int64_t groups_, + bool benchmark, bool deterministic, bool cudnn_enabled) +{ + return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN()); +} + // A generic function for convolution implementations which don't // natively implement groups (e.g., not CuDNN). at::Tensor _convolution_nogroup( @@ -886,7 +897,7 @@ std::tuple _convolution_double_backward( const Tensor& gO_r, const Tensor& weight_r, const Tensor& input, IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, bool transposed_, IntArrayRef output_padding_, int64_t groups_, - bool benchmark, bool deterministic, bool cudnn_enabled, + bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, std::array output_mask) { auto ggW = ggW_r; @@ -909,6 +920,7 @@ std::tuple _convolution_double_backward( params.benchmark = benchmark; params.deterministic = deterministic; params.cudnn_enabled = cudnn_enabled; + params.allow_tf32 = allow_tf32; // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb Tensor ggO; @@ -917,14 +929,14 @@ std::tuple _convolution_double_backward( if (weight.is_cuda()) { weight = weight.contiguous(); } - ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); + ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled, params.allow_tf32); } if (ggW.defined()) { if (ggW.is_cuda()) { ggW = ggW.contiguous(); } - auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); + auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled, params.allow_tf32); if (ggO.defined()) { ggO = ggO + ggW_term; } else { @@ -979,9 +991,9 @@ std::tuple _convolution_double_backward( // Compute conv if (params.transposed) { gw_conv_params.transposed = false; - gWt = at::_convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled); + gWt = at::_convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32); } else { - gWt = at::_convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled); + gWt = at::_convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32); } } else { std::vector gWt_list(groups); @@ -995,9 +1007,9 @@ std::tuple _convolution_double_backward( // Compute conv if (params.transposed) { gw_conv_params.transposed = false; - gWt_list[g] = at::_convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled); + gWt_list[g] = at::_convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32); } else { - gWt_list[g] = at::_convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled); + gWt_list[g] = at::_convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32); } } @@ -1033,7 +1045,7 @@ std::tuple _convolution_double_backward( if (gO.is_cuda()) { gO = gO.contiguous(); } - gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled); + gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32); // narrow gI to only relevant portion // we do it this way because negative output_padding is not supported @@ -1087,7 +1099,7 @@ std::tuple _convolution_double_backward( gOt = gOt.contiguous(); } - gIt = at::_convolution(ggWt, gOt, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled); + gIt = at::_convolution(ggWt, gOt, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32); } else { std::vector gIt_list(params.groups); for (int g = 0; g < groups; ++g) { @@ -1097,7 +1109,7 @@ std::tuple _convolution_double_backward( gOt_g = gOt_g.contiguous(); } - gIt_list[g] = at::_convolution(ggWt_g, gOt_g, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled); + gIt_list[g] = at::_convolution(ggWt_g, gOt_g, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32); } gIt = at::cat(gIt_list, 0); diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp index add08142578..4ddd533ec8f 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv.cpp @@ -17,56 +17,56 @@ namespace at { namespace native { at::Tensor cudnn_convolution( const at::Tensor& input, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) { + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_backward_input( IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_backward_weight( IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support"); } std::tuple cudnn_convolution_backward( const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_transpose( const at::Tensor& input, const at::Tensor& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) { + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_transpose_backward_input( const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) { + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_transpose_backward_weight( IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support"); } std::tuple cudnn_convolution_transpose_backward( const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); } @@ -215,6 +215,7 @@ struct ConvolutionParams int dilation[max_dim]; int64_t groups; bool deterministic; + bool allow_tf32; // NB: transposed purposely omitted: transposed just swaps // forward and backward, so you can reuse the benchmark entry, }; @@ -228,7 +229,7 @@ void setConvolutionParams( ConvolutionParams* params, const at::Tensor& input, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool deterministic) { + int64_t groups, bool deterministic, bool allow_tf32) { cudnnDataType_t dataType = getCudnnDataType(input); memset(params, 0, sizeof(ConvolutionParams)); @@ -250,6 +251,7 @@ void setConvolutionParams( // CuDNN, but it doesn't seem worth the effort to actually do this. params->groups = groups; params->deterministic = deterministic; + params->allow_tf32 = allow_tf32; } // Convenience struct for passing around descriptors and data @@ -658,6 +660,11 @@ public: perfResults[0].mathType = CUDNN_TENSOR_OP_MATH; } else { perfResults[0].mathType = CUDNN_DEFAULT_MATH; +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 + if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) { + perfResults[0].mathType = CUDNN_FMA_MATH; + } +#endif } search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory)); return perfResults; @@ -744,7 +751,7 @@ static inline void split_batch_dim_to_32bit_out( const at::Tensor& input, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, + bool benchmark, bool deterministic, bool allow_tf32, int64_t max_worksize, func_t func_32bit) { constexpr int64_t int_max = std::numeric_limits::max(); const int64_t ni = input.numel(); @@ -752,7 +759,7 @@ static inline void split_batch_dim_to_32bit_out( // Assume the shape of the tensor is (N, C, D1, D2, ...) // if N * C * D1 * D2 * ... <= int_max, then no need to split at all if (ni <= int_max && no <= int_max) { - func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic); + func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return; } // else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension @@ -770,7 +777,7 @@ static inline void split_batch_dim_to_32bit_out( int64_t split_size_ = std::min(split_size, n - start); Tensor input_ = input.narrow(0, start, split_size_); Tensor output_ = output.narrow(0, start, split_size_); - func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic); + func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } return; } @@ -789,6 +796,16 @@ static inline void split_batch_dim_to_32bit_out( } +#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 +#define ASSERT_CORRECT_PRECISION(math_type) \ +if (args.params.dataType == CUDNN_DATA_FLOAT) { \ + TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \ +} +#else +#define ASSERT_CORRECT_PRECISION(math_type) +#endif // CUDNN_VERSION >= 8000 + + // --------------------------------------------------------------------- // // Convolution forward / Transposed convolution backward @@ -808,17 +825,17 @@ static inline void split_batch_dim_to_32bit_out( void raw_cudnn_convolution_forward_out_32bit( const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { auto dataType = getCudnnDataType(input); ConvolutionArgs args{ input, output, weight }; args.handle = getCudnnHandle(); - setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic); + setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32); args.idesc.set(input); args.wdesc.set(weight, 0, input.suggest_memory_format()==at::MemoryFormat::ChannelsLast); args.odesc.set(output); - args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); + args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32); // TODO: when we do legacy group convolution support, we'll repeatedly // reinitialize the workspace for each convolution we do. This is @@ -832,6 +849,7 @@ void raw_cudnn_convolution_forward_out_32bit( // update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out // whether to use Tensor core kernels or not // See Note [behavior of cudnnFind and cudnnGet] + ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType); AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType)); Constant one(dataType, 1); @@ -850,15 +868,15 @@ void raw_cudnn_convolution_forward_out_32bit( void raw_cudnn_convolution_forward_out( const Tensor& output, const Tensor& input, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit); + bool benchmark, bool deterministic, bool allow_tf32) { + split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit); } Tensor cudnn_convolution_forward( CheckedFrom c, const TensorArg& input, const TensorArg& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { checkAllSameType(c, {input, weight}); checkAllSameGPU(c, {input, weight}); @@ -888,7 +906,7 @@ Tensor cudnn_convolution_forward( raw_cudnn_convolution_forward_out( *output, input_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return *output; } @@ -896,13 +914,13 @@ Tensor cudnn_convolution_forward( Tensor cudnn_convolution( const Tensor& input_t, const Tensor& weight_t, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }; CheckedFrom c = "cudnn_convolution"; auto output_t = cudnn_convolution_forward( - c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); + c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return output_t; } @@ -911,28 +929,28 @@ Tensor cudnn_convolution( Tensor cudnn_convolution_transpose_backward_input( const Tensor& grad_output_t, const Tensor& weight_t, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { TensorArg grad_output { grad_output_t, "grad_output", 1 }, weight { weight_t, "weight", 2 }; return cudnn_convolution_forward( "cudnn_convolution_transpose_backward_input", - grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); + grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } std::tuple cudnn_convolution_transpose_backward( const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); Tensor grad_input, grad_weight; if (output_mask[0]) { - grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); + grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } if (output_mask[1]) { - grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); + grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } return std::tuple{grad_input, grad_weight}; @@ -949,16 +967,16 @@ void raw_cudnn_convolution_backward_input_out_32bit( const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { auto dataType = getCudnnDataType(grad_output); ConvolutionArgs args{ grad_input, grad_output, weight }; args.handle = getCudnnHandle(); - setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic); + setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic, allow_tf32); args.idesc.set(grad_input); args.wdesc.set(weight, 0, grad_output.suggest_memory_format()==at::MemoryFormat::ChannelsLast); args.odesc.set(grad_output); - args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); + args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32); AlgoIterator(args, benchmark).try_all( [&](const cudnnConvolutionBwdDataAlgoPerf_t &bwdDataAlgPerf){ @@ -967,6 +985,7 @@ void raw_cudnn_convolution_backward_input_out_32bit( // update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out // whether to use Tensor core kernels or not // See Note [behavior of cudnnFind and cudnnGet] + ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType); AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType)); Constant one(dataType, 1); @@ -987,8 +1006,8 @@ void raw_cudnn_convolution_backward_input_out( const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit); + bool benchmark, bool deterministic, bool allow_tf32) { + split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit); } // NOTE [ Backward vs transpose convolutions ] @@ -1007,7 +1026,7 @@ Tensor cudnn_convolution_backward_input( CheckedFrom c, IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); @@ -1030,7 +1049,7 @@ Tensor cudnn_convolution_backward_input( raw_cudnn_convolution_backward_input_out( *grad_input, grad_output_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return *grad_input; } @@ -1039,31 +1058,31 @@ Tensor cudnn_convolution_transpose_forward( CheckedFrom c, const TensorArg& grad_output, const TensorArg& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), padding, output_padding, stride, dilation, groups); return cudnn_convolution_backward_input(c, input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } Tensor cudnn_convolution_backward_input( IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { TensorArg grad_output{ grad_output_t, "grad_output", 1 }, weight{ weight_t, "weight", 2 }; return cudnn_convolution_backward_input( "cudnn_convolution_backward_input", input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } std::tuple cudnn_convolution_backward( const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { + bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); @@ -1077,10 +1096,10 @@ std::tuple cudnn_convolution_backward( } } else { if (output_mask[0]) { - grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); + grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } if (output_mask[1]) { - grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); + grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } } @@ -1090,13 +1109,13 @@ std::tuple cudnn_convolution_backward( Tensor cudnn_convolution_transpose( const Tensor& input_t, const Tensor& weight_t, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) + int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }; CheckedFrom c = "cudnn_convolution_transpose"; auto output_t = cudnn_convolution_transpose_forward( - c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return output_t; } @@ -1109,17 +1128,17 @@ Tensor cudnn_convolution_transpose( void raw_cudnn_convolution_backward_weight_out_32bit( const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { auto dataType = getCudnnDataType(input); ConvolutionArgs args{ input, grad_output, grad_weight }; args.handle = getCudnnHandle(); - setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic); + setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic, allow_tf32); args.idesc.set(input); args.wdesc.set(grad_weight, 0, input.suggest_memory_format()==at::MemoryFormat::ChannelsLast); args.odesc.set(grad_output); - args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); + args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32); AlgoIterator(args, benchmark).try_all( [&](const cudnnConvolutionBwdFilterAlgoPerf_t &bwdFilterAlgPerf){ @@ -1128,6 +1147,7 @@ void raw_cudnn_convolution_backward_weight_out_32bit( // update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out // whether to use Tensor core kernels or not // See Note [behavior of cudnnFind and cudnnGet] + ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType); AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType)); Constant one(dataType, 1); @@ -1146,14 +1166,14 @@ void raw_cudnn_convolution_backward_weight_out_32bit( void raw_cudnn_convolution_backward_weight_out( const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { + bool benchmark, bool deterministic, bool allow_tf32) { constexpr int64_t int_max = std::numeric_limits::max(); const int64_t ni = input.numel(); const int64_t no = grad_output.numel(); // Assume the shape of the tensor is (N, C, D1, D2, ...) // if N * C * D1 * D2 * ... <= int_max, then no need to split at all if (ni <= int_max && no <= int_max) { - raw_cudnn_convolution_backward_weight_out_32bit(grad_weight, grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); + raw_cudnn_convolution_backward_weight_out_32bit(grad_weight, grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return; } // else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension @@ -1172,7 +1192,7 @@ void raw_cudnn_convolution_backward_weight_out( Tensor input_ = input.narrow(0, start, split_size_); Tensor grad_output_ = grad_output.narrow(0, start, split_size_); Tensor grad_weight_ = at::empty_like(grad_weight); - raw_cudnn_convolution_backward_weight_out_32bit(grad_weight_, grad_output_, input_, padding, stride, dilation, groups, benchmark, deterministic); + raw_cudnn_convolution_backward_weight_out_32bit(grad_weight_, grad_output_, input_, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); grad_weight.add_(grad_weight_); } return; @@ -1195,7 +1215,7 @@ Tensor cudnn_convolution_backward_weight( CheckedFrom c, IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; @@ -1221,7 +1241,7 @@ Tensor cudnn_convolution_backward_weight( raw_cudnn_convolution_backward_weight_out( *grad_weight, *grad_output_contig, *input, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); return grad_weight_t; } @@ -1231,12 +1251,12 @@ Tensor cudnn_convolution_backward_weight( const Tensor& grad_output_t, const Tensor& input_t, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { return cudnn_convolution_backward_weight( "cudnn_convolution_backward_weight", weight_size, grad_output_t, input_t, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } Tensor cudnn_convolution_transpose_backward_weight( @@ -1244,12 +1264,12 @@ Tensor cudnn_convolution_transpose_backward_weight( const Tensor& grad_output_t, const Tensor& input_t, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) + bool benchmark, bool deterministic, bool allow_tf32) { return cudnn_convolution_backward_weight( "cudnn_convolution_backward_weight", weight_size, input_t, grad_output_t, - padding, stride, dilation, groups, benchmark, deterministic); + padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); } }} // namespace at::native @@ -1271,6 +1291,15 @@ Tensor cudnn_convolution_deprecated( return output; } +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_deprecated2( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); +} + // TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future Tensor cudnn_convolution_transpose_deprecated( const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */, @@ -1284,4 +1313,13 @@ Tensor cudnn_convolution_transpose_deprecated( return output; } +// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future +Tensor cudnn_convolution_transpose_deprecated2( + const Tensor& input_t, const Tensor& weight_t, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN()); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index cb6e5d9e2a1..5be7d6eea8e 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -147,7 +147,7 @@ namespace { RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { RNNDescriptor rnn_desc; - rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo); + rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN()); return rnn_desc; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 326af4d7d64..dcf349324e2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -928,13 +928,16 @@ - func: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) use_c10_dispatcher: full -- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor +- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + use_c10_dispatcher: full + +- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor use_c10_dispatcher: full - func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding) -> Tensor use_c10_dispatcher: full -- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool[3] output_mask) -> (Tensor, Tensor, Tensor) +- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, bool[3] output_mask) -> (Tensor, Tensor, Tensor) use_c10_dispatcher: full - func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor @@ -1029,22 +1032,27 @@ dispatch: CUDA: cudnn_convolution_deprecated -- func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: cudnn_convolution.deprecated2(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: full + dispatch: + CUDA: cudnn_convolution_deprecated2 + +- func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution -- func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward_input -- func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) +- func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward -- func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward_weight @@ -1054,24 +1062,29 @@ dispatch: CUDA: cudnn_convolution_transpose_deprecated -- func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: cudnn_convolution_transpose.deprecated2(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: full + dispatch: + CUDA: cudnn_convolution_transpose_deprecated2 + +- func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose # NB: output_padding not strictly needed here, but it's helpful for the float # backwards -- func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) +- func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward -- func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward_input -- func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward_weight diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index f0e14a69aee..2e841d0ad82 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -140,7 +140,7 @@ constexpr uint64_t kProducedFileFormatVersion = 0x3L; // should be increased too. The relationship is: // kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion // >= kProducedFileFormatVersion -constexpr uint64_t kProducedBytecodeVersion = 0x3L; +constexpr uint64_t kProducedBytecodeVersion = 0x4L; static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion."); diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 49395bb166f..230426be869 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -68,14 +68,19 @@ TF32 tensor cores are designed to achieve better performance on matmul and convo `torch.float32` tensors by truncating input data to have 10 bits of mantissa, and accumulating results with FP32 precision, maintaining FP32 dynamic range. -matmul and convolutions are controlled separately, and their corresponding flag can be accessed at: +matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at: .. code:: python # The flag below controls whether to allow TF32 on matmul. This flag defaults to True. torch.backends.cuda.matmul.allow_tf32 = True - # The allow_tf32 flag for convolutions is not implemented yet + # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. + torch.backends.cudnn.allow_tf32 = True + +Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses +matmuls or convolutions are also affected. These include `nn.Linear`, `nn.Conv*`, cdist, tensordot, +affine grid and grid sample, adaptive log softmax, GRU and LSTM. To get an idea of the precision and speed, see the example code below: @@ -107,7 +112,7 @@ is needed, users can disable TF32 by: .. code:: python torch.backends.cuda.matmul.allow_tf32 = False - # disabling of TF32 for cuDNN is not implemented yet + torch.backends.cudnn.allow_tf32 = False For more information about TF32, see: diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 37480f88521..2fc7ba6fa2a 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -62,6 +62,16 @@ allow_list = [ ("aten::atan2", datetime.date(2020, 7, 30)), ("aten::copy_", datetime.date(2020, 7, 30)), ("aten::sort", datetime.date(2020, 7, 30)), + ('aten::_convolution', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_transpose', datetime.date(2020, 10, 15)), + ('aten::_convolution_double_backward', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_backward_input', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_backward', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_backward_weight', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_transpose_backward', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_transpose_backward_input', datetime.date(2020, 10, 15)), + ('aten::cudnn_convolution_transpose_backward_weight', datetime.date(2020, 10, 15)), ("aten::_cudnn_init_dropout_state", datetime.date(2020, 7, 30)), ("aten::sparse_coo_tensor", datetime.date(2020, 7, 30)), ("aten::_sparse_coo_tensor_with_dims", datetime.date(2020, 7, 30)), diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index b6e8ce41ddb..c0787b82009 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -83,7 +83,7 @@ void testLiteInterpreterConv() { m.register_parameter("bias", torch::ones({20}), false); m.define(R"( def forward(self, input): - return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True) + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) )"); inputs.push_back(torch::ones({1, 1, 28, 28})); diff --git a/test/test_cuda.py b/test/test_cuda.py index a0bf583769b..c4c53b85cdc 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -528,13 +528,19 @@ class TestCuda(TestCase): q_copy[1].fill_(10) self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) - def test_allow_tf32_get_set(self): + def test_cublas_allow_tf32_get_set(self): orig = torch.backends.cuda.matmul.allow_tf32 self.assertEqual(torch._C._get_cublas_allow_tf32(), orig) torch.backends.cuda.matmul.allow_tf32 = not orig self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig) torch.backends.cuda.matmul.allow_tf32 = orig + def test_cudnn_allow_tf32_get_set(self): + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False): + self.assertFalse(torch.backends.cudnn.allow_tf32) + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): + self.assertTrue(torch.backends.cudnn.allow_tf32) + def test_type_conversions(self): x = torch.randn(5, 5) self.assertIsInstance(x.float(), torch.FloatTensor) diff --git a/test/test_nn.py b/test/test_nn.py index 48f3f62459e..81b082a9821 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -51,6 +51,7 @@ from hypothesis import given import torch.testing._internal.hypothesis_utils as hu from torch.testing._internal.common_utils import _assertGradAndGradgradChecks from torch.testing._internal.common_utils import dtype2prec_DONTUSE +from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -63,9 +64,6 @@ if TEST_SCIPY: if TEST_NUMPY: import numpy as np -NO_HALF_TENSORTYPES = [torch.float, - torch.double] - DOUBLE_TENSORTYPES = [torch.double] @@ -5194,73 +5192,6 @@ class TestNN(NNTestCase): output_cpu = rnn(input.cpu(), hx) self.assertEqual(output_cuda, output_cpu) - @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - @repeat_test_for_types(NO_HALF_TENSORTYPES) - def test_cuda_rnn_fused(self, dtype=torch.float): - - def copy_rnn(rnn1, rnn2): - for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights): - for x, y in zip(x_layer, y_layer): - x.data.copy_(y.data) - - def check_rnn_grads(rnn1, rnn2): - for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights): - for x, y in zip(x_layer, y_layer): - self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0) - - input_size = 10 - hidden_size = 6 - num_layers = 2 - seq_length = 7 - batch = 6 - input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) - grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype) - hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype) - grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype) - with torch.backends.cudnn.flags(enabled=False): - for module in (nn.GRU, nn.LSTM): - for bias in (True, False): - rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype) - rnn_cuda = module(input_size, hidden_size, num_layers, bias=bias).to("cuda", dtype) - copy_rnn(rnn, rnn_cuda) - - is_lstm = isinstance(rnn, nn.LSTM) - if is_lstm: - hx = (hx_val.clone().requires_grad_(True), - hx_val.clone().add(1).requires_grad_(True)) - hx_cuda = (hx_val.clone().cuda().requires_grad_(True), - hx_val.clone().cuda().add(1).requires_grad_(True)) - else: - hx = hx_val.clone().requires_grad_(True) - hx_cuda = hx_val.clone().cuda().requires_grad_(True) - - inp = input_val.clone().requires_grad_(True) - inp_cu = input_val.clone().cuda().requires_grad_(True) - output1, hy1 = rnn(inp, hx) - output2, hy2 = rnn_cuda(inp_cu, hx_cuda) - if is_lstm: - torch.autograd.backward( - [output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1] - ) - torch.autograd.backward( - [output2, hy2[0], hy2[1]], - [grad_output.cuda(), grad_hy.cuda(), (grad_hy + 1).cuda()] - ) - else: - torch.autograd.backward([output1, hy1], [grad_output, grad_hy]) - torch.autograd.backward([output2, hy2], [grad_output.cuda(), grad_hy.cuda()]) - - self.assertEqual(output1, output2) - self.assertEqual(hy1, hy2) - - check_rnn_grads(rnn, rnn_cuda) - self.assertEqual(inp.grad.data, inp_cu.grad.data) - if is_lstm: - self.assertEqual(hx[0].grad.data, hx_cuda[0].grad.data) - self.assertEqual(hx[1].grad.data, hx_cuda[1].grad.data) - else: - self.assertEqual(hx.grad.data, hx_cuda.grad.data) - def test_transformer_args_check(self): model_name = 'Transformer' d_model = 128 @@ -7199,236 +7130,6 @@ class TestNN(NNTestCase): self.assertEqual(out_cpu, out_cuda) self.assertEqual(input_cpu.grad, input_gpu.grad) - @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), - "Scipy v1.0 and/or numpy not found") - def test_affine_2d_rotate0(self): - # scipy before 1.0.0 do not support homogeneous coordinate - # scipy.ndimage.affine_transform, so we need to skip. - for device in device_(): - input_size = [1, 1, 3, 3] - input_ary = np.array(np.random.random(input_size), dtype=np.float32) - output_size = [1, 1, 5, 5] - angle_rad = 0. - - transform_tensor, transform_ary, offset = \ - _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) - - scipy_ary = scipy.ndimage.affine_transform( - input_ary[0, 0], - transform_ary, - offset=offset, - output_shape=output_size[2:], - order=1, - mode='nearest', - prefilter=False) - - affine_tensor = torch.nn.functional.affine_grid( - transform_tensor, - torch.Size(output_size), - align_corners=True - ) - - gridsample_ary = torch.nn.functional.grid_sample( - torch.tensor(input_ary, device=device).to(device), - affine_tensor, - padding_mode='border', - align_corners=True - ).to('cpu').numpy() - - assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6 - assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6 - - @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), - "Scipy v1.0 and/or numpy not found") - def test_affine_2d_rotate90(self): - # scipy before 1.0.0 do not support homogeneous coordinate - # scipy.ndimage.affine_transform, so we need to skip. - for device, input_size2dsq, output_size2dsq in \ - itertools.product(device_(), input_size2dsq_(), output_size2dsq_()): - input_size = input_size2dsq - input_ary = np.array(np.random.random(input_size), dtype=np.float32) - output_size = output_size2dsq - angle_rad = 0.25 * math.pi * 2 - - transform_tensor, transform_ary, offset = \ - _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) - - scipy_ary = scipy.ndimage.affine_transform( - input_ary[0, 0], - transform_ary, - offset=offset, - output_shape=output_size[2:], - order=1, - mode='nearest', - prefilter=True) - - if input_size2dsq == output_size2dsq: - assert np.abs(scipy_ary.mean() - input_ary.mean()) < 1e-6 - assert np.abs(scipy_ary[0, 0] - input_ary[0, 0, 0, -1]).max() < 1e-6 - assert np.abs(scipy_ary[0, -1] - input_ary[0, 0, -1, -1]).max() < 1e-6 - assert np.abs(scipy_ary[-1, -1] - input_ary[0, 0, -1, 0]).max() < 1e-6 - assert np.abs(scipy_ary[-1, 0] - input_ary[0, 0, 0, 0]).max() < 1e-6 - - affine_tensor = torch.nn.functional.affine_grid( - transform_tensor, - torch.Size(output_size), - align_corners=True - ) - - gridsample_ary = torch.nn.functional.grid_sample( - torch.tensor(input_ary, device=device).to(device), - affine_tensor, - padding_mode='border', - align_corners=True - ).to('cpu').numpy() - - assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6 - assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6 - - @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), - "Scipy v1.0 and/or numpy not found") - def test_affine_2d_rotate45(self): - # scipy before 1.0.0 do not support homogeneous coordinate - # scipy.ndimage.affine_transform, so we need to skip. - for device in device_(): - input_size = [1, 1, 3, 3] - input_ary = np.array(np.zeros(input_size), dtype=np.float32) - input_ary[0, 0, 0, :] = 0.5 - input_ary[0, 0, 2, 2] = 1.0 - output_size = [1, 1, 3, 3] - angle_rad = 0.125 * math.pi * 2 - - transform_tensor, transform_ary, offset = \ - _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) - - scipy_ary = scipy.ndimage.affine_transform( - input_ary[0, 0], - transform_ary, - offset=offset, - output_shape=output_size[2:], - order=1, - mode='nearest', - prefilter=False) - - affine_tensor = torch.nn.functional.affine_grid( - transform_tensor, - torch.Size(output_size), - align_corners=True - ) - - gridsample_ary = torch.nn.functional.grid_sample( - torch.tensor(input_ary, device=device).to(device), - affine_tensor, - padding_mode='border', - align_corners=True - ).to('cpu').numpy() - - assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6 - - @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), - "Scipy v1.0 and/or numpy not found") - def test_affine_2d_rotateRandom(self): - # scipy before 1.0.0 do not support homogeneous coordinate - # scipy.ndimage.affine_transform, so we need to skip. - for device, angle_rad, input_size2d, output_size2d in \ - itertools.product(device_(), angle_rad_(), input_size2d_(), output_size2d_()): - - input_size = input_size2d - input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3) - output_size = output_size2d - - input_ary[0, 0, 0, 0] = 2 - input_ary[0, 0, 0, -1] = 4 - input_ary[0, 0, -1, 0] = 6 - input_ary[0, 0, -1, -1] = 8 - - transform_tensor, transform_ary, grid_ary = \ - _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) - - scipy_ary = scipy.ndimage.affine_transform( - input_ary[0, 0], - transform_ary, - output_shape=output_size[2:], - order=1, - mode='nearest', - prefilter=False) - - affine_tensor = torch.nn.functional.affine_grid( - transform_tensor, - torch.Size(output_size), - align_corners=True - ) - - gridsample_ary = torch.nn.functional.grid_sample( - torch.tensor(input_ary, device=device).to(device), - affine_tensor, - padding_mode='border', - align_corners=True - ).to('cpu').numpy() - - affine_tensor = affine_tensor.to('cpu') - - for r in range(affine_tensor.size(1)): - for c in range(affine_tensor.size(2)): - grid_out = np.dot(grid_ary, [r, c, 1]) - self.assertEqual(affine_tensor[0, r, c], grid_out[:2]) - - assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5 - - @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), - "Scipy v1.0 and/or numpy not found") - def test_affine_3d_rotateRandom(self): - # scipy before 1.0.0 do not support homogeneous coordinate - # scipy.ndimage.affine_transform, so we need to skip. - for device, angle_rad, axis_vector, input_size3d, output_size3d in \ - itertools.product(device_(), angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()): - input_size = input_size3d - input_ary = np.array(np.random.random(input_size), dtype=np.float32) - output_size = output_size3d - - input_ary[0, 0, 0, 0, 0] = 2 - input_ary[0, 0, 0, 0, -1] = 3 - input_ary[0, 0, 0, -1, 0] = 4 - input_ary[0, 0, 0, -1, -1] = 5 - input_ary[0, 0, -1, 0, 0] = 6 - input_ary[0, 0, -1, 0, -1] = 7 - input_ary[0, 0, -1, -1, 0] = 8 - input_ary[0, 0, -1, -1, -1] = 9 - - transform_tensor, transform_ary, grid_ary = \ - _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector) - - scipy_ary = scipy.ndimage.affine_transform( - input_ary[0, 0], - transform_ary, - output_shape=output_size[2:], - order=1, - mode='nearest', - prefilter=False) - - affine_tensor = torch.nn.functional.affine_grid( - transform_tensor, - torch.Size(output_size), - align_corners=True - ) - - gridsample_ary = torch.nn.functional.grid_sample( - torch.tensor(input_ary, device=device).to(device), - affine_tensor, - padding_mode='border', - align_corners=True - ).to('cpu').numpy() - - affine_tensor = affine_tensor.to('cpu') - - for i in range(affine_tensor.size(1)): - for r in range(affine_tensor.size(2)): - for c in range(affine_tensor.size(3)): - grid_out = np.dot(grid_ary, [i, r, c, 1]) - self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3]) - - assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5 - def test_channel_shuffle(self): # 3D tensor x = torch.tensor( @@ -8915,8 +8616,22 @@ def add_test(test, decorator=None): kwargs['extra_args'] = test.extra_args if 'dtype' in get_function_arglist(test.test_cuda): - add(cuda_test_name + '_float', lambda self, - test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.float, **kwargs)) + if tf32_is_not_fp32() and test.with_tf32: + + def with_tf32_off(self, test=test, kwargs=kwargs): + with tf32_off(): + test.test_cuda(self, dtype=torch.float, **kwargs) + + add(cuda_test_name + '_fp32', with_tf32_off) + + def with_tf32_on(self, test=test, kwargs=kwargs): + with tf32_on(self, test.tf32_precision): + test.test_cuda(self, dtype=torch.float, **kwargs) + + add(cuda_test_name + '_tf32', with_tf32_on) + else: + add(cuda_test_name + '_float', lambda self, + test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.float, **kwargs)) add(cuda_test_name + '_double', lambda self, test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.double, **kwargs)) @@ -8931,7 +8646,21 @@ def add_test(test, decorator=None): add(cuda_test_name + '_bfloat16', test_bfloat16) else: - add(cuda_test_name, lambda self, test=test, kwargs=kwargs: test.test_cuda(self, **kwargs)) + if tf32_is_not_fp32() and test.with_tf32: + + def with_tf32_off(self, test=test, kwargs=kwargs): + with tf32_off(): + test.test_cuda(self, **kwargs) + + add(cuda_test_name + '_fp32', with_tf32_off) + + def with_tf32_on(self, test=test, kwargs=kwargs): + with tf32_on(self, test.tf32_precision): + test.test_cuda(self, **kwargs) + + add(cuda_test_name + '_tf32', with_tf32_on) + else: + add(cuda_test_name, lambda self, test=test, kwargs=kwargs: test.test_cuda(self, **kwargs)) for test_params in module_tests + new_module_tests: # TODO: CUDA is not implemented yet @@ -9071,7 +8800,9 @@ class _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss): add_test(NewModuleTest( constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]), input_size=(4, 16), - fullname='AdaptiveLogSoftmax')) + fullname='AdaptiveLogSoftmax', + with_tf32=True, + tf32_precision=0.005)) # The following are helpers for TestNN.test_affine_* @@ -9482,6 +9213,239 @@ class TestNNDeviceType(NNTestCase): self.assertEqual(p.grad, torch.zeros_like(p.grad)) self.assertEqual(inp.grad, torch.zeros_like(inp)) + @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), + "Scipy v1.0 and/or numpy not found") + @tf32_on_and_off() + def test_affine_2d_rotate0(self, device): + # scipy before 1.0.0 do not support homogeneous coordinate + # scipy.ndimage.affine_transform, so we need to skip. + input_size = [1, 1, 3, 3] + input_ary = np.array(np.random.random(input_size), dtype=np.float32) + output_size = [1, 1, 5, 5] + angle_rad = 0. + + transform_tensor, transform_ary, offset = \ + _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) + + scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( + input_ary[0, 0], + transform_ary, + offset=offset, + output_shape=output_size[2:], + order=1, + mode='nearest', + prefilter=False)) + + affine_tensor = torch.nn.functional.affine_grid( + transform_tensor, + torch.Size(output_size), + align_corners=True + ) + + gridsample_ary = torch.nn.functional.grid_sample( + torch.tensor(input_ary, device=device).to(device), + affine_tensor, + padding_mode='border', + align_corners=True + ).to('cpu') + + self.assertEqual(scipy_ary.mean(), gridsample_ary.mean()) + self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + + @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), + "Scipy v1.0 and/or numpy not found") + @tf32_on_and_off(0.001) + def test_affine_2d_rotate90(self, device): + # scipy before 1.0.0 do not support homogeneous coordinate + # scipy.ndimage.affine_transform, so we need to skip. + for input_size2dsq, output_size2dsq in \ + itertools.product(input_size2dsq_(), output_size2dsq_()): + input_size = input_size2dsq + input_ary = np.array(np.random.random(input_size), dtype=np.float32) + output_size = output_size2dsq + angle_rad = 0.25 * math.pi * 2 + + transform_tensor, transform_ary, offset = \ + _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) + + scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( + input_ary[0, 0], + transform_ary, + offset=offset, + output_shape=output_size[2:], + order=1, + mode='nearest', + prefilter=True)) + + if input_size2dsq == output_size2dsq: + self.assertEqual(scipy_ary.mean(), input_ary.mean()) + self.assertEqual(scipy_ary[0, 0], input_ary[0, 0, 0, -1]) + self.assertEqual(scipy_ary[0, -1], input_ary[0, 0, -1, -1]) + self.assertEqual(scipy_ary[-1, -1], input_ary[0, 0, -1, 0]) + self.assertEqual(scipy_ary[-1, 0], input_ary[0, 0, 0, 0]) + + affine_tensor = torch.nn.functional.affine_grid( + transform_tensor, + torch.Size(output_size), + align_corners=True + ) + + gridsample_ary = torch.nn.functional.grid_sample( + torch.tensor(input_ary, device=device).to(device), + affine_tensor, + padding_mode='border', + align_corners=True + ).to('cpu') + + self.assertEqual(scipy_ary.mean(), gridsample_ary.mean()) + self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + + @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), + "Scipy v1.0 and/or numpy not found") + @tf32_on_and_off(0.005) + def test_affine_2d_rotate45(self, device): + # scipy before 1.0.0 do not support homogeneous coordinate + # scipy.ndimage.affine_transform, so we need to skip. + input_size = [1, 1, 3, 3] + input_ary = np.array(np.zeros(input_size), dtype=np.float32) + input_ary[0, 0, 0, :] = 0.5 + input_ary[0, 0, 2, 2] = 1.0 + output_size = [1, 1, 3, 3] + angle_rad = 0.125 * math.pi * 2 + + transform_tensor, transform_ary, offset = \ + _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) + + scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( + input_ary[0, 0], + transform_ary, + offset=offset, + output_shape=output_size[2:], + order=1, + mode='nearest', + prefilter=False)) + + affine_tensor = torch.nn.functional.affine_grid( + transform_tensor, + torch.Size(output_size), + align_corners=True + ) + + gridsample_ary = torch.nn.functional.grid_sample( + torch.tensor(input_ary, device=device).to(device), + affine_tensor, + padding_mode='border', + align_corners=True + ).to('cpu') + + self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + + @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), + "Scipy v1.0 and/or numpy not found") + @tf32_on_and_off(0.005) + def test_affine_2d_rotateRandom(self, device): + # scipy before 1.0.0 do not support homogeneous coordinate + # scipy.ndimage.affine_transform, so we need to skip. + for angle_rad, input_size2d, output_size2d in \ + itertools.product(angle_rad_(), input_size2d_(), output_size2d_()): + + input_size = input_size2d + input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3) + output_size = output_size2d + + input_ary[0, 0, 0, 0] = 2 + input_ary[0, 0, 0, -1] = 4 + input_ary[0, 0, -1, 0] = 6 + input_ary[0, 0, -1, -1] = 8 + + transform_tensor, transform_ary, grid_ary = \ + _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad) + + scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( + input_ary[0, 0], + transform_ary, + output_shape=output_size[2:], + order=1, + mode='nearest', + prefilter=False)) + + affine_tensor = torch.nn.functional.affine_grid( + transform_tensor, + torch.Size(output_size), + align_corners=True + ) + + gridsample_ary = torch.nn.functional.grid_sample( + torch.tensor(input_ary, device=device).to(device), + affine_tensor, + padding_mode='border', + align_corners=True + ).to('cpu') + + affine_tensor = affine_tensor.to('cpu') + + for r in range(affine_tensor.size(1)): + for c in range(affine_tensor.size(2)): + grid_out = np.dot(grid_ary, [r, c, 1]) + self.assertEqual(affine_tensor[0, r, c], grid_out[:2]) + + self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + + @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), + "Scipy v1.0 and/or numpy not found") + @tf32_on_and_off(0.005) + def test_affine_3d_rotateRandom(self, device): + # scipy before 1.0.0 do not support homogeneous coordinate + # scipy.ndimage.affine_transform, so we need to skip. + for angle_rad, axis_vector, input_size3d, output_size3d in \ + itertools.product(angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()): + input_size = input_size3d + input_ary = np.array(np.random.random(input_size), dtype=np.float32) + output_size = output_size3d + + input_ary[0, 0, 0, 0, 0] = 2 + input_ary[0, 0, 0, 0, -1] = 3 + input_ary[0, 0, 0, -1, 0] = 4 + input_ary[0, 0, 0, -1, -1] = 5 + input_ary[0, 0, -1, 0, 0] = 6 + input_ary[0, 0, -1, 0, -1] = 7 + input_ary[0, 0, -1, -1, 0] = 8 + input_ary[0, 0, -1, -1, -1] = 9 + + transform_tensor, transform_ary, grid_ary = \ + _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector) + + scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform( + input_ary[0, 0], + transform_ary, + output_shape=output_size[2:], + order=1, + mode='nearest', + prefilter=False)) + + affine_tensor = torch.nn.functional.affine_grid( + transform_tensor, + torch.Size(output_size), + align_corners=True + ) + + gridsample_ary = torch.nn.functional.grid_sample( + torch.tensor(input_ary, device=device).to(device), + affine_tensor, + padding_mode='border', + align_corners=True + ).to('cpu') + + affine_tensor = affine_tensor.to('cpu') + + for i in range(affine_tensor.size(1)): + for r in range(affine_tensor.size(2)): + for c in range(affine_tensor.size(3)): + grid_out = np.dot(grid_ary, [i, r, c, 1]) + self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3]) + + self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) + def test_Dropout(self, device): input = torch.Tensor(1000) self._test_dropout(nn.Dropout, device, input) @@ -9594,6 +9558,74 @@ class TestNNDeviceType(NNTestCase): inp = torch.randn(3, 0, 10, 10, device=device) mod(inp) + @onlyCUDA + @dtypes(torch.float, torch.double) + @tf32_on_and_off(0.005) + def test_rnn_fused(self, device, dtype): + + def copy_rnn(rnn1, rnn2): + for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights): + for x, y in zip(x_layer, y_layer): + x.data.copy_(y.data) + + def check_rnn_grads(rnn1, rnn2): + for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights): + for x, y in zip(x_layer, y_layer): + self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0) + + input_size = 10 + hidden_size = 6 + num_layers = 2 + seq_length = 7 + batch = 6 + input_val = torch.randn(seq_length, batch, input_size, dtype=dtype) + grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype) + hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype) + grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype) + with torch.backends.cudnn.flags(enabled=False, allow_tf32=None): + for module in (nn.GRU, nn.LSTM): + for bias in (True, False): + rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype) + rnn_device = module(input_size, hidden_size, num_layers, bias=bias).to(device, dtype) + copy_rnn(rnn, rnn_device) + + is_lstm = isinstance(rnn, nn.LSTM) + if is_lstm: + hx = (hx_val.clone().requires_grad_(True), + hx_val.clone().add(1).requires_grad_(True)) + hx_device = (hx_val.clone().to(device).requires_grad_(True), + hx_val.clone().to(device).add(1).requires_grad_(True)) + else: + hx = hx_val.clone().requires_grad_(True) + hx_device = hx_val.clone().to(device).requires_grad_(True) + + inp = input_val.clone().requires_grad_(True) + inp_cu = input_val.clone().to(device).requires_grad_(True) + output1, hy1 = rnn(inp, hx) + output2, hy2 = rnn_device(inp_cu, hx_device) + if is_lstm: + torch.autograd.backward( + [output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1] + ) + torch.autograd.backward( + [output2, hy2[0], hy2[1]], + [grad_output.to(device), grad_hy.to(device), (grad_hy + 1).to(device)] + ) + else: + torch.autograd.backward([output1, hy1], [grad_output, grad_hy]) + torch.autograd.backward([output2, hy2], [grad_output.to(device), grad_hy.to(device)]) + + self.assertEqual(output1, output2) + self.assertEqual(hy1, hy2) + + check_rnn_grads(rnn, rnn_device) + self.assertEqual(inp.grad, inp_cu.grad) + if is_lstm: + self.assertEqual(hx[0].grad, hx_device[0].grad) + self.assertEqual(hx[1].grad, hx_device[1].grad) + else: + self.assertEqual(hx.grad, hx_device.grad) + def test_BatchNorm_empty(self, device): mod = torch.nn.BatchNorm2d(3).to(device) inp = torch.randn(0, 3, 2, 2, device=device) @@ -10209,6 +10241,7 @@ class TestNNDeviceType(NNTestCase): self.assertEqual(out1, out2) @onlyCUDA + @tf32_on_and_off(0.005) def test_grid_sample_large(self, device): def issue_35202(): input_tensor = torch.rand(1, 1, 480, 640, dtype=torch.float, device=device, requires_grad=True) @@ -10235,7 +10268,7 @@ class TestNNDeviceType(NNTestCase): result.backward(torch.ones_like(result)) expected_grad = torch.ones_like(image) expected_grad[0, 0, 1, 1, 1] = 0 - self.assertTrue(torch.allclose(image.grad, expected_grad, atol=1e-3)) + self.assertEqual(image.grad, expected_grad, atol=0.005, rtol=0) issue_24823_1(torch.half) issue_24823_1(torch.float) issue_24823_1(torch.double) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e5778360437..cdd6326b81e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1405,50 +1405,50 @@ input, weight, bias: "grad.defined() ? convolution_backward_overrideable(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" - name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, false, output_padding, groups, false, false, false, grad_input_mask) + grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, false, output_padding, groups, false, false, false, false, grad_input_mask) - name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? slow_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, empty_like(grad, at::MemoryFormat::Contiguous), empty_like(grad, at::MemoryFormat::Contiguous), grad_input_mask) : std::tuple()" - name: slow_conv_transpose2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, false, grad_input_mask) - name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? slow_conv_transpose3d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, empty_like(grad, at::MemoryFormat::Preserve), empty_like(grad, at::MemoryFormat::Preserve), grad_input_mask) : std::tuple()" - name: slow_conv_transpose3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, false, grad_input_mask) - name: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) self, weight, bias: "grad.defined() ? thnn_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) : std::tuple()" - name: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, false, grad_input_mask) - name: thnn_conv_depthwise2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor self, weight: "grad.defined() ? thnn_conv_depthwise2d_backward(grad.contiguous(), self, weight, kernel_size, stride, padding, dilation, grad_input_mask) : std::tuple()" bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1) - name: thnn_conv_depthwise2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[2] output_mask) -> (Tensor grad_input, Tensor grad_weight) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, false, grad_input_mask) - name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) self, weight, bias: "grad.defined() ? slow_conv3d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) : std::tuple()" - name: slow_conv3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, false, grad_input_mask) - name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? slow_conv_dilated2d_backward(grad, self, weight, kernel_size, stride, padding, dilation, grad_input_mask) : std::tuple()" - name: slow_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, false, grad_input_mask) - name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? slow_conv_dilated3d_backward(grad, self, weight, kernel_size, stride, padding, dilation, grad_input_mask) : std::tuple()" - name: slow_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, false, grad_input_mask) - name: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor self: col2im_backward(grad, kernel_size, dilation, padding, stride) @@ -1661,17 +1661,17 @@ - name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) -- name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight: "grad.defined() ? cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple()" +- name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "grad.defined() ? cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, grad_input_mask) : std::tuple()" -- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask) +- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, allow_tf32, grad_input_mask) -- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight: "grad.defined() ? cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple()" +- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "grad.defined() ? cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, grad_input_mask) : std::tuple()" -- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask) +- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, allow_tf32, grad_input_mask) # The above backward definitions are equivalent to the definitions below. Why do we bundle # everything up? It's because it's more convenient to define double backwards @@ -1734,19 +1734,19 @@ self, weight, bias: "grad.defined() ? miopen_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple()" - name: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, false, grad_input_mask) - name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor self, weight, bias: "grad.defined() ? miopen_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple()" - name: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, false, grad_input_mask) - name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor self, weight, bias: "grad.defined() ? miopen_depthwise_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple()" - name: miopen_depthwise_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, false, grad_input_mask) - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" @@ -1769,7 +1769,7 @@ self, weight, bias: "grad.defined() ? mkldnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple()" - name: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, false, false, false, grad_input_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, false, false, false, false, grad_input_mask) # fft - name: _fft_with_size(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, bool normalized, bool onesided, int[] output_sizes) -> Tensor diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9ccb895833a..f336eb71524 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -193,6 +193,8 @@ def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN def _get_deterministic() -> _bool: ... # THPModule_deterministic def _set_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministic +def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN +def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index e213c8a0cae..a18c215983e 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -83,26 +83,32 @@ def is_acceptable(tensor): return True -def set_flags(_enabled, _benchmark, _deterministic): +def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=None): orig_flags = (torch._C._get_cudnn_enabled(), torch._C._get_cudnn_benchmark(), - torch._C._get_cudnn_deterministic()) - torch._C._set_cudnn_enabled(_enabled) - torch._C._set_cudnn_benchmark(_benchmark) - torch._C._set_cudnn_deterministic(_deterministic) + torch._C._get_cudnn_deterministic(), + torch._C._get_cudnn_allow_tf32()) + if _enabled is not None: + torch._C._set_cudnn_enabled(_enabled) + if _benchmark is not None: + torch._C._set_cudnn_benchmark(_benchmark) + if _deterministic is not None: + torch._C._set_cudnn_deterministic(_deterministic) + if _allow_tf32 is not None: + torch._C._set_cudnn_allow_tf32(_allow_tf32) return orig_flags @contextmanager -def flags(enabled=False, benchmark=False, deterministic=False): +def flags(enabled=False, benchmark=False, deterministic=False, allow_tf32=True): with __allow_nonbracketed_mutation(): - orig_flags = set_flags(enabled, benchmark, deterministic) + orig_flags = set_flags(enabled, benchmark, deterministic, allow_tf32) try: yield finally: # recover the previous values with __allow_nonbracketed_mutation(): - set_flags(orig_flags[0], orig_flags[1], orig_flags[2]) + set_flags(*orig_flags) # The magic here is to allow us to intercept code like this: @@ -116,6 +122,7 @@ class CudnnModule(PropModule): enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic) benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark) + allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32) # This is the sys.modules replacement trick, see # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index aecb336f364..ed4aa21a8f7 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -388,6 +388,20 @@ PyObject *THPModule_fromDLPack(PyObject *_unused, PyObject *data) END_HANDLE_TH_ERRORS } +PyObject *THPModule_setAllowTF32CuDNN(PyObject *_unused, PyObject *arg) +{ + THPUtils_assert(PyBool_Check(arg), "set_allow_tf32_cublas expects a bool, " + "but got %s", THPUtils_typename(arg)); + at::globalContext().setAllowTF32CuDNN(arg == Py_True); + Py_RETURN_NONE; +} + +PyObject *THPModule_allowTF32CuDNN(PyObject *_unused, PyObject *noargs) +{ + if (at::globalContext().allowTF32CuDNN()) Py_RETURN_TRUE; + else Py_RETURN_FALSE; +} + PyObject *THPModule_setUserEnabledCuDNN(PyObject *_unused, PyObject *arg) { THPUtils_assert(PyBool_Check(arg), "set_enabled_cudnn expects a bool, " @@ -577,6 +591,8 @@ static PyMethodDef TorchMethods[] = { {"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr}, {"_get_mkldnn_enabled", (PyCFunction)THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, {"_set_mkldnn_enabled", (PyCFunction)THPModule_setUserEnabledMkldnn, METH_O, nullptr}, + {"_get_cudnn_allow_tf32", (PyCFunction)THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_allow_tf32", (PyCFunction)THPModule_setAllowTF32CuDNN, METH_O, nullptr}, {"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_benchmark", (PyCFunction)THPModule_setBenchmarkCuDNN, METH_O, nullptr}, {"_get_cudnn_deterministic", (PyCFunction)THPModule_deterministicCuDNN, METH_NOARGS, nullptr}, diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp index c6ffe39dbab..7d344632838 100644 --- a/torch/csrc/jit/passes/fold_conv_bn.cpp +++ b/torch/csrc/jit/passes/fold_conv_bn.cpp @@ -32,29 +32,41 @@ void replaceConvBiasWithGetAttr(Module& module) { // Thus assumes that tracing will have always gotten rid of aten::conv2d or // aten::conv3d. If it did not, BN folding will fail. const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"( + graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], + %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, + %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): + %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, + %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32) + return (%conv_out) )"); + const PatternInfo& pattern_convolution_deprecated = + PatternInfo::parse_from_str(R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) return (%conv_out) )"); - const Graph& pattern_convolution_graph = *pattern_convolution.pattern_graph; - const auto& convolution_vmap = pattern_convolution.vmap; + auto replace_pattern = [&](const PatternInfo& pattern_convolution) { + const Graph& pattern_convolution_graph = *pattern_convolution.pattern_graph; + const auto& convolution_vmap = pattern_convolution.vmap; - const auto& matches = findPatternMatches(pattern_convolution_graph, *graph); - for (const auto& match : matches) { - // We come here only if the bias was not present in the module. - // In that case, the corresponding graph will not have getAttr("bias") - // Insert that in the graph. - // And change _convolution to take the new value. - auto conv_node = - match.values_map.at(convolution_vmap.at("conv_out"))->node(); - WithInsertPoint ins(conv_node); - Value* bias_attr_val = graph->insertGetAttr(graph->inputs()[0], "bias") - ->setType(TensorType::get()); - constexpr size_t conv_bias_index = 2; - conv_node->replaceInput(conv_bias_index, bias_attr_val); - } + const auto& matches = findPatternMatches(pattern_convolution_graph, *graph); + for (const auto& match : matches) { + // We come here only if the bias was not present in the module. + // In that case, the corresponding graph will not have getAttr("bias") + // Insert that in the graph. + // And change _convolution to take the new value. + auto conv_node = + match.values_map.at(convolution_vmap.at("conv_out"))->node(); + WithInsertPoint ins(conv_node); + Value* bias_attr_val = graph->insertGetAttr(graph->inputs()[0], "bias") + ->setType(TensorType::get()); + constexpr size_t conv_bias_index = 2; + conv_node->replaceInput(conv_bias_index, bias_attr_val); + } + }; + replace_pattern(pattern_convolution); + replace_pattern(pattern_convolution_deprecated); } void addBiasForConvIfNone(Module& module, const std::string& pattern_name) { diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp index 90caf4690b8..11e963af77b 100644 --- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp +++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp @@ -63,18 +63,25 @@ std::unordered_map getConvParams( void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { // TODO: remove constant prop in the pass ConstantPropagation(graph); - std::string convolution = R"( + std::string convolution_deprecated = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, %deterministic:bool, %cudnn_enabled:bool): %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) return (%r) )"; + std::string convolution = R"( + graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], + %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, + %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): + %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, + %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32) + return (%r) )"; std::string conv2d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, - %deterministic:bool, %cudnn_enabled:bool): + %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; @@ -88,14 +95,14 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { std::string conv1d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, - %deterministic:bool, %cudnn_enabled:bool): + %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; std::string conv3d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, - %deterministic:bool, %cudnn_enabled:bool): + %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; @@ -167,9 +174,11 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { SubgraphRewriter rewriter_conv1d; rewriter_conv1d.RegisterRewritePattern(convolution, conv1d); + rewriter_conv1d.RegisterRewritePattern(convolution_deprecated, conv1d); rewriter_conv1d.runOnGraph(graph, filter_conv1d); SubgraphRewriter rewriter_conv2d; rewriter_conv2d.RegisterRewritePattern(convolution, conv2d); + rewriter_conv2d.RegisterRewritePattern(convolution_deprecated, conv2d); rewriter_conv2d.runOnGraph(graph, filter_conv2d); SubgraphRewriter rewriter_conv2d_transpose; rewriter_conv2d_transpose.RegisterRewritePattern( @@ -177,6 +186,7 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose); SubgraphRewriter rewriter_conv3d; rewriter_conv3d.RegisterRewritePattern(convolution, conv3d); + rewriter_conv3d.RegisterRewritePattern(convolution_deprecated, conv3d); rewriter_conv3d.runOnGraph(graph, filter_conv3d); } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 478fdf6af00..cf9a93b7f52 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1160,7 +1160,8 @@ class ShapePropagator { "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", - "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", + "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", // deprecated _convolution + "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor", diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 8d81877180b..0c43d149a6c 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1109,9 +1109,9 @@ def log_softmax(g, input, dim, dtype=None): return return_op -@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i') +@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i', 'i') def _convolution(g, input, weight, bias, stride, padding, dilation, - transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled): + transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32): weight_size = weight.type().sizes() args = [input, weight] @@ -1145,32 +1145,32 @@ def _convolution(g, input, weight, bias, stride, padding, dilation, @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i') def conv1d(g, input, weight, bias, stride, padding, dilation, groups): - return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None) + return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None) @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i') def conv2d(g, input, weight, bias, stride, padding, dilation, groups): - return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None) + return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None) @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i') def conv3d(g, input, weight, bias, stride, padding, dilation, groups): - return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None) + return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None) @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is') def conv_transpose1d(g, input, weight, bias, stride, padding, output_padding, groups, dilation): - return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None) + return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None) @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is') def conv_transpose2d(g, input, weight, bias, stride, padding, output_padding, groups, dilation): - return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None) + return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None) @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is') def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, groups, dilation): - return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None) + return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None) @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 87293078385..d8cdb892c96 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -61,8 +61,12 @@ class AutocastTestLists(object): # The remaining lists organize ops that autocast treats explicitly. self.torch_fp16 = [ + # deprecated _convolution ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1, False, True, True)), + # the current _convolution + ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, + (0, 0), 1, False, True, True, True)), ("_convolution_nogroup", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0))), ("conv1d", conv_args_fp32[0]), ("conv2d", conv_args_fp32[1]), @@ -72,13 +76,17 @@ class AutocastTestLists(object): ("conv_transpose2d", conv_args_fp32[1], TEST_WITH_ROCM), ("conv_transpose3d", conv_args_fp32[2], TEST_WITH_ROCM), ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)), - # cudnn_convolutions with bias + # deprecated cudnn_convolutions with bias ("cudnn_convolution", conv_args_fp32[1] + bias_fp32 + ((0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM), ("cudnn_convolution_transpose", conv_args_fp32[1] + bias_fp32 + ((0, 0), (0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM), - # cudnn_convolutions with no bias + # deprecated cudnn_convolutions with no allow_tf32 flag ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM), ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM), + # the current cudnn_convolutions + ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM), + ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1), + (1, 1), 1, False, True, True), TEST_WITH_ROCM), ("prelu", pointwise0_fp32 + element0_fp32), ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32), diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 4e0a51ddaf0..506f3bdb505 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -5,6 +5,7 @@ import torch import torch.cuda from torch.testing._internal.common_utils import TEST_NUMBA import inspect +import contextlib TEST_CUDA = torch.cuda.is_available() @@ -49,6 +50,31 @@ def tf32_is_not_fp32(): return True +@contextlib.contextmanager +def tf32_off(): + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + try: + torch.backends.cuda.matmul.allow_tf32 = False + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + + +@contextlib.contextmanager +def tf32_on(self, tf32_precision=1e-5): + old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 + old_precison = self.precision + try: + torch.backends.cuda.matmul.allow_tf32 = True + self.precision = tf32_precision + with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul + self.precision = old_precison + + # This is a wrapper that wraps a test to run this test twice, one with # allow_tf32=True, another with allow_tf32=False. When running with # allow_tf32=True, it will use reduced precision as pecified by the @@ -64,33 +90,22 @@ def tf32_is_not_fp32(): # TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced # precision to check values. def tf32_on_and_off(tf32_precision=1e-5): - def call_with_tf32_on_and_off(self, function_call): - old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 - old_precison = self.precision - - def with_tf32_disabled(): - torch.backends.cuda.matmul.allow_tf32 = False + def with_tf32_disabled(self, function_call): + with tf32_off(): function_call() - def with_tf32_enabled(): - torch.backends.cuda.matmul.allow_tf32 = True - self.precision = tf32_precision + def with_tf32_enabled(self, function_call): + with tf32_on(self, tf32_precision): function_call() - try: - with_tf32_disabled() - with_tf32_enabled() - finally: - torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32 - self.precision = old_precison - def wrapper(f): nargs = len(inspect.signature(f).parameters) if nargs == 2: @functools.wraps(f) def wrapped(self, device): if self.device_type == 'cuda' and tf32_is_not_fp32(): - call_with_tf32_on_and_off(self, lambda: f(self, device)) + with_tf32_disabled(self, lambda: f(self, device)) + with_tf32_enabled(self, lambda: f(self, device)) else: f(self, device) else: @@ -99,7 +114,8 @@ def tf32_on_and_off(tf32_precision=1e-5): @functools.wraps(f) def wrapped(self, device, dtype): if self.device_type == 'cuda' and dtype in {torch.float32, torch.complex64} and tf32_is_not_fp32(): - call_with_tf32_on_and_off(self, lambda: f(self, device, dtype)) + with_tf32_disabled(self, lambda: f(self, device, dtype)) + with_tf32_enabled(self, lambda: f(self, device, dtype)) else: f(self, device, dtype) diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 85faea19576..06cc8aa0029 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -106,6 +106,8 @@ module_tests = [ cpp_constructor_args='torch::nn::LinearOptions(10, 8)', input_size=(4, 10), reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8), + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Linear', @@ -113,7 +115,9 @@ module_tests = [ cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)', input_size=(4, 10), desc='no_bias', - reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + reference_fn=lambda i, p, _: torch.mm(i, p[0].t()), + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Threshold', @@ -1595,6 +1599,7 @@ new_module_tests = [ cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', input_size=(2, 4, 10), cudnn=True, + with_tf32=True, ), dict( module_name='Conv1d', @@ -1603,6 +1608,8 @@ new_module_tests = [ input_size=(2, 4, 10), cudnn=True, desc='stride', + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Conv1d', @@ -1611,6 +1618,7 @@ new_module_tests = [ input_size=(2, 4, 10), cudnn=True, desc='pad1', + with_tf32=True, ), dict( module_name='Conv1d', @@ -1619,6 +1627,7 @@ new_module_tests = [ input_size=(2, 4, 10), cudnn=True, desc='pad2', + with_tf32=True, ), dict( module_name='Conv1d', @@ -1627,6 +1636,7 @@ new_module_tests = [ input_size=(1, 4, 1), cudnn=True, desc='pad1size1', + with_tf32=True, ), dict( module_name='Conv1d', @@ -1635,6 +1645,7 @@ new_module_tests = [ input_size=(1, 4, 1), cudnn=True, desc='pad2size1', + with_tf32=True, ), dict( module_name='Conv1d', @@ -1644,12 +1655,14 @@ new_module_tests = [ cudnn=True, desc='zero_batch', test_cuda=(not TEST_WITH_ROCM), + with_tf32=True, ), dict( fullname='Conv1d_dilated', constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)', input_size=(2, 4, 10), + with_tf32=True, ), dict( fullname='Conv1d_groups', @@ -1657,6 +1670,7 @@ new_module_tests = [ cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)', input_size=(2, 4, 6), cudnn=True, + with_tf32=True, ), dict( fullname='ConvTranspose1d', @@ -1664,6 +1678,8 @@ new_module_tests = [ cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)', cudnn=True, input_size=(1, 3, 7), + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='ConvTranspose1d', @@ -1673,6 +1689,8 @@ new_module_tests = [ input_size=(1, 3, 6), cudnn=True, desc='no_bias', + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='ConvTranspose1d', @@ -1682,6 +1700,7 @@ new_module_tests = [ input_size=(1, 3, 6), cudnn=True, desc='dilated', + with_tf32=True, ), dict( fullname='ConvTranspose1d_groups', @@ -1690,6 +1709,8 @@ new_module_tests = [ .stride(3).padding(1).output_padding(1).groups(2)''', cudnn=True, input_size=(2, 4, 7), + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='MaxPool1d', @@ -1711,6 +1732,7 @@ new_module_tests = [ input_size=(2, 3, 7, 5), cudnn=True, check_with_long_tensor=True, + with_tf32=True, ), dict( module_name='Conv2d', @@ -1720,6 +1742,8 @@ new_module_tests = [ cudnn=True, desc='strided', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Conv2d', @@ -1729,6 +1753,8 @@ new_module_tests = [ cudnn=True, desc='padding', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Conv2d', @@ -1738,6 +1764,7 @@ new_module_tests = [ cudnn=True, desc='dilated', check_with_long_tensor=True, + with_tf32=True, ), dict( module_name='Conv2d', @@ -1748,6 +1775,7 @@ new_module_tests = [ cudnn=True, desc='no_bias', check_with_long_tensor=True, + with_tf32=True, ), dict( module_name='Conv2d', @@ -1758,6 +1786,7 @@ new_module_tests = [ desc='zero_batch', check_with_long_tensor=True, test_cuda=(not TEST_WITH_ROCM), + with_tf32=True, ), dict( fullname='Conv2d_groups', @@ -1766,6 +1795,7 @@ new_module_tests = [ input_size=(2, 4, 6, 5), cudnn=True, check_with_long_tensor=True, + with_tf32=True, ), dict( fullname='Conv2d_groups_thnn', @@ -1773,6 +1803,7 @@ new_module_tests = [ cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', input_size=(2, 4, 6, 5), check_with_long_tensor=True, + with_tf32=True, ), dict( module_name='ConvTranspose2d', @@ -1782,6 +1813,8 @@ new_module_tests = [ cudnn=True, input_size=(1, 3, 7, 6), check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='ConvTranspose2d', @@ -1797,6 +1830,8 @@ new_module_tests = [ cudnn=True, desc='dilated', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='ConvTranspose2d', @@ -1807,6 +1842,8 @@ new_module_tests = [ cudnn=True, desc='no_bias', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( fullname='ConvTranspose2d_groups', @@ -1815,36 +1852,47 @@ new_module_tests = [ input_size=(1, 2, 4, 5), cudnn=True, check_with_long_tensor=True, + with_tf32=True, ), dict( fullname='Conv2d_depthwise', constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)', input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, ), dict( fullname='Conv2d_depthwise_with_multiplier', constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)', input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, ), dict( fullname='Conv2d_depthwise_strided', constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)', input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, ), dict( fullname='Conv2d_depthwise_padded', constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)', input_size=(2, 4, 6, 6), + with_tf32=True, + tf32_precision=0.005, ), dict( fullname='Conv2d_depthwise_dilated', constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)', input_size=(2, 4, 5, 5), + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='MaxPool2d', @@ -2067,6 +2115,8 @@ new_module_tests = [ input_size=(1, 2, 4, 5, 4), cudnn=True, check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Conv3d', @@ -2077,6 +2127,8 @@ new_module_tests = [ cudnn=True, desc='no_bias', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.05, ), dict( module_name='Conv3d', @@ -2086,6 +2138,8 @@ new_module_tests = [ cudnn=True, desc='stride', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( module_name='Conv3d', @@ -2095,6 +2149,8 @@ new_module_tests = [ cudnn=True, desc='stride_padding', check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.01, ), dict( module_name='Conv3d', @@ -2105,6 +2161,7 @@ new_module_tests = [ check_with_long_tensor=True, desc='zero_batch', test_cuda=(not TEST_WITH_ROCM), + with_tf32=True, ), dict( fullname='Conv3d_groups', @@ -2113,18 +2170,22 @@ new_module_tests = [ input_size=(1, 2, 4, 5, 4), cudnn=True, check_with_long_tensor=True, + with_tf32=True, + tf32_precision=0.005, ), dict( fullname='Conv3d_dilated', constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)', input_size=(2, 3, 5, 5, 5), + with_tf32=True, ), dict( fullname='Conv3d_dilated_strided', constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)', input_size=(2, 3, 5, 5, 5), + with_tf32=True, ), dict( module_name='ConvTranspose3d', @@ -2132,6 +2193,7 @@ new_module_tests = [ cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})', cudnn=True, input_size=(1, 2, 4, 5, 4), + with_tf32=True, ), dict( module_name='ConvTranspose3d', @@ -2141,6 +2203,7 @@ new_module_tests = [ cudnn=True, input_size=(1, 2, 4, 5, 4), desc='dilated', + with_tf32=True, ), dict( module_name='MaxPool3d', @@ -3458,6 +3521,8 @@ for padding_mode, cpp_padding_mode in zip( output_size=output_size, cudnn=True, desc='{}_stride2_pad2'.format(padding_mode), + with_tf32=True, + tf32_precision=0.05 ), ) @@ -4876,6 +4941,8 @@ class NewModuleTest(InputVariableMixin, ModuleTest): self.check_inplace = kwargs.get('check_inplace', False) self.check_gradgrad = kwargs.get('check_gradgrad', True) self.skip_double = kwargs.get('skip_double', False) + self.with_tf32 = kwargs.get('with_tf32', False) + self.tf32_precision = kwargs.get('tf32_precision', 0.001) self.test_cpu = kwargs.get('test_cpu', True) def _do_test(self, test_case, module, input):