slow_conv3d grad_weight: call gemm directly (#65759)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65759

Test Plan: Imported from OSS

Reviewed By: dagitses

Differential Revision: D31257873

Pulled By: ngimel

fbshipit-source-id: 1612c0be10b2aa269c807c7b9f5470172ed68dc1
This commit is contained in:
Peter Bell 2021-10-08 09:53:28 -07:00 committed by Facebook GitHub Bot
parent dfb64b3287
commit 0020a151c6

View File

@ -382,32 +382,32 @@ void slow_conv3d_backward_out_cpu_template(
}); });
} }
template <typename scalar_t>
void slow_conv3d_backward_weight_frame( void slow_conv3d_backward_weight_frame(
Tensor& grad_weight, TensorAccessor<scalar_t, 2> grad_weight,
Tensor& grad_output, TensorAccessor<scalar_t, 4> grad_output,
const Tensor& finput, TensorAccessor<scalar_t, 2> finput,
int64_t groups) { int64_t groups) {
auto grad_output_2d = groups > 1 // Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
? grad_output.view( // Note gemm expects fortran order, so all 3 matrices are transposed.
{groups, // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
grad_output.size(0) / groups, const int64_t m = grad_weight.size(1);
grad_output.size(1) * grad_output.size(2) * grad_output.size(3)}) const int64_t n = grad_weight.size(0) / groups;
: grad_output.view( const int64_t k = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
{grad_output.size(0),
grad_output.size(1) * grad_output.size(2) * grad_output.size(3)});
if (groups > 1) { const int64_t lda = k;
auto grad_weight_g = grad_weight.reshape( const int64_t ldb = k;
{groups, grad_weight.size(0) / groups, grad_weight.size(1)}); const int64_t ldc = m;
Tensor tfinput =
finput.reshape({groups, finput.size(0) / groups, finput.size(1)}) at::native::cpublas::gemm_batched_with_stride(
.permute({0, 2, 1}) TransposeType::Transpose,
.contiguous(); TransposeType::NoTranspose,
grad_weight_g.baddbmm_(grad_output_2d, tfinput); groups, m, n, k,
} else { static_cast<scalar_t>(1),
const Tensor tfinput = finput.transpose(0, 1); finput.data(), lda, finput.stride(0) * m,
grad_weight.addmm_(grad_output_2d, tfinput); grad_output.data(), ldb, grad_output.stride(0) * n,
} static_cast<scalar_t>(1),
grad_weight.data(), ldc, grad_weight.stride(0) * n);
} }
static void slow_conv3d_backward_parameters_out_cpu_template( static void slow_conv3d_backward_parameters_out_cpu_template(
@ -456,12 +456,19 @@ static void slow_conv3d_backward_parameters_out_cpu_template(
auto grad_output_contiguous = grad_output.contiguous(); auto grad_output_contiguous = grad_output.contiguous();
const int64_t batch_size = input.size(0); const int64_t batch_size = input.size(0);
for (int64_t t = 0; t < batch_size; t++) {
Tensor grad_output_t = grad_output_contiguous[t]; AT_DISPATCH_FLOATING_TYPES_AND(
Tensor finput_t = finput[t]; kBFloat16, input.scalar_type(), "slow_conv3d_cpu_grad_weight", [&] {
slow_conv3d_backward_weight_frame( auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
grad_weight_2d, grad_output_t, finput_t, groups); auto grad_output_a = grad_output_contiguous.accessor<scalar_t, 5>();
} auto finput_a = finput.accessor<scalar_t, 3>();
for (int64_t t = 0; t < batch_size; t++) {
auto grad_output_t = grad_output_a[t];
auto finput_t = finput_a[t];
slow_conv3d_backward_weight_frame(
grad_weight_2d_a, grad_output_t, finput_t, groups);
}
});
} }
} // namespace } // namespace