mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
optimize MulGradient for common shapes (#19705)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19705 Optimizing for a case when there's a consecutive dims that are not broadcasted followed by another consecutive dims that are broadcasted. For example, MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0) where A.shape == dC.shape == [9508, 80] and B.shape == [80] . Test Plan: In SKL T6, Running mul_gradient_benchmark without this optimization Operator #0 (dA, MulGradient) 11.9119 ms/iter After this optimization, Operator #0 (dA, MulGradient) 0.672759 ms/iter Need to land D15291800 before to fix the unit test error Reviewed By: dmudiger Differential Revision: D15075415 fbshipit-source-id: 0f97be17cf8f1dacbafa34cd637fb8bc1c5e5387
This commit is contained in:
parent
a53b39f09d
commit
4a751dfc20
|
|
@ -7,6 +7,219 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TGrad, typename TIn>
|
||||
void ComputeMulGradient(
|
||||
const int ndim,
|
||||
const int* A_dims,
|
||||
const int* B_dims,
|
||||
const int* C_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) {
|
||||
const int A_size =
|
||||
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
||||
const int B_size =
|
||||
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
||||
const int C_size =
|
||||
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int C_index = 0; C_index < C_size; ++C_index) {
|
||||
const int A_index =
|
||||
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
||||
const int B_index =
|
||||
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
||||
dA[A_index] += dC[C_index] * B[B_index];
|
||||
dB[B_index] += dC[C_index] * A[A_index];
|
||||
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
||||
}
|
||||
}
|
||||
|
||||
// A : input not to broadcast whose size is common_size x broadcast_size
|
||||
// B : input to broadcast whose size is common_size
|
||||
void ComputeMulGradient(
|
||||
const int common_size,
|
||||
const int broadcast_size,
|
||||
const float* dC,
|
||||
const float* A,
|
||||
const float* B,
|
||||
float* dA,
|
||||
float* dB,
|
||||
CPUContext* context) {
|
||||
for (int i = 0; i < common_size; ++i) {
|
||||
caffe2::math::Scale(
|
||||
broadcast_size,
|
||||
B[i],
|
||||
dC + i * broadcast_size,
|
||||
dA + i * broadcast_size,
|
||||
context);
|
||||
caffe2::math::Dot(
|
||||
broadcast_size,
|
||||
dC + i * broadcast_size,
|
||||
A + i * broadcast_size,
|
||||
dB + i,
|
||||
context);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeMulGradient(
|
||||
const int size,
|
||||
const float* dC,
|
||||
const float* A,
|
||||
const float* B,
|
||||
float* dA,
|
||||
float* dB) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
dA[i] = dC[i] * B[i];
|
||||
dB[i] = dC[i] * A[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
template <typename TGrad, typename TIn, typename TOut>
|
||||
bool MulFunctor<CPUContext>::Backward(
|
||||
const std::vector<int>& A_dims,
|
||||
const std::vector<int>& B_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
const TOut* /* C */,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) const {
|
||||
if (A_dims == B_dims) {
|
||||
const int size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
math::Mul(size, dC, B, dA, context);
|
||||
math::Mul(size, dC, A, dB, context);
|
||||
return true;
|
||||
}
|
||||
|
||||
const int ndim = std::max(A_dims.size(), B_dims.size());
|
||||
if (ndim == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> A_broadcast_dims(ndim);
|
||||
std::vector<int> B_broadcast_dims(ndim);
|
||||
std::vector<int> C_broadcast_dims(ndim);
|
||||
math::utils::ComputeBroadcastBinaryOpDims(
|
||||
A_dims.size(),
|
||||
A_dims.data(),
|
||||
B_dims.size(),
|
||||
B_dims.data(),
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data());
|
||||
|
||||
const int C_size = std::accumulate(
|
||||
C_broadcast_dims.cbegin(),
|
||||
C_broadcast_dims.cbegin() + ndim,
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
if (C_size == 0) {
|
||||
const int A_size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
const int B_size = std::accumulate(
|
||||
B_dims.cbegin(), B_dims.cend(), 1, std::multiplies<int>());
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Flatten dims as much as possible
|
||||
// We call A is broadcasted at dim d if A_broadcast_dims[d] <= 1
|
||||
// Two consecutive dims d and d+1 can be flattened if
|
||||
// A and B are broadcasted at dim d, or
|
||||
// A and B are broadcasted at dim d + 1, or
|
||||
// A is broadcasted at dim d and d + 1, or
|
||||
// B is broadcasted at dim d and d + 1, or
|
||||
// A and B are not broadcasted at dim d and d + 1
|
||||
std::vector<int> A_broadcast_dims_flattened, B_broadcast_dims_flattened,
|
||||
C_broadcast_dims_flattened;
|
||||
A_broadcast_dims_flattened.reserve(ndim);
|
||||
B_broadcast_dims_flattened.reserve(ndim);
|
||||
|
||||
A_broadcast_dims_flattened.push_back(A_broadcast_dims[0]);
|
||||
B_broadcast_dims_flattened.push_back(B_broadcast_dims[0]);
|
||||
|
||||
for (int i = 1; i < ndim; ++i) {
|
||||
int A_old = A_broadcast_dims_flattened.back();
|
||||
int B_old = B_broadcast_dims_flattened.back();
|
||||
int A_new = A_broadcast_dims[i];
|
||||
int B_new = B_broadcast_dims[i];
|
||||
if ((A_old == 1 && B_old == 1) || (A_new == 1 && B_new == 1) ||
|
||||
(A_old == 1 && A_new == 1) || (B_old == 1 && B_new == 1) ||
|
||||
(A_old > 1 && B_old > 1 && A_new > 1 && B_new > 1)) {
|
||||
A_broadcast_dims_flattened.back() *= A_new;
|
||||
B_broadcast_dims_flattened.back() *= B_new;
|
||||
} else {
|
||||
A_broadcast_dims_flattened.push_back(A_new);
|
||||
B_broadcast_dims_flattened.push_back(B_new);
|
||||
}
|
||||
}
|
||||
|
||||
int ndim_flattened = A_broadcast_dims_flattened.size();
|
||||
C_broadcast_dims_flattened.resize(ndim_flattened);
|
||||
for (int i = 0; i < ndim_flattened; ++i) {
|
||||
C_broadcast_dims_flattened[i] =
|
||||
std::max(A_broadcast_dims_flattened[i], B_broadcast_dims_flattened[i]);
|
||||
}
|
||||
|
||||
if (std::is_same<TGrad, float>::value && std::is_same<TIn, float>::value &&
|
||||
ndim_flattened <= 2 &&
|
||||
A_broadcast_dims_flattened[0] == B_broadcast_dims_flattened[0] &&
|
||||
(ndim_flattened == 1 || A_broadcast_dims_flattened[1] <= 1 ||
|
||||
B_broadcast_dims_flattened[1] <= 1)) {
|
||||
if (ndim_flattened == 2) {
|
||||
// fast path when we have 2 flattened dimensions and the second dimension
|
||||
// is broadcasted.
|
||||
bool broadcast_B = B_broadcast_dims_flattened[1] <= 1;
|
||||
ComputeMulGradient(
|
||||
C_broadcast_dims_flattened[0],
|
||||
C_broadcast_dims_flattened[1],
|
||||
reinterpret_cast<const float*>(dC),
|
||||
reinterpret_cast<const float*>(broadcast_B ? A : B),
|
||||
reinterpret_cast<const float*>(broadcast_B ? B : A),
|
||||
reinterpret_cast<float*>(broadcast_B ? dA : dB),
|
||||
reinterpret_cast<float*>(broadcast_B ? dB : dA),
|
||||
context);
|
||||
} else {
|
||||
// fast path when we have 1 flattened dimension
|
||||
assert(ndim_flattened == 1);
|
||||
ComputeMulGradient(
|
||||
C_broadcast_dims_flattened[0],
|
||||
reinterpret_cast<const float*>(dC),
|
||||
reinterpret_cast<const float*>(A),
|
||||
reinterpret_cast<const float*>(B),
|
||||
reinterpret_cast<float*>(dA),
|
||||
reinterpret_cast<float*>(dB));
|
||||
}
|
||||
} else {
|
||||
ComputeMulGradient<TGrad, TIn>(
|
||||
ndim_flattened,
|
||||
A_broadcast_dims_flattened.data(),
|
||||
B_broadcast_dims_flattened.data(),
|
||||
C_broadcast_dims_flattened.data(),
|
||||
dC,
|
||||
A,
|
||||
B,
|
||||
dA,
|
||||
dB,
|
||||
context);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR(
|
||||
MulGradient,
|
||||
BinaryElementwiseGradientOp<
|
||||
|
|
|
|||
|
|
@ -8,42 +8,6 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TGrad, typename TIn>
|
||||
void ComputeMulGradient(
|
||||
const int ndim,
|
||||
const int* A_dims,
|
||||
const int* B_dims,
|
||||
const int* C_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) {
|
||||
const int A_size =
|
||||
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
||||
const int B_size =
|
||||
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
||||
const int C_size =
|
||||
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int C_index = 0; C_index < C_size; ++C_index) {
|
||||
const int A_index =
|
||||
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
||||
const int B_index =
|
||||
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
||||
dA[A_index] += dC[C_index] * B[B_index];
|
||||
dB[B_index] += dC[C_index] * A[A_index];
|
||||
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class Context>
|
||||
struct MulFunctor {
|
||||
template <typename TIn, typename TOut>
|
||||
|
|
@ -79,51 +43,6 @@ struct MulFunctor {
|
|||
Context* context) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
template <typename TGrad, typename TIn, typename TOut>
|
||||
bool MulFunctor<CPUContext>::Backward(
|
||||
const std::vector<int>& A_dims,
|
||||
const std::vector<int>& B_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
const TOut* /* C */,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) const {
|
||||
if (A_dims == B_dims) {
|
||||
const int size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
math::Mul(size, dC, B, dA, context);
|
||||
math::Mul(size, dC, A, dB, context);
|
||||
return true;
|
||||
}
|
||||
const int ndim = std::max(A_dims.size(), B_dims.size());
|
||||
std::vector<int> A_broadcast_dims(ndim);
|
||||
std::vector<int> B_broadcast_dims(ndim);
|
||||
std::vector<int> C_broadcast_dims(ndim);
|
||||
math::utils::ComputeBroadcastBinaryOpDims(
|
||||
A_dims.size(),
|
||||
A_dims.data(),
|
||||
B_dims.size(),
|
||||
B_dims.data(),
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data());
|
||||
ComputeMulGradient<TGrad, TIn>(
|
||||
ndim,
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data(),
|
||||
dC,
|
||||
A,
|
||||
B,
|
||||
dA,
|
||||
dB,
|
||||
context);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|
||||
|
|
|
|||
38
caffe2/python/operator_test/mul_gradient_benchmark.py
Normal file
38
caffe2/python/operator_test/mul_gradient_benchmark.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
|
||||
def benchmark_mul_gradient(args):
|
||||
workspace.FeedBlob("dC", np.random.rand(args.m, args.n).astype(np.float32))
|
||||
workspace.FeedBlob("A", np.random.rand(args.m, args.n).astype(np.float32))
|
||||
workspace.FeedBlob("B", np.random.rand(args.m).astype(np.float32))
|
||||
|
||||
net = core.Net("mynet")
|
||||
net.MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0)
|
||||
workspace.CreateNet(net)
|
||||
|
||||
workspace.BenchmarkNet(net.Name(), 1, args.iteration, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="benchmark for MulGradient.")
|
||||
parser.add_argument(
|
||||
'-m', type=int, default=9508,
|
||||
help="The number of rows of A")
|
||||
parser.add_argument(
|
||||
"-n", type=int, default=80,
|
||||
help="The number of columns of A")
|
||||
parser.add_argument(
|
||||
'-i', "--iteration", type=int, default=100,
|
||||
help="The number of iterations.")
|
||||
args, extra_args = parser.parse_known_args()
|
||||
core.GlobalInit(['python'] + extra_args)
|
||||
benchmark_mul_gradient(args)
|
||||
Loading…
Reference in New Issue
Block a user