mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
slow_conv3d grad_weight: call gemm directly (#65759)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65759 Test Plan: Imported from OSS Reviewed By: dagitses Differential Revision: D31257873 Pulled By: ngimel fbshipit-source-id: 1612c0be10b2aa269c807c7b9f5470172ed68dc1
This commit is contained in:
parent
dfb64b3287
commit
0020a151c6
|
|
@ -382,32 +382,32 @@ void slow_conv3d_backward_out_cpu_template(
|
|||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void slow_conv3d_backward_weight_frame(
|
||||
Tensor& grad_weight,
|
||||
Tensor& grad_output,
|
||||
const Tensor& finput,
|
||||
TensorAccessor<scalar_t, 2> grad_weight,
|
||||
TensorAccessor<scalar_t, 4> grad_output,
|
||||
TensorAccessor<scalar_t, 2> finput,
|
||||
int64_t groups) {
|
||||
auto grad_output_2d = groups > 1
|
||||
? grad_output.view(
|
||||
{groups,
|
||||
grad_output.size(0) / groups,
|
||||
grad_output.size(1) * grad_output.size(2) * grad_output.size(3)})
|
||||
: grad_output.view(
|
||||
{grad_output.size(0),
|
||||
grad_output.size(1) * grad_output.size(2) * grad_output.size(3)});
|
||||
// Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
|
||||
// Note gemm expects fortran order, so all 3 matrices are transposed.
|
||||
// Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
|
||||
const int64_t m = grad_weight.size(1);
|
||||
const int64_t n = grad_weight.size(0) / groups;
|
||||
const int64_t k = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
|
||||
|
||||
if (groups > 1) {
|
||||
auto grad_weight_g = grad_weight.reshape(
|
||||
{groups, grad_weight.size(0) / groups, grad_weight.size(1)});
|
||||
Tensor tfinput =
|
||||
finput.reshape({groups, finput.size(0) / groups, finput.size(1)})
|
||||
.permute({0, 2, 1})
|
||||
.contiguous();
|
||||
grad_weight_g.baddbmm_(grad_output_2d, tfinput);
|
||||
} else {
|
||||
const Tensor tfinput = finput.transpose(0, 1);
|
||||
grad_weight.addmm_(grad_output_2d, tfinput);
|
||||
}
|
||||
const int64_t lda = k;
|
||||
const int64_t ldb = k;
|
||||
const int64_t ldc = m;
|
||||
|
||||
at::native::cpublas::gemm_batched_with_stride(
|
||||
TransposeType::Transpose,
|
||||
TransposeType::NoTranspose,
|
||||
groups, m, n, k,
|
||||
static_cast<scalar_t>(1),
|
||||
finput.data(), lda, finput.stride(0) * m,
|
||||
grad_output.data(), ldb, grad_output.stride(0) * n,
|
||||
static_cast<scalar_t>(1),
|
||||
grad_weight.data(), ldc, grad_weight.stride(0) * n);
|
||||
}
|
||||
|
||||
static void slow_conv3d_backward_parameters_out_cpu_template(
|
||||
|
|
@ -456,12 +456,19 @@ static void slow_conv3d_backward_parameters_out_cpu_template(
|
|||
auto grad_output_contiguous = grad_output.contiguous();
|
||||
|
||||
const int64_t batch_size = input.size(0);
|
||||
for (int64_t t = 0; t < batch_size; t++) {
|
||||
Tensor grad_output_t = grad_output_contiguous[t];
|
||||
Tensor finput_t = finput[t];
|
||||
slow_conv3d_backward_weight_frame(
|
||||
grad_weight_2d, grad_output_t, finput_t, groups);
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
kBFloat16, input.scalar_type(), "slow_conv3d_cpu_grad_weight", [&] {
|
||||
auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
|
||||
auto grad_output_a = grad_output_contiguous.accessor<scalar_t, 5>();
|
||||
auto finput_a = finput.accessor<scalar_t, 3>();
|
||||
for (int64_t t = 0; t < batch_size; t++) {
|
||||
auto grad_output_t = grad_output_a[t];
|
||||
auto finput_t = finput_a[t];
|
||||
slow_conv3d_backward_weight_frame(
|
||||
grad_weight_2d_a, grad_output_t, finput_t, groups);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user