slow_conv3d grad_weight: call gemm directly (#65759)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65759

Test Plan: Imported from OSS

Reviewed By: dagitses

Differential Revision: D31257873

Pulled By: ngimel

fbshipit-source-id: 1612c0be10b2aa269c807c7b9f5470172ed68dc1
This commit is contained in:
Peter Bell 2021-10-08 09:53:28 -07:00 committed by Facebook GitHub Bot
parent dfb64b3287
commit 0020a151c6

View File

@ -382,32 +382,32 @@ void slow_conv3d_backward_out_cpu_template(
});
}
template <typename scalar_t>
void slow_conv3d_backward_weight_frame(
Tensor& grad_weight,
Tensor& grad_output,
const Tensor& finput,
TensorAccessor<scalar_t, 2> grad_weight,
TensorAccessor<scalar_t, 4> grad_output,
TensorAccessor<scalar_t, 2> finput,
int64_t groups) {
auto grad_output_2d = groups > 1
? grad_output.view(
{groups,
grad_output.size(0) / groups,
grad_output.size(1) * grad_output.size(2) * grad_output.size(3)})
: grad_output.view(
{grad_output.size(0),
grad_output.size(1) * grad_output.size(2) * grad_output.size(3)});
// Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
// Note gemm expects fortran order, so all 3 matrices are transposed.
// Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
const int64_t m = grad_weight.size(1);
const int64_t n = grad_weight.size(0) / groups;
const int64_t k = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
if (groups > 1) {
auto grad_weight_g = grad_weight.reshape(
{groups, grad_weight.size(0) / groups, grad_weight.size(1)});
Tensor tfinput =
finput.reshape({groups, finput.size(0) / groups, finput.size(1)})
.permute({0, 2, 1})
.contiguous();
grad_weight_g.baddbmm_(grad_output_2d, tfinput);
} else {
const Tensor tfinput = finput.transpose(0, 1);
grad_weight.addmm_(grad_output_2d, tfinput);
}
const int64_t lda = k;
const int64_t ldb = k;
const int64_t ldc = m;
at::native::cpublas::gemm_batched_with_stride(
TransposeType::Transpose,
TransposeType::NoTranspose,
groups, m, n, k,
static_cast<scalar_t>(1),
finput.data(), lda, finput.stride(0) * m,
grad_output.data(), ldb, grad_output.stride(0) * n,
static_cast<scalar_t>(1),
grad_weight.data(), ldc, grad_weight.stride(0) * n);
}
static void slow_conv3d_backward_parameters_out_cpu_template(
@ -456,12 +456,19 @@ static void slow_conv3d_backward_parameters_out_cpu_template(
auto grad_output_contiguous = grad_output.contiguous();
const int64_t batch_size = input.size(0);
for (int64_t t = 0; t < batch_size; t++) {
Tensor grad_output_t = grad_output_contiguous[t];
Tensor finput_t = finput[t];
slow_conv3d_backward_weight_frame(
grad_weight_2d, grad_output_t, finput_t, groups);
}
AT_DISPATCH_FLOATING_TYPES_AND(
kBFloat16, input.scalar_type(), "slow_conv3d_cpu_grad_weight", [&] {
auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
auto grad_output_a = grad_output_contiguous.accessor<scalar_t, 5>();
auto finput_a = finput.accessor<scalar_t, 3>();
for (int64_t t = 0; t < batch_size; t++) {
auto grad_output_t = grad_output_a[t];
auto finput_t = finput_a[t];
slow_conv3d_backward_weight_frame(
grad_weight_2d_a, grad_output_t, finput_t, groups);
}
});
}
} // namespace