optimize MulGradient for common shapes (#19705)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19705

Optimizing for a case when there's a consecutive dims that are not broadcasted followed by another consecutive dims that are broadcasted.
For example, MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0) where A.shape == dC.shape == [9508, 80] and B.shape == [80] .

Test Plan:
In SKL T6,

Running mul_gradient_benchmark without this optimization
Operator #0 (dA, MulGradient) 11.9119 ms/iter

After this optimization,
Operator #0 (dA, MulGradient) 0.672759 ms/iter

Need to land D15291800 before to fix the unit test error

Reviewed By: dmudiger

Differential Revision: D15075415

fbshipit-source-id: 0f97be17cf8f1dacbafa34cd637fb8bc1c5e5387
This commit is contained in:
Yuchen Hao 2019-12-11 11:36:14 -08:00 committed by Facebook Github Bot
parent a53b39f09d
commit 4a751dfc20
3 changed files with 251 additions and 81 deletions

View File

@ -7,6 +7,219 @@
namespace caffe2 { namespace caffe2 {
namespace {
template <typename TGrad, typename TIn>
void ComputeMulGradient(
const int ndim,
const int* A_dims,
const int* B_dims,
const int* C_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
TGrad* dA,
TGrad* dB,
CPUContext* context) {
const int A_size =
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
const int B_size =
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
const int C_size =
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
std::vector<int> index(ndim, 0);
for (int C_index = 0; C_index < C_size; ++C_index) {
const int A_index =
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
const int B_index =
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
dA[A_index] += dC[C_index] * B[B_index];
dB[B_index] += dC[C_index] * A[A_index];
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
}
}
// A : input not to broadcast whose size is common_size x broadcast_size
// B : input to broadcast whose size is common_size
void ComputeMulGradient(
const int common_size,
const int broadcast_size,
const float* dC,
const float* A,
const float* B,
float* dA,
float* dB,
CPUContext* context) {
for (int i = 0; i < common_size; ++i) {
caffe2::math::Scale(
broadcast_size,
B[i],
dC + i * broadcast_size,
dA + i * broadcast_size,
context);
caffe2::math::Dot(
broadcast_size,
dC + i * broadcast_size,
A + i * broadcast_size,
dB + i,
context);
}
}
void ComputeMulGradient(
const int size,
const float* dC,
const float* A,
const float* B,
float* dA,
float* dB) {
for (int i = 0; i < size; ++i) {
dA[i] = dC[i] * B[i];
dB[i] = dC[i] * A[i];
}
}
} // namespace
template <>
template <typename TGrad, typename TIn, typename TOut>
bool MulFunctor<CPUContext>::Backward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
const TOut* /* C */,
TGrad* dA,
TGrad* dB,
CPUContext* context) const {
if (A_dims == B_dims) {
const int size = std::accumulate(
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
math::Mul(size, dC, B, dA, context);
math::Mul(size, dC, A, dB, context);
return true;
}
const int ndim = std::max(A_dims.size(), B_dims.size());
if (ndim == 0) {
return true;
}
std::vector<int> A_broadcast_dims(ndim);
std::vector<int> B_broadcast_dims(ndim);
std::vector<int> C_broadcast_dims(ndim);
math::utils::ComputeBroadcastBinaryOpDims(
A_dims.size(),
A_dims.data(),
B_dims.size(),
B_dims.data(),
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data());
const int C_size = std::accumulate(
C_broadcast_dims.cbegin(),
C_broadcast_dims.cbegin() + ndim,
1,
std::multiplies<int>());
if (C_size == 0) {
const int A_size = std::accumulate(
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
const int B_size = std::accumulate(
B_dims.cbegin(), B_dims.cend(), 1, std::multiplies<int>());
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
return true;
}
// Flatten dims as much as possible
// We call A is broadcasted at dim d if A_broadcast_dims[d] <= 1
// Two consecutive dims d and d+1 can be flattened if
// A and B are broadcasted at dim d, or
// A and B are broadcasted at dim d + 1, or
// A is broadcasted at dim d and d + 1, or
// B is broadcasted at dim d and d + 1, or
// A and B are not broadcasted at dim d and d + 1
std::vector<int> A_broadcast_dims_flattened, B_broadcast_dims_flattened,
C_broadcast_dims_flattened;
A_broadcast_dims_flattened.reserve(ndim);
B_broadcast_dims_flattened.reserve(ndim);
A_broadcast_dims_flattened.push_back(A_broadcast_dims[0]);
B_broadcast_dims_flattened.push_back(B_broadcast_dims[0]);
for (int i = 1; i < ndim; ++i) {
int A_old = A_broadcast_dims_flattened.back();
int B_old = B_broadcast_dims_flattened.back();
int A_new = A_broadcast_dims[i];
int B_new = B_broadcast_dims[i];
if ((A_old == 1 && B_old == 1) || (A_new == 1 && B_new == 1) ||
(A_old == 1 && A_new == 1) || (B_old == 1 && B_new == 1) ||
(A_old > 1 && B_old > 1 && A_new > 1 && B_new > 1)) {
A_broadcast_dims_flattened.back() *= A_new;
B_broadcast_dims_flattened.back() *= B_new;
} else {
A_broadcast_dims_flattened.push_back(A_new);
B_broadcast_dims_flattened.push_back(B_new);
}
}
int ndim_flattened = A_broadcast_dims_flattened.size();
C_broadcast_dims_flattened.resize(ndim_flattened);
for (int i = 0; i < ndim_flattened; ++i) {
C_broadcast_dims_flattened[i] =
std::max(A_broadcast_dims_flattened[i], B_broadcast_dims_flattened[i]);
}
if (std::is_same<TGrad, float>::value && std::is_same<TIn, float>::value &&
ndim_flattened <= 2 &&
A_broadcast_dims_flattened[0] == B_broadcast_dims_flattened[0] &&
(ndim_flattened == 1 || A_broadcast_dims_flattened[1] <= 1 ||
B_broadcast_dims_flattened[1] <= 1)) {
if (ndim_flattened == 2) {
// fast path when we have 2 flattened dimensions and the second dimension
// is broadcasted.
bool broadcast_B = B_broadcast_dims_flattened[1] <= 1;
ComputeMulGradient(
C_broadcast_dims_flattened[0],
C_broadcast_dims_flattened[1],
reinterpret_cast<const float*>(dC),
reinterpret_cast<const float*>(broadcast_B ? A : B),
reinterpret_cast<const float*>(broadcast_B ? B : A),
reinterpret_cast<float*>(broadcast_B ? dA : dB),
reinterpret_cast<float*>(broadcast_B ? dB : dA),
context);
} else {
// fast path when we have 1 flattened dimension
assert(ndim_flattened == 1);
ComputeMulGradient(
C_broadcast_dims_flattened[0],
reinterpret_cast<const float*>(dC),
reinterpret_cast<const float*>(A),
reinterpret_cast<const float*>(B),
reinterpret_cast<float*>(dA),
reinterpret_cast<float*>(dB));
}
} else {
ComputeMulGradient<TGrad, TIn>(
ndim_flattened,
A_broadcast_dims_flattened.data(),
B_broadcast_dims_flattened.data(),
C_broadcast_dims_flattened.data(),
dC,
A,
B,
dA,
dB,
context);
}
return true;
}
REGISTER_CPU_OPERATOR( REGISTER_CPU_OPERATOR(
MulGradient, MulGradient,
BinaryElementwiseGradientOp< BinaryElementwiseGradientOp<

View File

@ -8,42 +8,6 @@
namespace caffe2 { namespace caffe2 {
namespace {
template <typename TGrad, typename TIn>
void ComputeMulGradient(
const int ndim,
const int* A_dims,
const int* B_dims,
const int* C_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
TGrad* dA,
TGrad* dB,
CPUContext* context) {
const int A_size =
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
const int B_size =
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
const int C_size =
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
std::vector<int> index(ndim, 0);
for (int C_index = 0; C_index < C_size; ++C_index) {
const int A_index =
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
const int B_index =
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
dA[A_index] += dC[C_index] * B[B_index];
dB[B_index] += dC[C_index] * A[A_index];
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
}
}
} // namespace
template <class Context> template <class Context>
struct MulFunctor { struct MulFunctor {
template <typename TIn, typename TOut> template <typename TIn, typename TOut>
@ -79,51 +43,6 @@ struct MulFunctor {
Context* context) const; Context* context) const;
}; };
template <>
template <typename TGrad, typename TIn, typename TOut>
bool MulFunctor<CPUContext>::Backward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
const TOut* /* C */,
TGrad* dA,
TGrad* dB,
CPUContext* context) const {
if (A_dims == B_dims) {
const int size = std::accumulate(
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
math::Mul(size, dC, B, dA, context);
math::Mul(size, dC, A, dB, context);
return true;
}
const int ndim = std::max(A_dims.size(), B_dims.size());
std::vector<int> A_broadcast_dims(ndim);
std::vector<int> B_broadcast_dims(ndim);
std::vector<int> C_broadcast_dims(ndim);
math::utils::ComputeBroadcastBinaryOpDims(
A_dims.size(),
A_dims.data(),
B_dims.size(),
B_dims.data(),
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data());
ComputeMulGradient<TGrad, TIn>(
ndim,
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data(),
dC,
A,
B,
dA,
dB,
context);
return true;
}
} // namespace caffe2 } // namespace caffe2
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_ #endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_

View File

@ -0,0 +1,38 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import numpy as np
from caffe2.python import core, workspace
def benchmark_mul_gradient(args):
workspace.FeedBlob("dC", np.random.rand(args.m, args.n).astype(np.float32))
workspace.FeedBlob("A", np.random.rand(args.m, args.n).astype(np.float32))
workspace.FeedBlob("B", np.random.rand(args.m).astype(np.float32))
net = core.Net("mynet")
net.MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0)
workspace.CreateNet(net)
workspace.BenchmarkNet(net.Name(), 1, args.iteration, True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="benchmark for MulGradient.")
parser.add_argument(
'-m', type=int, default=9508,
help="The number of rows of A")
parser.add_argument(
"-n", type=int, default=80,
help="The number of columns of A")
parser.add_argument(
'-i', "--iteration", type=int, default=100,
help="The number of iterations.")
args, extra_args = parser.parse_known_args()
core.GlobalInit(['python'] + extra_args)
benchmark_mul_gradient(args)