mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable TF32 support for cuDNN (#40737)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40737 Reviewed By: mruberry Differential Revision: D22801525 Pulled By: ngimel fbshipit-source-id: ac7f7e728b4b3e01925337e8c9996f26a6433fd2
This commit is contained in:
parent
93fbbaab2a
commit
5e97f251a8
|
|
@ -78,6 +78,14 @@ void Context::alertNotDeterministic(c10::string_view const& caller) {
|
|||
}
|
||||
}
|
||||
|
||||
bool Context::allowTF32CuDNN() const {
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
void Context::setAllowTF32CuDNN(bool b) {
|
||||
allow_tf32_cudnn = b;
|
||||
}
|
||||
|
||||
static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
|
||||
static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" };
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,8 @@ class CAFFE2_API Context {
|
|||
bool deterministic() const;
|
||||
void setDeterministic(bool);
|
||||
void alertNotDeterministic(c10::string_view const& caller);
|
||||
bool allowTF32CuDNN() const;
|
||||
void setAllowTF32CuDNN(bool);
|
||||
bool allowTF32CuBLAS() const;
|
||||
void setAllowTF32CuBLAS(bool);
|
||||
void alertCuBLASConfigNotDeterministic();
|
||||
|
|
@ -146,6 +148,7 @@ class CAFFE2_API Context {
|
|||
bool deterministic_cudnn = false;
|
||||
bool _deterministic = false;
|
||||
bool benchmark_cudnn = false;
|
||||
bool allow_tf32_cudnn = true;
|
||||
bool allow_tf32_cublas = true;
|
||||
bool enabled_mkldnn = true;
|
||||
#ifdef C10_MOBILE
|
||||
|
|
|
|||
|
|
@ -255,7 +255,8 @@ TORCH_LIBRARY_IMPL(_, Autocast, m) {
|
|||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||
KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(_convolution_nogroup), "_convolution_nogroup", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef), fp16)
|
||||
KERNEL(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16)
|
||||
KERNEL(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16)
|
||||
|
|
@ -267,8 +268,10 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
|||
KERNEL(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16)
|
||||
KERNEL(ADD_NS(prelu), "prelu", Tensor (const Tensor &, const Tensor &), fp16)
|
||||
KERNEL(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16)
|
||||
KERNEL(ADD_NS(addmv), "addmv", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16)
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ struct TORCH_CUDA_API ConvolutionDescriptor
|
|||
&cudnnCreateConvolutionDescriptor,
|
||||
&cudnnDestroyConvolutionDescriptor>
|
||||
{
|
||||
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) {
|
||||
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
|
||||
cudnnDataType_t mathType = dataType;
|
||||
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
|
||||
|
|
@ -172,9 +172,13 @@ struct TORCH_CUDA_API ConvolutionDescriptor
|
|||
AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
|
||||
// See Note [behavior of cudnnFind and cudnnGet]
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
|
||||
if(dataType == CUDNN_DATA_HALF)
|
||||
if(dataType == CUDNN_DATA_HALF) {
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
|
||||
|
||||
} else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -235,7 +239,7 @@ struct TORCH_CUDA_API RNNDescriptor
|
|||
DropoutDescriptor dropout_desc_;
|
||||
void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc,
|
||||
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
|
||||
cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo) {
|
||||
cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
|
||||
dropout_desc_ = std::move(dropout_desc);
|
||||
AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
|
||||
handle,
|
||||
|
|
@ -252,7 +256,13 @@ struct TORCH_CUDA_API RNNDescriptor
|
|||
if (prop->major >= 7) {
|
||||
if (input_type == CUDNN_DATA_HALF) {
|
||||
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
|
||||
} else {
|
||||
}
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
|
||||
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
// Technically, as the default it's not necessary to explicitly
|
||||
// set this.
|
||||
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ struct ConvParams {
|
|||
bool benchmark;
|
||||
bool deterministic;
|
||||
bool cudnn_enabled;
|
||||
bool allow_tf32;
|
||||
|
||||
bool is_strided() const;
|
||||
bool is_dilated() const;
|
||||
|
|
@ -582,7 +583,7 @@ at::Tensor convolution(
|
|||
auto& ctx = at::globalContext();
|
||||
return at::_convolution(input, weight, bias, stride, padding, dilation,
|
||||
transposed, output_padding, groups,
|
||||
ctx.benchmarkCuDNN(), ctx.deterministicCuDNN() || ctx.deterministic(), ctx.userEnabledCuDNN());
|
||||
ctx.benchmarkCuDNN(), ctx.deterministicCuDNN() || ctx.deterministic(), ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
|
||||
}
|
||||
|
||||
at::Tensor convolution_overrideable(
|
||||
|
|
@ -596,7 +597,7 @@ at::Tensor _convolution(
|
|||
const Tensor& input_r, const Tensor& weight_r, const Tensor& bias_r,
|
||||
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
|
||||
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
|
||||
bool benchmark, bool deterministic, bool cudnn_enabled) {
|
||||
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
|
||||
|
||||
const bool input_is_mkldnn = input_r.is_mkldnn();
|
||||
auto input = input_r;
|
||||
|
|
@ -618,6 +619,7 @@ at::Tensor _convolution(
|
|||
params.benchmark = benchmark;
|
||||
params.deterministic = deterministic;
|
||||
params.cudnn_enabled = cudnn_enabled;
|
||||
params.allow_tf32 = allow_tf32;
|
||||
|
||||
check_shape_forward(input, weight_sizes, bias, params);
|
||||
|
||||
|
|
@ -664,7 +666,7 @@ at::Tensor _convolution(
|
|||
if (params.use_cudnn_depthwise(input, weight)) {
|
||||
output = at::cudnn_convolution(
|
||||
input.contiguous(cudnn_memory_format), weight,
|
||||
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
|
||||
padding, stride, dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
|
||||
if (bias.defined()) {
|
||||
output.add_(reshape_bias(input.dim(), bias));
|
||||
}
|
||||
|
|
@ -687,14 +689,14 @@ at::Tensor _convolution(
|
|||
if (params.transposed) {
|
||||
output = at::cudnn_convolution_transpose(
|
||||
input.contiguous(cudnn_memory_format), weight,
|
||||
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
|
||||
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
|
||||
if (bias.defined()) {
|
||||
output.add_(reshape_bias(input.dim(), bias));
|
||||
}
|
||||
} else {
|
||||
output = at::cudnn_convolution(
|
||||
input.contiguous(cudnn_memory_format), weight,
|
||||
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
|
||||
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
|
||||
if (bias.defined()) {
|
||||
output.add_(reshape_bias(input.dim(), bias));
|
||||
}
|
||||
|
|
@ -793,6 +795,15 @@ at::Tensor _convolution(
|
|||
return output;
|
||||
}
|
||||
|
||||
at::Tensor _convolution(
|
||||
const Tensor& input_r, const Tensor& weight_r, const Tensor& bias_r,
|
||||
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
|
||||
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
|
||||
bool benchmark, bool deterministic, bool cudnn_enabled)
|
||||
{
|
||||
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
|
||||
}
|
||||
|
||||
// A generic function for convolution implementations which don't
|
||||
// natively implement groups (e.g., not CuDNN).
|
||||
at::Tensor _convolution_nogroup(
|
||||
|
|
@ -886,7 +897,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
|
||||
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
|
||||
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
|
||||
bool benchmark, bool deterministic, bool cudnn_enabled,
|
||||
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32,
|
||||
std::array<bool, 3> output_mask) {
|
||||
|
||||
auto ggW = ggW_r;
|
||||
|
|
@ -909,6 +920,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
params.benchmark = benchmark;
|
||||
params.deterministic = deterministic;
|
||||
params.cudnn_enabled = cudnn_enabled;
|
||||
params.allow_tf32 = allow_tf32;
|
||||
|
||||
// Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
|
||||
Tensor ggO;
|
||||
|
|
@ -917,14 +929,14 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
if (weight.is_cuda()) {
|
||||
weight = weight.contiguous();
|
||||
}
|
||||
ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled);
|
||||
ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled, params.allow_tf32);
|
||||
}
|
||||
|
||||
if (ggW.defined()) {
|
||||
if (ggW.is_cuda()) {
|
||||
ggW = ggW.contiguous();
|
||||
}
|
||||
auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled);
|
||||
auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled, params.allow_tf32);
|
||||
if (ggO.defined()) {
|
||||
ggO = ggO + ggW_term;
|
||||
} else {
|
||||
|
|
@ -979,9 +991,9 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
// Compute conv
|
||||
if (params.transposed) {
|
||||
gw_conv_params.transposed = false;
|
||||
gWt = at::_convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled);
|
||||
gWt = at::_convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
} else {
|
||||
gWt = at::_convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled);
|
||||
gWt = at::_convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
}
|
||||
} else {
|
||||
std::vector<Tensor> gWt_list(groups);
|
||||
|
|
@ -995,9 +1007,9 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
// Compute conv
|
||||
if (params.transposed) {
|
||||
gw_conv_params.transposed = false;
|
||||
gWt_list[g] = at::_convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled);
|
||||
gWt_list[g] = at::_convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
} else {
|
||||
gWt_list[g] = at::_convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled);
|
||||
gWt_list[g] = at::_convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups, gw_conv_params.benchmark, gw_conv_params.deterministic, gw_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1033,7 +1045,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
if (gO.is_cuda()) {
|
||||
gO = gO.contiguous();
|
||||
}
|
||||
gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled);
|
||||
gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
|
||||
// narrow gI to only relevant portion
|
||||
// we do it this way because negative output_padding is not supported
|
||||
|
|
@ -1087,7 +1099,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
gOt = gOt.contiguous();
|
||||
}
|
||||
|
||||
gIt = at::_convolution(ggWt, gOt, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled);
|
||||
gIt = at::_convolution(ggWt, gOt, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
} else {
|
||||
std::vector<Tensor> gIt_list(params.groups);
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
|
|
@ -1097,7 +1109,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
|
|||
gOt_g = gOt_g.contiguous();
|
||||
}
|
||||
|
||||
gIt_list[g] = at::_convolution(ggWt_g, gOt_g, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled);
|
||||
gIt_list[g] = at::_convolution(ggWt_g, gOt_g, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled, params.allow_tf32);
|
||||
}
|
||||
|
||||
gIt = at::cat(gIt_list, 0);
|
||||
|
|
|
|||
|
|
@ -17,56 +17,56 @@ namespace at { namespace native {
|
|||
at::Tensor cudnn_convolution(
|
||||
const at::Tensor& input, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic) {
|
||||
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
at::Tensor cudnn_convolution_backward_input(
|
||||
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
at::Tensor cudnn_convolution_backward_weight(
|
||||
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic, std::array<bool,2> output_mask) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
|
||||
AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
at::Tensor cudnn_convolution_transpose(
|
||||
const at::Tensor& input, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic) {
|
||||
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
at::Tensor cudnn_convolution_transpose_backward_input(
|
||||
const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic) {
|
||||
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
at::Tensor cudnn_convolution_transpose_backward_weight(
|
||||
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic, std::array<bool,2> output_mask) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
|
||||
AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
|
|
@ -215,6 +215,7 @@ struct ConvolutionParams
|
|||
int dilation[max_dim];
|
||||
int64_t groups;
|
||||
bool deterministic;
|
||||
bool allow_tf32;
|
||||
// NB: transposed purposely omitted: transposed just swaps
|
||||
// forward and backward, so you can reuse the benchmark entry,
|
||||
};
|
||||
|
|
@ -228,7 +229,7 @@ void setConvolutionParams(
|
|||
ConvolutionParams* params,
|
||||
const at::Tensor& input, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool deterministic) {
|
||||
int64_t groups, bool deterministic, bool allow_tf32) {
|
||||
|
||||
cudnnDataType_t dataType = getCudnnDataType(input);
|
||||
memset(params, 0, sizeof(ConvolutionParams));
|
||||
|
|
@ -250,6 +251,7 @@ void setConvolutionParams(
|
|||
// CuDNN, but it doesn't seem worth the effort to actually do this.
|
||||
params->groups = groups;
|
||||
params->deterministic = deterministic;
|
||||
params->allow_tf32 = allow_tf32;
|
||||
}
|
||||
|
||||
// Convenience struct for passing around descriptors and data
|
||||
|
|
@ -658,6 +660,11 @@ public:
|
|||
perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else {
|
||||
perfResults[0].mathType = CUDNN_DEFAULT_MATH;
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
|
||||
perfResults[0].mathType = CUDNN_FMA_MATH;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
search::getWorkspaceSize(args, perfResults[0].algo, &(perfResults[0].memory));
|
||||
return perfResults;
|
||||
|
|
@ -744,7 +751,7 @@ static inline void split_batch_dim_to_32bit_out(
|
|||
const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic,
|
||||
bool benchmark, bool deterministic, bool allow_tf32,
|
||||
int64_t max_worksize, func_t func_32bit) {
|
||||
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
||||
const int64_t ni = input.numel();
|
||||
|
|
@ -752,7 +759,7 @@ static inline void split_batch_dim_to_32bit_out(
|
|||
// Assume the shape of the tensor is (N, C, D1, D2, ...)
|
||||
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
|
||||
if (ni <= int_max && no <= int_max) {
|
||||
func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
func_32bit(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
return;
|
||||
}
|
||||
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension
|
||||
|
|
@ -770,7 +777,7 @@ static inline void split_batch_dim_to_32bit_out(
|
|||
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
|
||||
Tensor input_ = input.narrow(0, start, split_size_);
|
||||
Tensor output_ = output.narrow(0, start, split_size_);
|
||||
func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
func_32bit(output_, input_, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -789,6 +796,16 @@ static inline void split_batch_dim_to_32bit_out(
|
|||
}
|
||||
|
||||
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
#define ASSERT_CORRECT_PRECISION(math_type) \
|
||||
if (args.params.dataType == CUDNN_DATA_FLOAT) { \
|
||||
TORCH_INTERNAL_ASSERT(args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \
|
||||
}
|
||||
#else
|
||||
#define ASSERT_CORRECT_PRECISION(math_type)
|
||||
#endif // CUDNN_VERSION >= 8000
|
||||
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
//
|
||||
// Convolution forward / Transposed convolution backward
|
||||
|
|
@ -808,17 +825,17 @@ static inline void split_batch_dim_to_32bit_out(
|
|||
void raw_cudnn_convolution_forward_out_32bit(
|
||||
const Tensor& output, const Tensor& input, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
|
||||
auto dataType = getCudnnDataType(input);
|
||||
|
||||
ConvolutionArgs args{ input, output, weight };
|
||||
args.handle = getCudnnHandle();
|
||||
setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic);
|
||||
setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32);
|
||||
args.idesc.set(input);
|
||||
args.wdesc.set(weight, 0, input.suggest_memory_format()==at::MemoryFormat::ChannelsLast);
|
||||
args.odesc.set(output);
|
||||
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
|
||||
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
|
||||
|
||||
// TODO: when we do legacy group convolution support, we'll repeatedly
|
||||
// reinitialize the workspace for each convolution we do. This is
|
||||
|
|
@ -832,6 +849,7 @@ void raw_cudnn_convolution_forward_out_32bit(
|
|||
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
|
||||
// whether to use Tensor core kernels or not
|
||||
// See Note [behavior of cudnnFind and cudnnGet]
|
||||
ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType));
|
||||
|
||||
Constant one(dataType, 1);
|
||||
|
|
@ -850,15 +868,15 @@ void raw_cudnn_convolution_forward_out_32bit(
|
|||
void raw_cudnn_convolution_forward_out(
|
||||
const Tensor& output, const Tensor& input, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
|
||||
}
|
||||
|
||||
Tensor cudnn_convolution_forward(
|
||||
CheckedFrom c,
|
||||
const TensorArg& input, const TensorArg& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
checkAllSameType(c, {input, weight});
|
||||
checkAllSameGPU(c, {input, weight});
|
||||
|
|
@ -888,7 +906,7 @@ Tensor cudnn_convolution_forward(
|
|||
|
||||
raw_cudnn_convolution_forward_out(
|
||||
*output, input_contig, weight_contig,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
|
||||
return *output;
|
||||
}
|
||||
|
|
@ -896,13 +914,13 @@ Tensor cudnn_convolution_forward(
|
|||
Tensor cudnn_convolution(
|
||||
const Tensor& input_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic)
|
||||
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
CheckedFrom c = "cudnn_convolution";
|
||||
auto output_t = cudnn_convolution_forward(
|
||||
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
return output_t;
|
||||
}
|
||||
|
||||
|
|
@ -911,28 +929,28 @@ Tensor cudnn_convolution(
|
|||
Tensor cudnn_convolution_transpose_backward_input(
|
||||
const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic)
|
||||
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
TensorArg grad_output { grad_output_t, "grad_output", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
return cudnn_convolution_forward(
|
||||
"cudnn_convolution_transpose_backward_input",
|
||||
grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic, std::array<bool,2> output_mask) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
|
||||
|
||||
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
|
||||
|
||||
Tensor grad_input, grad_weight;
|
||||
if (output_mask[0]) {
|
||||
grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
|
||||
return std::tuple<Tensor,Tensor>{grad_input, grad_weight};
|
||||
|
|
@ -949,16 +967,16 @@ void raw_cudnn_convolution_backward_input_out_32bit(
|
|||
const at::Tensor& grad_output,
|
||||
const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
auto dataType = getCudnnDataType(grad_output);
|
||||
|
||||
ConvolutionArgs args{ grad_input, grad_output, weight };
|
||||
args.handle = getCudnnHandle();
|
||||
setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic);
|
||||
setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic, allow_tf32);
|
||||
args.idesc.set(grad_input);
|
||||
args.wdesc.set(weight, 0, grad_output.suggest_memory_format()==at::MemoryFormat::ChannelsLast);
|
||||
args.odesc.set(grad_output);
|
||||
args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
|
||||
args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
|
||||
|
||||
AlgoIterator<cudnnConvolutionBwdDataAlgoPerf_t>(args, benchmark).try_all(
|
||||
[&](const cudnnConvolutionBwdDataAlgoPerf_t &bwdDataAlgPerf){
|
||||
|
|
@ -967,6 +985,7 @@ void raw_cudnn_convolution_backward_input_out_32bit(
|
|||
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
|
||||
// whether to use Tensor core kernels or not
|
||||
// See Note [behavior of cudnnFind and cudnnGet]
|
||||
ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType);
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType));
|
||||
|
||||
Constant one(dataType, 1);
|
||||
|
|
@ -987,8 +1006,8 @@ void raw_cudnn_convolution_backward_input_out(
|
|||
const at::Tensor& grad_output,
|
||||
const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit);
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
split_batch_dim_to_32bit_out(grad_input, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 128, raw_cudnn_convolution_backward_input_out_32bit);
|
||||
}
|
||||
|
||||
// NOTE [ Backward vs transpose convolutions ]
|
||||
|
|
@ -1007,7 +1026,7 @@ Tensor cudnn_convolution_backward_input(
|
|||
CheckedFrom c,
|
||||
IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
checkAllSameType(c, {grad_output, weight});
|
||||
checkAllSameGPU(c, {grad_output, weight});
|
||||
|
|
@ -1030,7 +1049,7 @@ Tensor cudnn_convolution_backward_input(
|
|||
|
||||
raw_cudnn_convolution_backward_input_out(
|
||||
*grad_input, grad_output_contig, weight_contig,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
|
||||
return *grad_input;
|
||||
}
|
||||
|
|
@ -1039,31 +1058,31 @@ Tensor cudnn_convolution_transpose_forward(
|
|||
CheckedFrom c,
|
||||
const TensorArg& grad_output, const TensorArg& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(),
|
||||
padding, output_padding, stride, dilation, groups);
|
||||
return cudnn_convolution_backward_input(c, input_size, grad_output, weight,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
|
||||
Tensor cudnn_convolution_backward_input(
|
||||
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
|
||||
weight{ weight_t, "weight", 2 };
|
||||
return cudnn_convolution_backward_input(
|
||||
"cudnn_convolution_backward_input",
|
||||
input_size, grad_output, weight,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
|
||||
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic, std::array<bool,2> output_mask) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
|
||||
|
||||
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
|
||||
|
||||
|
|
@ -1077,10 +1096,10 @@ std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
|
|||
}
|
||||
} else {
|
||||
if (output_mask[0]) {
|
||||
grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1090,13 +1109,13 @@ std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
|
|||
Tensor cudnn_convolution_transpose(
|
||||
const Tensor& input_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic)
|
||||
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
CheckedFrom c = "cudnn_convolution_transpose";
|
||||
auto output_t = cudnn_convolution_transpose_forward(
|
||||
c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic);
|
||||
c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
return output_t;
|
||||
}
|
||||
|
||||
|
|
@ -1109,17 +1128,17 @@ Tensor cudnn_convolution_transpose(
|
|||
void raw_cudnn_convolution_backward_weight_out_32bit(
|
||||
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
|
||||
auto dataType = getCudnnDataType(input);
|
||||
|
||||
ConvolutionArgs args{ input, grad_output, grad_weight };
|
||||
args.handle = getCudnnHandle();
|
||||
setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic);
|
||||
setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic, allow_tf32);
|
||||
args.idesc.set(input);
|
||||
args.wdesc.set(grad_weight, 0, input.suggest_memory_format()==at::MemoryFormat::ChannelsLast);
|
||||
args.odesc.set(grad_output);
|
||||
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
|
||||
args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, args.params.allow_tf32);
|
||||
|
||||
AlgoIterator<cudnnConvolutionBwdFilterAlgoPerf_t>(args, benchmark).try_all(
|
||||
[&](const cudnnConvolutionBwdFilterAlgoPerf_t &bwdFilterAlgPerf){
|
||||
|
|
@ -1128,6 +1147,7 @@ void raw_cudnn_convolution_backward_weight_out_32bit(
|
|||
// update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out
|
||||
// whether to use Tensor core kernels or not
|
||||
// See Note [behavior of cudnnFind and cudnnGet]
|
||||
ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType);
|
||||
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType));
|
||||
|
||||
Constant one(dataType, 1);
|
||||
|
|
@ -1146,14 +1166,14 @@ void raw_cudnn_convolution_backward_weight_out_32bit(
|
|||
void raw_cudnn_convolution_backward_weight_out(
|
||||
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic) {
|
||||
bool benchmark, bool deterministic, bool allow_tf32) {
|
||||
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
||||
const int64_t ni = input.numel();
|
||||
const int64_t no = grad_output.numel();
|
||||
// Assume the shape of the tensor is (N, C, D1, D2, ...)
|
||||
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
|
||||
if (ni <= int_max && no <= int_max) {
|
||||
raw_cudnn_convolution_backward_weight_out_32bit(grad_weight, grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
raw_cudnn_convolution_backward_weight_out_32bit(grad_weight, grad_output, input, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
return;
|
||||
}
|
||||
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across the N dimension
|
||||
|
|
@ -1172,7 +1192,7 @@ void raw_cudnn_convolution_backward_weight_out(
|
|||
Tensor input_ = input.narrow(0, start, split_size_);
|
||||
Tensor grad_output_ = grad_output.narrow(0, start, split_size_);
|
||||
Tensor grad_weight_ = at::empty_like(grad_weight);
|
||||
raw_cudnn_convolution_backward_weight_out_32bit(grad_weight_, grad_output_, input_, padding, stride, dilation, groups, benchmark, deterministic);
|
||||
raw_cudnn_convolution_backward_weight_out_32bit(grad_weight_, grad_output_, input_, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
grad_weight.add_(grad_weight_);
|
||||
}
|
||||
return;
|
||||
|
|
@ -1195,7 +1215,7 @@ Tensor cudnn_convolution_backward_weight(
|
|||
CheckedFrom c,
|
||||
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
auto layout = cudnn_conv_use_channels_last(input_t, grad_output_t) ?
|
||||
at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
|
||||
|
|
@ -1221,7 +1241,7 @@ Tensor cudnn_convolution_backward_weight(
|
|||
|
||||
raw_cudnn_convolution_backward_weight_out(
|
||||
*grad_weight, *grad_output_contig, *input,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
|
||||
return grad_weight_t;
|
||||
}
|
||||
|
|
@ -1231,12 +1251,12 @@ Tensor cudnn_convolution_backward_weight(
|
|||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
return cudnn_convolution_backward_weight(
|
||||
"cudnn_convolution_backward_weight",
|
||||
weight_size, grad_output_t, input_t,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
|
||||
Tensor cudnn_convolution_transpose_backward_weight(
|
||||
|
|
@ -1244,12 +1264,12 @@ Tensor cudnn_convolution_transpose_backward_weight(
|
|||
const Tensor& grad_output_t,
|
||||
const Tensor& input_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
|
||||
bool benchmark, bool deterministic)
|
||||
bool benchmark, bool deterministic, bool allow_tf32)
|
||||
{
|
||||
return cudnn_convolution_backward_weight(
|
||||
"cudnn_convolution_backward_weight",
|
||||
weight_size, input_t, grad_output_t,
|
||||
padding, stride, dilation, groups, benchmark, deterministic);
|
||||
padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
@ -1271,6 +1291,15 @@ Tensor cudnn_convolution_deprecated(
|
|||
return output;
|
||||
}
|
||||
|
||||
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
|
||||
Tensor cudnn_convolution_deprecated2(
|
||||
const Tensor& input_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic)
|
||||
{
|
||||
return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN());
|
||||
}
|
||||
|
||||
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
|
||||
Tensor cudnn_convolution_transpose_deprecated(
|
||||
const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */,
|
||||
|
|
@ -1284,4 +1313,13 @@ Tensor cudnn_convolution_transpose_deprecated(
|
|||
return output;
|
||||
}
|
||||
|
||||
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
|
||||
Tensor cudnn_convolution_transpose_deprecated2(
|
||||
const Tensor& input_t, const Tensor& weight_t,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool benchmark, bool deterministic)
|
||||
{
|
||||
return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN());
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ namespace {
|
|||
|
||||
RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
||||
RNNDescriptor rnn_desc;
|
||||
rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo);
|
||||
rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN());
|
||||
return rnn_desc;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -928,13 +928,16 @@
|
|||
- func: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
|
||||
- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor
|
||||
|
|
@ -1029,22 +1032,27 @@
|
|||
dispatch:
|
||||
CUDA: cudnn_convolution_deprecated
|
||||
|
||||
- func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
- func: cudnn_convolution.deprecated2(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_deprecated2
|
||||
|
||||
- func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution
|
||||
|
||||
- func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
- func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_backward_input
|
||||
|
||||
- func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
- func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_backward
|
||||
|
||||
- func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
- func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_backward_weight
|
||||
|
|
@ -1054,24 +1062,29 @@
|
|||
dispatch:
|
||||
CUDA: cudnn_convolution_transpose_deprecated
|
||||
|
||||
- func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
- func: cudnn_convolution_transpose.deprecated2(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_transpose_deprecated2
|
||||
|
||||
- func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_transpose
|
||||
|
||||
# NB: output_padding not strictly needed here, but it's helpful for the float
|
||||
# backwards
|
||||
- func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
- func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_transpose_backward
|
||||
|
||||
- func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
- func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_transpose_backward_input
|
||||
|
||||
- func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
- func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CUDA: cudnn_convolution_transpose_backward_weight
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
|||
// should be increased too. The relationship is:
|
||||
// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion
|
||||
// >= kProducedFileFormatVersion
|
||||
constexpr uint64_t kProducedBytecodeVersion = 0x3L;
|
||||
constexpr uint64_t kProducedBytecodeVersion = 0x4L;
|
||||
|
||||
static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
|
||||
"kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion.");
|
||||
|
|
|
|||
|
|
@ -68,14 +68,19 @@ TF32 tensor cores are designed to achieve better performance on matmul and convo
|
|||
`torch.float32` tensors by truncating input data to have 10 bits of mantissa, and accumulating
|
||||
results with FP32 precision, maintaining FP32 dynamic range.
|
||||
|
||||
matmul and convolutions are controlled separately, and their corresponding flag can be accessed at:
|
||||
matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at:
|
||||
|
||||
.. code:: python
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# The allow_tf32 flag for convolutions is not implemented yet
|
||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses
|
||||
matmuls or convolutions are also affected. These include `nn.Linear`, `nn.Conv*`, cdist, tensordot,
|
||||
affine grid and grid sample, adaptive log softmax, GRU and LSTM.
|
||||
|
||||
To get an idea of the precision and speed, see the example code below:
|
||||
|
||||
|
|
@ -107,7 +112,7 @@ is needed, users can disable TF32 by:
|
|||
.. code:: python
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
# disabling of TF32 for cuDNN is not implemented yet
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
|
||||
For more information about TF32, see:
|
||||
|
||||
|
|
|
|||
|
|
@ -62,6 +62,16 @@ allow_list = [
|
|||
("aten::atan2", datetime.date(2020, 7, 30)),
|
||||
("aten::copy_", datetime.date(2020, 7, 30)),
|
||||
("aten::sort", datetime.date(2020, 7, 30)),
|
||||
('aten::_convolution', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_transpose', datetime.date(2020, 10, 15)),
|
||||
('aten::_convolution_double_backward', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_backward_input', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_backward', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_backward_weight', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_transpose_backward', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_transpose_backward_input', datetime.date(2020, 10, 15)),
|
||||
('aten::cudnn_convolution_transpose_backward_weight', datetime.date(2020, 10, 15)),
|
||||
("aten::_cudnn_init_dropout_state", datetime.date(2020, 7, 30)),
|
||||
("aten::sparse_coo_tensor", datetime.date(2020, 7, 30)),
|
||||
("aten::_sparse_coo_tensor_with_dims", datetime.date(2020, 7, 30)),
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ void testLiteInterpreterConv() {
|
|||
m.register_parameter("bias", torch::ones({20}), false);
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True)
|
||||
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
|
||||
)");
|
||||
|
||||
inputs.push_back(torch::ones({1, 1, 28, 28}));
|
||||
|
|
|
|||
|
|
@ -528,13 +528,19 @@ class TestCuda(TestCase):
|
|||
q_copy[1].fill_(10)
|
||||
self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
|
||||
|
||||
def test_allow_tf32_get_set(self):
|
||||
def test_cublas_allow_tf32_get_set(self):
|
||||
orig = torch.backends.cuda.matmul.allow_tf32
|
||||
self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
|
||||
torch.backends.cuda.matmul.allow_tf32 = not orig
|
||||
self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
|
||||
torch.backends.cuda.matmul.allow_tf32 = orig
|
||||
|
||||
def test_cudnn_allow_tf32_get_set(self):
|
||||
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
|
||||
self.assertFalse(torch.backends.cudnn.allow_tf32)
|
||||
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
|
||||
self.assertTrue(torch.backends.cudnn.allow_tf32)
|
||||
|
||||
def test_type_conversions(self):
|
||||
x = torch.randn(5, 5)
|
||||
self.assertIsInstance(x.float(), torch.FloatTensor)
|
||||
|
|
|
|||
637
test/test_nn.py
637
test/test_nn.py
|
|
@ -51,6 +51,7 @@ from hypothesis import given
|
|||
import torch.testing._internal.hypothesis_utils as hu
|
||||
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks
|
||||
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
|
||||
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
|
|
@ -63,9 +64,6 @@ if TEST_SCIPY:
|
|||
if TEST_NUMPY:
|
||||
import numpy as np
|
||||
|
||||
NO_HALF_TENSORTYPES = [torch.float,
|
||||
torch.double]
|
||||
|
||||
DOUBLE_TENSORTYPES = [torch.double]
|
||||
|
||||
|
||||
|
|
@ -5194,73 +5192,6 @@ class TestNN(NNTestCase):
|
|||
output_cpu = rnn(input.cpu(), hx)
|
||||
self.assertEqual(output_cuda, output_cpu)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
|
||||
@repeat_test_for_types(NO_HALF_TENSORTYPES)
|
||||
def test_cuda_rnn_fused(self, dtype=torch.float):
|
||||
|
||||
def copy_rnn(rnn1, rnn2):
|
||||
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
|
||||
for x, y in zip(x_layer, y_layer):
|
||||
x.data.copy_(y.data)
|
||||
|
||||
def check_rnn_grads(rnn1, rnn2):
|
||||
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
|
||||
for x, y in zip(x_layer, y_layer):
|
||||
self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0)
|
||||
|
||||
input_size = 10
|
||||
hidden_size = 6
|
||||
num_layers = 2
|
||||
seq_length = 7
|
||||
batch = 6
|
||||
input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
|
||||
grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype)
|
||||
hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
|
||||
grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
for module in (nn.GRU, nn.LSTM):
|
||||
for bias in (True, False):
|
||||
rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype)
|
||||
rnn_cuda = module(input_size, hidden_size, num_layers, bias=bias).to("cuda", dtype)
|
||||
copy_rnn(rnn, rnn_cuda)
|
||||
|
||||
is_lstm = isinstance(rnn, nn.LSTM)
|
||||
if is_lstm:
|
||||
hx = (hx_val.clone().requires_grad_(True),
|
||||
hx_val.clone().add(1).requires_grad_(True))
|
||||
hx_cuda = (hx_val.clone().cuda().requires_grad_(True),
|
||||
hx_val.clone().cuda().add(1).requires_grad_(True))
|
||||
else:
|
||||
hx = hx_val.clone().requires_grad_(True)
|
||||
hx_cuda = hx_val.clone().cuda().requires_grad_(True)
|
||||
|
||||
inp = input_val.clone().requires_grad_(True)
|
||||
inp_cu = input_val.clone().cuda().requires_grad_(True)
|
||||
output1, hy1 = rnn(inp, hx)
|
||||
output2, hy2 = rnn_cuda(inp_cu, hx_cuda)
|
||||
if is_lstm:
|
||||
torch.autograd.backward(
|
||||
[output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1]
|
||||
)
|
||||
torch.autograd.backward(
|
||||
[output2, hy2[0], hy2[1]],
|
||||
[grad_output.cuda(), grad_hy.cuda(), (grad_hy + 1).cuda()]
|
||||
)
|
||||
else:
|
||||
torch.autograd.backward([output1, hy1], [grad_output, grad_hy])
|
||||
torch.autograd.backward([output2, hy2], [grad_output.cuda(), grad_hy.cuda()])
|
||||
|
||||
self.assertEqual(output1, output2)
|
||||
self.assertEqual(hy1, hy2)
|
||||
|
||||
check_rnn_grads(rnn, rnn_cuda)
|
||||
self.assertEqual(inp.grad.data, inp_cu.grad.data)
|
||||
if is_lstm:
|
||||
self.assertEqual(hx[0].grad.data, hx_cuda[0].grad.data)
|
||||
self.assertEqual(hx[1].grad.data, hx_cuda[1].grad.data)
|
||||
else:
|
||||
self.assertEqual(hx.grad.data, hx_cuda.grad.data)
|
||||
|
||||
def test_transformer_args_check(self):
|
||||
model_name = 'Transformer'
|
||||
d_model = 128
|
||||
|
|
@ -7199,236 +7130,6 @@ class TestNN(NNTestCase):
|
|||
self.assertEqual(out_cpu, out_cuda)
|
||||
self.assertEqual(input_cpu.grad, input_gpu.grad)
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_affine_2d_rotate0(self):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for device in device_():
|
||||
input_size = [1, 1, 3, 3]
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32)
|
||||
output_size = [1, 1, 5, 5]
|
||||
angle_rad = 0.
|
||||
|
||||
transform_tensor, transform_ary, offset = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
offset=offset,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False)
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu').numpy()
|
||||
|
||||
assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6
|
||||
assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_affine_2d_rotate90(self):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for device, input_size2dsq, output_size2dsq in \
|
||||
itertools.product(device_(), input_size2dsq_(), output_size2dsq_()):
|
||||
input_size = input_size2dsq
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32)
|
||||
output_size = output_size2dsq
|
||||
angle_rad = 0.25 * math.pi * 2
|
||||
|
||||
transform_tensor, transform_ary, offset = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
offset=offset,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=True)
|
||||
|
||||
if input_size2dsq == output_size2dsq:
|
||||
assert np.abs(scipy_ary.mean() - input_ary.mean()) < 1e-6
|
||||
assert np.abs(scipy_ary[0, 0] - input_ary[0, 0, 0, -1]).max() < 1e-6
|
||||
assert np.abs(scipy_ary[0, -1] - input_ary[0, 0, -1, -1]).max() < 1e-6
|
||||
assert np.abs(scipy_ary[-1, -1] - input_ary[0, 0, -1, 0]).max() < 1e-6
|
||||
assert np.abs(scipy_ary[-1, 0] - input_ary[0, 0, 0, 0]).max() < 1e-6
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu').numpy()
|
||||
|
||||
assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6
|
||||
assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_affine_2d_rotate45(self):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for device in device_():
|
||||
input_size = [1, 1, 3, 3]
|
||||
input_ary = np.array(np.zeros(input_size), dtype=np.float32)
|
||||
input_ary[0, 0, 0, :] = 0.5
|
||||
input_ary[0, 0, 2, 2] = 1.0
|
||||
output_size = [1, 1, 3, 3]
|
||||
angle_rad = 0.125 * math.pi * 2
|
||||
|
||||
transform_tensor, transform_ary, offset = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
offset=offset,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False)
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu').numpy()
|
||||
|
||||
assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_affine_2d_rotateRandom(self):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for device, angle_rad, input_size2d, output_size2d in \
|
||||
itertools.product(device_(), angle_rad_(), input_size2d_(), output_size2d_()):
|
||||
|
||||
input_size = input_size2d
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3)
|
||||
output_size = output_size2d
|
||||
|
||||
input_ary[0, 0, 0, 0] = 2
|
||||
input_ary[0, 0, 0, -1] = 4
|
||||
input_ary[0, 0, -1, 0] = 6
|
||||
input_ary[0, 0, -1, -1] = 8
|
||||
|
||||
transform_tensor, transform_ary, grid_ary = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False)
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu').numpy()
|
||||
|
||||
affine_tensor = affine_tensor.to('cpu')
|
||||
|
||||
for r in range(affine_tensor.size(1)):
|
||||
for c in range(affine_tensor.size(2)):
|
||||
grid_out = np.dot(grid_ary, [r, c, 1])
|
||||
self.assertEqual(affine_tensor[0, r, c], grid_out[:2])
|
||||
|
||||
assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
def test_affine_3d_rotateRandom(self):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for device, angle_rad, axis_vector, input_size3d, output_size3d in \
|
||||
itertools.product(device_(), angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()):
|
||||
input_size = input_size3d
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32)
|
||||
output_size = output_size3d
|
||||
|
||||
input_ary[0, 0, 0, 0, 0] = 2
|
||||
input_ary[0, 0, 0, 0, -1] = 3
|
||||
input_ary[0, 0, 0, -1, 0] = 4
|
||||
input_ary[0, 0, 0, -1, -1] = 5
|
||||
input_ary[0, 0, -1, 0, 0] = 6
|
||||
input_ary[0, 0, -1, 0, -1] = 7
|
||||
input_ary[0, 0, -1, -1, 0] = 8
|
||||
input_ary[0, 0, -1, -1, -1] = 9
|
||||
|
||||
transform_tensor, transform_ary, grid_ary = \
|
||||
_buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
|
||||
|
||||
scipy_ary = scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False)
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu').numpy()
|
||||
|
||||
affine_tensor = affine_tensor.to('cpu')
|
||||
|
||||
for i in range(affine_tensor.size(1)):
|
||||
for r in range(affine_tensor.size(2)):
|
||||
for c in range(affine_tensor.size(3)):
|
||||
grid_out = np.dot(grid_ary, [i, r, c, 1])
|
||||
self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3])
|
||||
|
||||
assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
|
||||
|
||||
def test_channel_shuffle(self):
|
||||
# 3D tensor
|
||||
x = torch.tensor(
|
||||
|
|
@ -8915,6 +8616,20 @@ def add_test(test, decorator=None):
|
|||
kwargs['extra_args'] = test.extra_args
|
||||
|
||||
if 'dtype' in get_function_arglist(test.test_cuda):
|
||||
if tf32_is_not_fp32() and test.with_tf32:
|
||||
|
||||
def with_tf32_off(self, test=test, kwargs=kwargs):
|
||||
with tf32_off():
|
||||
test.test_cuda(self, dtype=torch.float, **kwargs)
|
||||
|
||||
add(cuda_test_name + '_fp32', with_tf32_off)
|
||||
|
||||
def with_tf32_on(self, test=test, kwargs=kwargs):
|
||||
with tf32_on(self, test.tf32_precision):
|
||||
test.test_cuda(self, dtype=torch.float, **kwargs)
|
||||
|
||||
add(cuda_test_name + '_tf32', with_tf32_on)
|
||||
else:
|
||||
add(cuda_test_name + '_float', lambda self,
|
||||
test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.float, **kwargs))
|
||||
add(cuda_test_name + '_double', lambda self,
|
||||
|
|
@ -8930,6 +8645,20 @@ def add_test(test, decorator=None):
|
|||
if getattr(test, 'check_bfloat16', True):
|
||||
add(cuda_test_name + '_bfloat16', test_bfloat16)
|
||||
|
||||
else:
|
||||
if tf32_is_not_fp32() and test.with_tf32:
|
||||
|
||||
def with_tf32_off(self, test=test, kwargs=kwargs):
|
||||
with tf32_off():
|
||||
test.test_cuda(self, **kwargs)
|
||||
|
||||
add(cuda_test_name + '_fp32', with_tf32_off)
|
||||
|
||||
def with_tf32_on(self, test=test, kwargs=kwargs):
|
||||
with tf32_on(self, test.tf32_precision):
|
||||
test.test_cuda(self, **kwargs)
|
||||
|
||||
add(cuda_test_name + '_tf32', with_tf32_on)
|
||||
else:
|
||||
add(cuda_test_name, lambda self, test=test, kwargs=kwargs: test.test_cuda(self, **kwargs))
|
||||
|
||||
|
|
@ -9071,7 +8800,9 @@ class _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss):
|
|||
add_test(NewModuleTest(
|
||||
constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]),
|
||||
input_size=(4, 16),
|
||||
fullname='AdaptiveLogSoftmax'))
|
||||
fullname='AdaptiveLogSoftmax',
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005))
|
||||
|
||||
|
||||
# The following are helpers for TestNN.test_affine_*
|
||||
|
|
@ -9482,6 +9213,239 @@ class TestNNDeviceType(NNTestCase):
|
|||
self.assertEqual(p.grad, torch.zeros_like(p.grad))
|
||||
self.assertEqual(inp.grad, torch.zeros_like(inp))
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
@tf32_on_and_off()
|
||||
def test_affine_2d_rotate0(self, device):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
input_size = [1, 1, 3, 3]
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32)
|
||||
output_size = [1, 1, 5, 5]
|
||||
angle_rad = 0.
|
||||
|
||||
transform_tensor, transform_ary, offset = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
offset=offset,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False))
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu')
|
||||
|
||||
self.assertEqual(scipy_ary.mean(), gridsample_ary.mean())
|
||||
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
@tf32_on_and_off(0.001)
|
||||
def test_affine_2d_rotate90(self, device):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for input_size2dsq, output_size2dsq in \
|
||||
itertools.product(input_size2dsq_(), output_size2dsq_()):
|
||||
input_size = input_size2dsq
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32)
|
||||
output_size = output_size2dsq
|
||||
angle_rad = 0.25 * math.pi * 2
|
||||
|
||||
transform_tensor, transform_ary, offset = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
offset=offset,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=True))
|
||||
|
||||
if input_size2dsq == output_size2dsq:
|
||||
self.assertEqual(scipy_ary.mean(), input_ary.mean())
|
||||
self.assertEqual(scipy_ary[0, 0], input_ary[0, 0, 0, -1])
|
||||
self.assertEqual(scipy_ary[0, -1], input_ary[0, 0, -1, -1])
|
||||
self.assertEqual(scipy_ary[-1, -1], input_ary[0, 0, -1, 0])
|
||||
self.assertEqual(scipy_ary[-1, 0], input_ary[0, 0, 0, 0])
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu')
|
||||
|
||||
self.assertEqual(scipy_ary.mean(), gridsample_ary.mean())
|
||||
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_affine_2d_rotate45(self, device):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
input_size = [1, 1, 3, 3]
|
||||
input_ary = np.array(np.zeros(input_size), dtype=np.float32)
|
||||
input_ary[0, 0, 0, :] = 0.5
|
||||
input_ary[0, 0, 2, 2] = 1.0
|
||||
output_size = [1, 1, 3, 3]
|
||||
angle_rad = 0.125 * math.pi * 2
|
||||
|
||||
transform_tensor, transform_ary, offset = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
offset=offset,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False))
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu')
|
||||
|
||||
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_affine_2d_rotateRandom(self, device):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for angle_rad, input_size2d, output_size2d in \
|
||||
itertools.product(angle_rad_(), input_size2d_(), output_size2d_()):
|
||||
|
||||
input_size = input_size2d
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3)
|
||||
output_size = output_size2d
|
||||
|
||||
input_ary[0, 0, 0, 0] = 2
|
||||
input_ary[0, 0, 0, -1] = 4
|
||||
input_ary[0, 0, -1, 0] = 6
|
||||
input_ary[0, 0, -1, -1] = 8
|
||||
|
||||
transform_tensor, transform_ary, grid_ary = \
|
||||
_buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
|
||||
|
||||
scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False))
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu')
|
||||
|
||||
affine_tensor = affine_tensor.to('cpu')
|
||||
|
||||
for r in range(affine_tensor.size(1)):
|
||||
for c in range(affine_tensor.size(2)):
|
||||
grid_out = np.dot(grid_ary, [r, c, 1])
|
||||
self.assertEqual(affine_tensor[0, r, c], grid_out[:2])
|
||||
|
||||
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
|
||||
|
||||
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
|
||||
"Scipy v1.0 and/or numpy not found")
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_affine_3d_rotateRandom(self, device):
|
||||
# scipy before 1.0.0 do not support homogeneous coordinate
|
||||
# scipy.ndimage.affine_transform, so we need to skip.
|
||||
for angle_rad, axis_vector, input_size3d, output_size3d in \
|
||||
itertools.product(angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()):
|
||||
input_size = input_size3d
|
||||
input_ary = np.array(np.random.random(input_size), dtype=np.float32)
|
||||
output_size = output_size3d
|
||||
|
||||
input_ary[0, 0, 0, 0, 0] = 2
|
||||
input_ary[0, 0, 0, 0, -1] = 3
|
||||
input_ary[0, 0, 0, -1, 0] = 4
|
||||
input_ary[0, 0, 0, -1, -1] = 5
|
||||
input_ary[0, 0, -1, 0, 0] = 6
|
||||
input_ary[0, 0, -1, 0, -1] = 7
|
||||
input_ary[0, 0, -1, -1, 0] = 8
|
||||
input_ary[0, 0, -1, -1, -1] = 9
|
||||
|
||||
transform_tensor, transform_ary, grid_ary = \
|
||||
_buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
|
||||
|
||||
scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
|
||||
input_ary[0, 0],
|
||||
transform_ary,
|
||||
output_shape=output_size[2:],
|
||||
order=1,
|
||||
mode='nearest',
|
||||
prefilter=False))
|
||||
|
||||
affine_tensor = torch.nn.functional.affine_grid(
|
||||
transform_tensor,
|
||||
torch.Size(output_size),
|
||||
align_corners=True
|
||||
)
|
||||
|
||||
gridsample_ary = torch.nn.functional.grid_sample(
|
||||
torch.tensor(input_ary, device=device).to(device),
|
||||
affine_tensor,
|
||||
padding_mode='border',
|
||||
align_corners=True
|
||||
).to('cpu')
|
||||
|
||||
affine_tensor = affine_tensor.to('cpu')
|
||||
|
||||
for i in range(affine_tensor.size(1)):
|
||||
for r in range(affine_tensor.size(2)):
|
||||
for c in range(affine_tensor.size(3)):
|
||||
grid_out = np.dot(grid_ary, [i, r, c, 1])
|
||||
self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3])
|
||||
|
||||
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
|
||||
|
||||
def test_Dropout(self, device):
|
||||
input = torch.Tensor(1000)
|
||||
self._test_dropout(nn.Dropout, device, input)
|
||||
|
|
@ -9594,6 +9558,74 @@ class TestNNDeviceType(NNTestCase):
|
|||
inp = torch.randn(3, 0, 10, 10, device=device)
|
||||
mod(inp)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float, torch.double)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_rnn_fused(self, device, dtype):
|
||||
|
||||
def copy_rnn(rnn1, rnn2):
|
||||
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
|
||||
for x, y in zip(x_layer, y_layer):
|
||||
x.data.copy_(y.data)
|
||||
|
||||
def check_rnn_grads(rnn1, rnn2):
|
||||
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
|
||||
for x, y in zip(x_layer, y_layer):
|
||||
self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0)
|
||||
|
||||
input_size = 10
|
||||
hidden_size = 6
|
||||
num_layers = 2
|
||||
seq_length = 7
|
||||
batch = 6
|
||||
input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
|
||||
grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype)
|
||||
hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
|
||||
grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
|
||||
with torch.backends.cudnn.flags(enabled=False, allow_tf32=None):
|
||||
for module in (nn.GRU, nn.LSTM):
|
||||
for bias in (True, False):
|
||||
rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype)
|
||||
rnn_device = module(input_size, hidden_size, num_layers, bias=bias).to(device, dtype)
|
||||
copy_rnn(rnn, rnn_device)
|
||||
|
||||
is_lstm = isinstance(rnn, nn.LSTM)
|
||||
if is_lstm:
|
||||
hx = (hx_val.clone().requires_grad_(True),
|
||||
hx_val.clone().add(1).requires_grad_(True))
|
||||
hx_device = (hx_val.clone().to(device).requires_grad_(True),
|
||||
hx_val.clone().to(device).add(1).requires_grad_(True))
|
||||
else:
|
||||
hx = hx_val.clone().requires_grad_(True)
|
||||
hx_device = hx_val.clone().to(device).requires_grad_(True)
|
||||
|
||||
inp = input_val.clone().requires_grad_(True)
|
||||
inp_cu = input_val.clone().to(device).requires_grad_(True)
|
||||
output1, hy1 = rnn(inp, hx)
|
||||
output2, hy2 = rnn_device(inp_cu, hx_device)
|
||||
if is_lstm:
|
||||
torch.autograd.backward(
|
||||
[output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1]
|
||||
)
|
||||
torch.autograd.backward(
|
||||
[output2, hy2[0], hy2[1]],
|
||||
[grad_output.to(device), grad_hy.to(device), (grad_hy + 1).to(device)]
|
||||
)
|
||||
else:
|
||||
torch.autograd.backward([output1, hy1], [grad_output, grad_hy])
|
||||
torch.autograd.backward([output2, hy2], [grad_output.to(device), grad_hy.to(device)])
|
||||
|
||||
self.assertEqual(output1, output2)
|
||||
self.assertEqual(hy1, hy2)
|
||||
|
||||
check_rnn_grads(rnn, rnn_device)
|
||||
self.assertEqual(inp.grad, inp_cu.grad)
|
||||
if is_lstm:
|
||||
self.assertEqual(hx[0].grad, hx_device[0].grad)
|
||||
self.assertEqual(hx[1].grad, hx_device[1].grad)
|
||||
else:
|
||||
self.assertEqual(hx.grad, hx_device.grad)
|
||||
|
||||
def test_BatchNorm_empty(self, device):
|
||||
mod = torch.nn.BatchNorm2d(3).to(device)
|
||||
inp = torch.randn(0, 3, 2, 2, device=device)
|
||||
|
|
@ -10209,6 +10241,7 @@ class TestNNDeviceType(NNTestCase):
|
|||
self.assertEqual(out1, out2)
|
||||
|
||||
@onlyCUDA
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_grid_sample_large(self, device):
|
||||
def issue_35202():
|
||||
input_tensor = torch.rand(1, 1, 480, 640, dtype=torch.float, device=device, requires_grad=True)
|
||||
|
|
@ -10235,7 +10268,7 @@ class TestNNDeviceType(NNTestCase):
|
|||
result.backward(torch.ones_like(result))
|
||||
expected_grad = torch.ones_like(image)
|
||||
expected_grad[0, 0, 1, 1, 1] = 0
|
||||
self.assertTrue(torch.allclose(image.grad, expected_grad, atol=1e-3))
|
||||
self.assertEqual(image.grad, expected_grad, atol=0.005, rtol=0)
|
||||
issue_24823_1(torch.half)
|
||||
issue_24823_1(torch.float)
|
||||
issue_24823_1(torch.double)
|
||||
|
|
|
|||
|
|
@ -1405,50 +1405,50 @@
|
|||
input, weight, bias: "grad.defined() ? convolution_backward_overrideable(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, false, output_padding, groups, false, false, false, grad_input_mask)
|
||||
grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, false, output_padding, groups, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? slow_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, empty_like(grad, at::MemoryFormat::Contiguous), empty_like(grad, at::MemoryFormat::Contiguous), grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_transpose2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? slow_conv_transpose3d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, empty_like(grad, at::MemoryFormat::Preserve), empty_like(grad, at::MemoryFormat::Preserve), grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_transpose3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input)
|
||||
self, weight, bias: "grad.defined() ? thnn_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: thnn_conv_depthwise2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor
|
||||
self, weight: "grad.defined() ? thnn_conv_depthwise2d_backward(grad.contiguous(), self, weight, kernel_size, stride, padding, dilation, grad_input_mask) : std::tuple<Tensor, Tensor>()"
|
||||
bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1)
|
||||
|
||||
- name: thnn_conv_depthwise2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[2] output_mask) -> (Tensor grad_input, Tensor grad_weight)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input)
|
||||
self, weight, bias: "grad.defined() ? slow_conv3d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? slow_conv_dilated2d_backward(grad, self, weight, kernel_size, stride, padding, dilation, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? slow_conv_dilated3d_backward(grad, self, weight, kernel_size, stride, padding, dilation, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, false, grad_input_mask)
|
||||
|
||||
- name: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
|
||||
self: col2im_backward(grad, kernel_size, dilation, padding, stride)
|
||||
|
|
@ -1661,17 +1661,17 @@
|
|||
- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
|
||||
log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)
|
||||
|
||||
- name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight: "grad.defined() ? cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple<Tensor, Tensor>()"
|
||||
- name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
self, weight: "grad.defined() ? cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, grad_input_mask) : std::tuple<Tensor, Tensor>()"
|
||||
|
||||
- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask)
|
||||
- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, allow_tf32, grad_input_mask)
|
||||
|
||||
- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight: "grad.defined() ? cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple<Tensor, Tensor>()"
|
||||
- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
|
||||
self, weight: "grad.defined() ? cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, grad_input_mask) : std::tuple<Tensor, Tensor>()"
|
||||
|
||||
- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
||||
- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, allow_tf32, grad_input_mask)
|
||||
|
||||
# The above backward definitions are equivalent to the definitions below. Why do we bundle
|
||||
# everything up? It's because it's more convenient to define double backwards
|
||||
|
|
@ -1734,19 +1734,19 @@
|
|||
self, weight, bias: "grad.defined() ? miopen_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, false, grad_input_mask)
|
||||
|
||||
- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? miopen_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, false, grad_input_mask)
|
||||
|
||||
- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? miopen_depthwise_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_depthwise_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, false, grad_input_mask)
|
||||
|
||||
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
|
||||
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
|
@ -1769,7 +1769,7 @@
|
|||
self, weight, bias: "grad.defined() ? mkldnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, false, false, false, grad_input_mask)
|
||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, false, false, false, false, grad_input_mask)
|
||||
|
||||
# fft
|
||||
- name: _fft_with_size(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, bool normalized, bool onesided, int[] output_sizes) -> Tensor
|
||||
|
|
|
|||
|
|
@ -193,6 +193,8 @@ def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
|
|||
def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
|
||||
def _get_deterministic() -> _bool: ... # THPModule_deterministic
|
||||
def _set_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministic
|
||||
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
|
||||
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
|
||||
# NB: There is no Capsule type in typing, see
|
||||
# https://code.activestate.com/lists/python-dev/139675/
|
||||
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
|
||||
|
|
|
|||
|
|
@ -83,26 +83,32 @@ def is_acceptable(tensor):
|
|||
return True
|
||||
|
||||
|
||||
def set_flags(_enabled, _benchmark, _deterministic):
|
||||
def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=None):
|
||||
orig_flags = (torch._C._get_cudnn_enabled(),
|
||||
torch._C._get_cudnn_benchmark(),
|
||||
torch._C._get_cudnn_deterministic())
|
||||
torch._C._get_cudnn_deterministic(),
|
||||
torch._C._get_cudnn_allow_tf32())
|
||||
if _enabled is not None:
|
||||
torch._C._set_cudnn_enabled(_enabled)
|
||||
if _benchmark is not None:
|
||||
torch._C._set_cudnn_benchmark(_benchmark)
|
||||
if _deterministic is not None:
|
||||
torch._C._set_cudnn_deterministic(_deterministic)
|
||||
if _allow_tf32 is not None:
|
||||
torch._C._set_cudnn_allow_tf32(_allow_tf32)
|
||||
return orig_flags
|
||||
|
||||
|
||||
@contextmanager
|
||||
def flags(enabled=False, benchmark=False, deterministic=False):
|
||||
def flags(enabled=False, benchmark=False, deterministic=False, allow_tf32=True):
|
||||
with __allow_nonbracketed_mutation():
|
||||
orig_flags = set_flags(enabled, benchmark, deterministic)
|
||||
orig_flags = set_flags(enabled, benchmark, deterministic, allow_tf32)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# recover the previous values
|
||||
with __allow_nonbracketed_mutation():
|
||||
set_flags(orig_flags[0], orig_flags[1], orig_flags[2])
|
||||
set_flags(*orig_flags)
|
||||
|
||||
|
||||
# The magic here is to allow us to intercept code like this:
|
||||
|
|
@ -116,6 +122,7 @@ class CudnnModule(PropModule):
|
|||
enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
|
||||
deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic)
|
||||
benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark)
|
||||
allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32)
|
||||
|
||||
# This is the sys.modules replacement trick, see
|
||||
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
|
||||
|
|
|
|||
|
|
@ -388,6 +388,20 @@ PyObject *THPModule_fromDLPack(PyObject *_unused, PyObject *data)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPModule_setAllowTF32CuDNN(PyObject *_unused, PyObject *arg)
|
||||
{
|
||||
THPUtils_assert(PyBool_Check(arg), "set_allow_tf32_cublas expects a bool, "
|
||||
"but got %s", THPUtils_typename(arg));
|
||||
at::globalContext().setAllowTF32CuDNN(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *THPModule_allowTF32CuDNN(PyObject *_unused, PyObject *noargs)
|
||||
{
|
||||
if (at::globalContext().allowTF32CuDNN()) Py_RETURN_TRUE;
|
||||
else Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
PyObject *THPModule_setUserEnabledCuDNN(PyObject *_unused, PyObject *arg)
|
||||
{
|
||||
THPUtils_assert(PyBool_Check(arg), "set_enabled_cudnn expects a bool, "
|
||||
|
|
@ -577,6 +591,8 @@ static PyMethodDef TorchMethods[] = {
|
|||
{"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr},
|
||||
{"_get_mkldnn_enabled", (PyCFunction)THPModule_userEnabledMkldnn, METH_NOARGS, nullptr},
|
||||
{"_set_mkldnn_enabled", (PyCFunction)THPModule_setUserEnabledMkldnn, METH_O, nullptr},
|
||||
{"_get_cudnn_allow_tf32", (PyCFunction)THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
|
||||
{"_set_cudnn_allow_tf32", (PyCFunction)THPModule_setAllowTF32CuDNN, METH_O, nullptr},
|
||||
{"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
|
||||
{"_set_cudnn_benchmark", (PyCFunction)THPModule_setBenchmarkCuDNN, METH_O, nullptr},
|
||||
{"_get_cudnn_deterministic", (PyCFunction)THPModule_deterministicCuDNN, METH_NOARGS, nullptr},
|
||||
|
|
|
|||
|
|
@ -32,12 +32,21 @@ void replaceConvBiasWithGetAttr(Module& module) {
|
|||
// Thus assumes that tracing will have always gotten rid of aten::conv2d or
|
||||
// aten::conv3d. If it did not, BN folding will fail.
|
||||
const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
|
||||
%conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
|
||||
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
|
||||
return (%conv_out) )");
|
||||
const PatternInfo& pattern_convolution_deprecated =
|
||||
PatternInfo::parse_from_str(R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
|
||||
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
|
||||
return (%conv_out) )");
|
||||
auto replace_pattern = [&](const PatternInfo& pattern_convolution) {
|
||||
const Graph& pattern_convolution_graph = *pattern_convolution.pattern_graph;
|
||||
const auto& convolution_vmap = pattern_convolution.vmap;
|
||||
|
||||
|
|
@ -55,6 +64,9 @@ void replaceConvBiasWithGetAttr(Module& module) {
|
|||
constexpr size_t conv_bias_index = 2;
|
||||
conv_node->replaceInput(conv_bias_index, bias_attr_val);
|
||||
}
|
||||
};
|
||||
replace_pattern(pattern_convolution);
|
||||
replace_pattern(pattern_convolution_deprecated);
|
||||
}
|
||||
|
||||
void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
|
||||
|
|
|
|||
|
|
@ -63,18 +63,25 @@ std::unordered_map<std::string, c10::IValue> getConvParams(
|
|||
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
// TODO: remove constant prop in the pass
|
||||
ConstantPropagation(graph);
|
||||
std::string convolution = R"(
|
||||
std::string convolution_deprecated = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
|
||||
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
|
||||
return (%r) )";
|
||||
std::string convolution = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
|
||||
%r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
|
||||
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
|
||||
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
|
|
@ -88,14 +95,14 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
|||
std::string conv1d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
|
||||
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv3d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
|
||||
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
|
|
@ -167,9 +174,11 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
|||
|
||||
SubgraphRewriter rewriter_conv1d;
|
||||
rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
|
||||
rewriter_conv1d.RegisterRewritePattern(convolution_deprecated, conv1d);
|
||||
rewriter_conv1d.runOnGraph(graph, filter_conv1d);
|
||||
SubgraphRewriter rewriter_conv2d;
|
||||
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
|
||||
rewriter_conv2d.RegisterRewritePattern(convolution_deprecated, conv2d);
|
||||
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
|
||||
SubgraphRewriter rewriter_conv2d_transpose;
|
||||
rewriter_conv2d_transpose.RegisterRewritePattern(
|
||||
|
|
@ -177,6 +186,7 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
|||
rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose);
|
||||
SubgraphRewriter rewriter_conv3d;
|
||||
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
|
||||
rewriter_conv3d.RegisterRewritePattern(convolution_deprecated, conv3d);
|
||||
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1160,7 +1160,8 @@ class ShapePropagator {
|
|||
"aten::conv_transpose2d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
|
||||
"aten::conv_transpose3d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
|
||||
"aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
|
||||
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor",
|
||||
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", // deprecated _convolution
|
||||
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
|
||||
"aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
|
||||
"aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
|
||||
"aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
|
||||
|
|
|
|||
|
|
@ -1109,9 +1109,9 @@ def log_softmax(g, input, dim, dtype=None):
|
|||
return return_op
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i', 'i')
|
||||
def _convolution(g, input, weight, bias, stride, padding, dilation,
|
||||
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
|
||||
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32):
|
||||
weight_size = weight.type().sizes()
|
||||
|
||||
args = [input, weight]
|
||||
|
|
@ -1145,32 +1145,32 @@ def _convolution(g, input, weight, bias, stride, padding, dilation,
|
|||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i')
|
||||
def conv1d(g, input, weight, bias, stride, padding, dilation, groups):
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None)
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i')
|
||||
def conv2d(g, input, weight, bias, stride, padding, dilation, groups):
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None)
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i')
|
||||
def conv3d(g, input, weight, bias, stride, padding, dilation, groups):
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None)
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is')
|
||||
def conv_transpose1d(g, input, weight, bias, stride, padding, output_padding, groups, dilation):
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None)
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is')
|
||||
def conv_transpose2d(g, input, weight, bias, stride, padding, output_padding, groups, dilation):
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None)
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is')
|
||||
def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, groups, dilation):
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None)
|
||||
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
|
||||
|
|
|
|||
|
|
@ -61,8 +61,12 @@ class AutocastTestLists(object):
|
|||
|
||||
# The remaining lists organize ops that autocast treats explicitly.
|
||||
self.torch_fp16 = [
|
||||
# deprecated _convolution
|
||||
("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
|
||||
(0, 0), 1, False, True, True)),
|
||||
# the current _convolution
|
||||
("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
|
||||
(0, 0), 1, False, True, True, True)),
|
||||
("_convolution_nogroup", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0))),
|
||||
("conv1d", conv_args_fp32[0]),
|
||||
("conv2d", conv_args_fp32[1]),
|
||||
|
|
@ -72,13 +76,17 @@ class AutocastTestLists(object):
|
|||
("conv_transpose2d", conv_args_fp32[1], TEST_WITH_ROCM),
|
||||
("conv_transpose3d", conv_args_fp32[2], TEST_WITH_ROCM),
|
||||
("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
|
||||
# cudnn_convolutions with bias
|
||||
# deprecated cudnn_convolutions with bias
|
||||
("cudnn_convolution", conv_args_fp32[1] + bias_fp32 + ((0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM),
|
||||
("cudnn_convolution_transpose", conv_args_fp32[1] + bias_fp32 + ((0, 0), (0, 0), (1, 1),
|
||||
(1, 1), 1, False, True), TEST_WITH_ROCM),
|
||||
# cudnn_convolutions with no bias
|
||||
# deprecated cudnn_convolutions with no allow_tf32 flag
|
||||
("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM),
|
||||
("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1), (1, 1), 1, False, True), TEST_WITH_ROCM),
|
||||
# the current cudnn_convolutions
|
||||
("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
|
||||
("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
|
||||
(1, 1), 1, False, True, True), TEST_WITH_ROCM),
|
||||
("prelu", pointwise0_fp32 + element0_fp32),
|
||||
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
|
||||
("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import torch.cuda
|
||||
from torch.testing._internal.common_utils import TEST_NUMBA
|
||||
import inspect
|
||||
import contextlib
|
||||
|
||||
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
|
|
@ -49,6 +50,31 @@ def tf32_is_not_fp32():
|
|||
return True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_off():
|
||||
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
|
||||
yield
|
||||
finally:
|
||||
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tf32_on(self, tf32_precision=1e-5):
|
||||
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
||||
old_precison = self.precision
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
self.precision = tf32_precision
|
||||
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
|
||||
yield
|
||||
finally:
|
||||
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
||||
self.precision = old_precison
|
||||
|
||||
|
||||
# This is a wrapper that wraps a test to run this test twice, one with
|
||||
# allow_tf32=True, another with allow_tf32=False. When running with
|
||||
# allow_tf32=True, it will use reduced precision as pecified by the
|
||||
|
|
@ -64,33 +90,22 @@ def tf32_is_not_fp32():
|
|||
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
|
||||
# precision to check values.
|
||||
def tf32_on_and_off(tf32_precision=1e-5):
|
||||
def call_with_tf32_on_and_off(self, function_call):
|
||||
old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
|
||||
old_precison = self.precision
|
||||
|
||||
def with_tf32_disabled():
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
def with_tf32_disabled(self, function_call):
|
||||
with tf32_off():
|
||||
function_call()
|
||||
|
||||
def with_tf32_enabled():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
self.precision = tf32_precision
|
||||
def with_tf32_enabled(self, function_call):
|
||||
with tf32_on(self, tf32_precision):
|
||||
function_call()
|
||||
|
||||
try:
|
||||
with_tf32_disabled()
|
||||
with_tf32_enabled()
|
||||
finally:
|
||||
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32
|
||||
self.precision = old_precison
|
||||
|
||||
def wrapper(f):
|
||||
nargs = len(inspect.signature(f).parameters)
|
||||
if nargs == 2:
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, device):
|
||||
if self.device_type == 'cuda' and tf32_is_not_fp32():
|
||||
call_with_tf32_on_and_off(self, lambda: f(self, device))
|
||||
with_tf32_disabled(self, lambda: f(self, device))
|
||||
with_tf32_enabled(self, lambda: f(self, device))
|
||||
else:
|
||||
f(self, device)
|
||||
else:
|
||||
|
|
@ -99,7 +114,8 @@ def tf32_on_and_off(tf32_precision=1e-5):
|
|||
@functools.wraps(f)
|
||||
def wrapped(self, device, dtype):
|
||||
if self.device_type == 'cuda' and dtype in {torch.float32, torch.complex64} and tf32_is_not_fp32():
|
||||
call_with_tf32_on_and_off(self, lambda: f(self, device, dtype))
|
||||
with_tf32_disabled(self, lambda: f(self, device, dtype))
|
||||
with_tf32_enabled(self, lambda: f(self, device, dtype))
|
||||
else:
|
||||
f(self, device, dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -106,6 +106,8 @@ module_tests = [
|
|||
cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
|
||||
input_size=(4, 10),
|
||||
reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Linear',
|
||||
|
|
@ -113,7 +115,9 @@ module_tests = [
|
|||
cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
|
||||
input_size=(4, 10),
|
||||
desc='no_bias',
|
||||
reference_fn=lambda i, p, _: torch.mm(i, p[0].t())
|
||||
reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Threshold',
|
||||
|
|
@ -1595,6 +1599,7 @@ new_module_tests = [
|
|||
cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
|
||||
input_size=(2, 4, 10),
|
||||
cudnn=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv1d',
|
||||
|
|
@ -1603,6 +1608,8 @@ new_module_tests = [
|
|||
input_size=(2, 4, 10),
|
||||
cudnn=True,
|
||||
desc='stride',
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv1d',
|
||||
|
|
@ -1611,6 +1618,7 @@ new_module_tests = [
|
|||
input_size=(2, 4, 10),
|
||||
cudnn=True,
|
||||
desc='pad1',
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv1d',
|
||||
|
|
@ -1619,6 +1627,7 @@ new_module_tests = [
|
|||
input_size=(2, 4, 10),
|
||||
cudnn=True,
|
||||
desc='pad2',
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv1d',
|
||||
|
|
@ -1627,6 +1636,7 @@ new_module_tests = [
|
|||
input_size=(1, 4, 1),
|
||||
cudnn=True,
|
||||
desc='pad1size1',
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv1d',
|
||||
|
|
@ -1635,6 +1645,7 @@ new_module_tests = [
|
|||
input_size=(1, 4, 1),
|
||||
cudnn=True,
|
||||
desc='pad2size1',
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv1d',
|
||||
|
|
@ -1644,12 +1655,14 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='zero_batch',
|
||||
test_cuda=(not TEST_WITH_ROCM),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv1d_dilated',
|
||||
constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
|
||||
cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
|
||||
input_size=(2, 4, 10),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv1d_groups',
|
||||
|
|
@ -1657,6 +1670,7 @@ new_module_tests = [
|
|||
cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
|
||||
input_size=(2, 4, 6),
|
||||
cudnn=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='ConvTranspose1d',
|
||||
|
|
@ -1664,6 +1678,8 @@ new_module_tests = [
|
|||
cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
|
||||
cudnn=True,
|
||||
input_size=(1, 3, 7),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose1d',
|
||||
|
|
@ -1673,6 +1689,8 @@ new_module_tests = [
|
|||
input_size=(1, 3, 6),
|
||||
cudnn=True,
|
||||
desc='no_bias',
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose1d',
|
||||
|
|
@ -1682,6 +1700,7 @@ new_module_tests = [
|
|||
input_size=(1, 3, 6),
|
||||
cudnn=True,
|
||||
desc='dilated',
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='ConvTranspose1d_groups',
|
||||
|
|
@ -1690,6 +1709,8 @@ new_module_tests = [
|
|||
.stride(3).padding(1).output_padding(1).groups(2)''',
|
||||
cudnn=True,
|
||||
input_size=(2, 4, 7),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='MaxPool1d',
|
||||
|
|
@ -1711,6 +1732,7 @@ new_module_tests = [
|
|||
input_size=(2, 3, 7, 5),
|
||||
cudnn=True,
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
|
|
@ -1720,6 +1742,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='strided',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
|
|
@ -1729,6 +1753,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='padding',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
|
|
@ -1738,6 +1764,7 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='dilated',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
|
|
@ -1748,6 +1775,7 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='no_bias',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv2d',
|
||||
|
|
@ -1758,6 +1786,7 @@ new_module_tests = [
|
|||
desc='zero_batch',
|
||||
check_with_long_tensor=True,
|
||||
test_cuda=(not TEST_WITH_ROCM),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_groups',
|
||||
|
|
@ -1766,6 +1795,7 @@ new_module_tests = [
|
|||
input_size=(2, 4, 6, 5),
|
||||
cudnn=True,
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_groups_thnn',
|
||||
|
|
@ -1773,6 +1803,7 @@ new_module_tests = [
|
|||
cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
|
||||
input_size=(2, 4, 6, 5),
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose2d',
|
||||
|
|
@ -1782,6 +1813,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
input_size=(1, 3, 7, 6),
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose2d',
|
||||
|
|
@ -1797,6 +1830,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='dilated',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose2d',
|
||||
|
|
@ -1807,6 +1842,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='no_bias',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
fullname='ConvTranspose2d_groups',
|
||||
|
|
@ -1815,36 +1852,47 @@ new_module_tests = [
|
|||
input_size=(1, 2, 4, 5),
|
||||
cudnn=True,
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_depthwise',
|
||||
constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
|
||||
cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
|
||||
input_size=(2, 4, 6, 6),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_depthwise_with_multiplier',
|
||||
constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
|
||||
cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
|
||||
input_size=(2, 4, 6, 6),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_depthwise_strided',
|
||||
constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
|
||||
cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
|
||||
input_size=(2, 4, 6, 6),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_depthwise_padded',
|
||||
constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
|
||||
cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
|
||||
input_size=(2, 4, 6, 6),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv2d_depthwise_dilated',
|
||||
constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
|
||||
cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
|
||||
input_size=(2, 4, 5, 5),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='MaxPool2d',
|
||||
|
|
@ -2067,6 +2115,8 @@ new_module_tests = [
|
|||
input_size=(1, 2, 4, 5, 4),
|
||||
cudnn=True,
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv3d',
|
||||
|
|
@ -2077,6 +2127,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='no_bias',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.05,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv3d',
|
||||
|
|
@ -2086,6 +2138,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='stride',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv3d',
|
||||
|
|
@ -2095,6 +2149,8 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
desc='stride_padding',
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.01,
|
||||
),
|
||||
dict(
|
||||
module_name='Conv3d',
|
||||
|
|
@ -2105,6 +2161,7 @@ new_module_tests = [
|
|||
check_with_long_tensor=True,
|
||||
desc='zero_batch',
|
||||
test_cuda=(not TEST_WITH_ROCM),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv3d_groups',
|
||||
|
|
@ -2113,18 +2170,22 @@ new_module_tests = [
|
|||
input_size=(1, 2, 4, 5, 4),
|
||||
cudnn=True,
|
||||
check_with_long_tensor=True,
|
||||
with_tf32=True,
|
||||
tf32_precision=0.005,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv3d_dilated',
|
||||
constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
|
||||
cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
|
||||
input_size=(2, 3, 5, 5, 5),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
fullname='Conv3d_dilated_strided',
|
||||
constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
|
||||
cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
|
||||
input_size=(2, 3, 5, 5, 5),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose3d',
|
||||
|
|
@ -2132,6 +2193,7 @@ new_module_tests = [
|
|||
cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
|
||||
cudnn=True,
|
||||
input_size=(1, 2, 4, 5, 4),
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='ConvTranspose3d',
|
||||
|
|
@ -2141,6 +2203,7 @@ new_module_tests = [
|
|||
cudnn=True,
|
||||
input_size=(1, 2, 4, 5, 4),
|
||||
desc='dilated',
|
||||
with_tf32=True,
|
||||
),
|
||||
dict(
|
||||
module_name='MaxPool3d',
|
||||
|
|
@ -3458,6 +3521,8 @@ for padding_mode, cpp_padding_mode in zip(
|
|||
output_size=output_size,
|
||||
cudnn=True,
|
||||
desc='{}_stride2_pad2'.format(padding_mode),
|
||||
with_tf32=True,
|
||||
tf32_precision=0.05
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -4876,6 +4941,8 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
|||
self.check_inplace = kwargs.get('check_inplace', False)
|
||||
self.check_gradgrad = kwargs.get('check_gradgrad', True)
|
||||
self.skip_double = kwargs.get('skip_double', False)
|
||||
self.with_tf32 = kwargs.get('with_tf32', False)
|
||||
self.tf32_precision = kwargs.get('tf32_precision', 0.001)
|
||||
self.test_cpu = kwargs.get('test_cpu', True)
|
||||
|
||||
def _do_test(self, test_case, module, input):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user