mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
optimize MulGradient for common shapes (#19705)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19705 Optimizing for a case when there's a consecutive dims that are not broadcasted followed by another consecutive dims that are broadcasted. For example, MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0) where A.shape == dC.shape == [9508, 80] and B.shape == [80] . Test Plan: In SKL T6, Running mul_gradient_benchmark without this optimization Operator #0 (dA, MulGradient) 11.9119 ms/iter After this optimization, Operator #0 (dA, MulGradient) 0.672759 ms/iter Need to land D15291800 before to fix the unit test error Reviewed By: dmudiger Differential Revision: D15075415 fbshipit-source-id: 0f97be17cf8f1dacbafa34cd637fb8bc1c5e5387
This commit is contained in:
parent
a53b39f09d
commit
4a751dfc20
|
|
@ -7,6 +7,219 @@
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename TGrad, typename TIn>
|
||||||
|
void ComputeMulGradient(
|
||||||
|
const int ndim,
|
||||||
|
const int* A_dims,
|
||||||
|
const int* B_dims,
|
||||||
|
const int* C_dims,
|
||||||
|
const TGrad* dC,
|
||||||
|
const TIn* A,
|
||||||
|
const TIn* B,
|
||||||
|
TGrad* dA,
|
||||||
|
TGrad* dB,
|
||||||
|
CPUContext* context) {
|
||||||
|
const int A_size =
|
||||||
|
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
||||||
|
const int B_size =
|
||||||
|
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
||||||
|
const int C_size =
|
||||||
|
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
||||||
|
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||||
|
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||||
|
std::vector<int> index(ndim, 0);
|
||||||
|
for (int C_index = 0; C_index < C_size; ++C_index) {
|
||||||
|
const int A_index =
|
||||||
|
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
||||||
|
const int B_index =
|
||||||
|
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
||||||
|
dA[A_index] += dC[C_index] * B[B_index];
|
||||||
|
dB[B_index] += dC[C_index] * A[A_index];
|
||||||
|
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A : input not to broadcast whose size is common_size x broadcast_size
|
||||||
|
// B : input to broadcast whose size is common_size
|
||||||
|
void ComputeMulGradient(
|
||||||
|
const int common_size,
|
||||||
|
const int broadcast_size,
|
||||||
|
const float* dC,
|
||||||
|
const float* A,
|
||||||
|
const float* B,
|
||||||
|
float* dA,
|
||||||
|
float* dB,
|
||||||
|
CPUContext* context) {
|
||||||
|
for (int i = 0; i < common_size; ++i) {
|
||||||
|
caffe2::math::Scale(
|
||||||
|
broadcast_size,
|
||||||
|
B[i],
|
||||||
|
dC + i * broadcast_size,
|
||||||
|
dA + i * broadcast_size,
|
||||||
|
context);
|
||||||
|
caffe2::math::Dot(
|
||||||
|
broadcast_size,
|
||||||
|
dC + i * broadcast_size,
|
||||||
|
A + i * broadcast_size,
|
||||||
|
dB + i,
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputeMulGradient(
|
||||||
|
const int size,
|
||||||
|
const float* dC,
|
||||||
|
const float* A,
|
||||||
|
const float* B,
|
||||||
|
float* dA,
|
||||||
|
float* dB) {
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
dA[i] = dC[i] * B[i];
|
||||||
|
dB[i] = dC[i] * A[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <>
|
||||||
|
template <typename TGrad, typename TIn, typename TOut>
|
||||||
|
bool MulFunctor<CPUContext>::Backward(
|
||||||
|
const std::vector<int>& A_dims,
|
||||||
|
const std::vector<int>& B_dims,
|
||||||
|
const TGrad* dC,
|
||||||
|
const TIn* A,
|
||||||
|
const TIn* B,
|
||||||
|
const TOut* /* C */,
|
||||||
|
TGrad* dA,
|
||||||
|
TGrad* dB,
|
||||||
|
CPUContext* context) const {
|
||||||
|
if (A_dims == B_dims) {
|
||||||
|
const int size = std::accumulate(
|
||||||
|
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||||
|
math::Mul(size, dC, B, dA, context);
|
||||||
|
math::Mul(size, dC, A, dB, context);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ndim = std::max(A_dims.size(), B_dims.size());
|
||||||
|
if (ndim == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> A_broadcast_dims(ndim);
|
||||||
|
std::vector<int> B_broadcast_dims(ndim);
|
||||||
|
std::vector<int> C_broadcast_dims(ndim);
|
||||||
|
math::utils::ComputeBroadcastBinaryOpDims(
|
||||||
|
A_dims.size(),
|
||||||
|
A_dims.data(),
|
||||||
|
B_dims.size(),
|
||||||
|
B_dims.data(),
|
||||||
|
A_broadcast_dims.data(),
|
||||||
|
B_broadcast_dims.data(),
|
||||||
|
C_broadcast_dims.data());
|
||||||
|
|
||||||
|
const int C_size = std::accumulate(
|
||||||
|
C_broadcast_dims.cbegin(),
|
||||||
|
C_broadcast_dims.cbegin() + ndim,
|
||||||
|
1,
|
||||||
|
std::multiplies<int>());
|
||||||
|
if (C_size == 0) {
|
||||||
|
const int A_size = std::accumulate(
|
||||||
|
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||||
|
const int B_size = std::accumulate(
|
||||||
|
B_dims.cbegin(), B_dims.cend(), 1, std::multiplies<int>());
|
||||||
|
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||||
|
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten dims as much as possible
|
||||||
|
// We call A is broadcasted at dim d if A_broadcast_dims[d] <= 1
|
||||||
|
// Two consecutive dims d and d+1 can be flattened if
|
||||||
|
// A and B are broadcasted at dim d, or
|
||||||
|
// A and B are broadcasted at dim d + 1, or
|
||||||
|
// A is broadcasted at dim d and d + 1, or
|
||||||
|
// B is broadcasted at dim d and d + 1, or
|
||||||
|
// A and B are not broadcasted at dim d and d + 1
|
||||||
|
std::vector<int> A_broadcast_dims_flattened, B_broadcast_dims_flattened,
|
||||||
|
C_broadcast_dims_flattened;
|
||||||
|
A_broadcast_dims_flattened.reserve(ndim);
|
||||||
|
B_broadcast_dims_flattened.reserve(ndim);
|
||||||
|
|
||||||
|
A_broadcast_dims_flattened.push_back(A_broadcast_dims[0]);
|
||||||
|
B_broadcast_dims_flattened.push_back(B_broadcast_dims[0]);
|
||||||
|
|
||||||
|
for (int i = 1; i < ndim; ++i) {
|
||||||
|
int A_old = A_broadcast_dims_flattened.back();
|
||||||
|
int B_old = B_broadcast_dims_flattened.back();
|
||||||
|
int A_new = A_broadcast_dims[i];
|
||||||
|
int B_new = B_broadcast_dims[i];
|
||||||
|
if ((A_old == 1 && B_old == 1) || (A_new == 1 && B_new == 1) ||
|
||||||
|
(A_old == 1 && A_new == 1) || (B_old == 1 && B_new == 1) ||
|
||||||
|
(A_old > 1 && B_old > 1 && A_new > 1 && B_new > 1)) {
|
||||||
|
A_broadcast_dims_flattened.back() *= A_new;
|
||||||
|
B_broadcast_dims_flattened.back() *= B_new;
|
||||||
|
} else {
|
||||||
|
A_broadcast_dims_flattened.push_back(A_new);
|
||||||
|
B_broadcast_dims_flattened.push_back(B_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int ndim_flattened = A_broadcast_dims_flattened.size();
|
||||||
|
C_broadcast_dims_flattened.resize(ndim_flattened);
|
||||||
|
for (int i = 0; i < ndim_flattened; ++i) {
|
||||||
|
C_broadcast_dims_flattened[i] =
|
||||||
|
std::max(A_broadcast_dims_flattened[i], B_broadcast_dims_flattened[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::is_same<TGrad, float>::value && std::is_same<TIn, float>::value &&
|
||||||
|
ndim_flattened <= 2 &&
|
||||||
|
A_broadcast_dims_flattened[0] == B_broadcast_dims_flattened[0] &&
|
||||||
|
(ndim_flattened == 1 || A_broadcast_dims_flattened[1] <= 1 ||
|
||||||
|
B_broadcast_dims_flattened[1] <= 1)) {
|
||||||
|
if (ndim_flattened == 2) {
|
||||||
|
// fast path when we have 2 flattened dimensions and the second dimension
|
||||||
|
// is broadcasted.
|
||||||
|
bool broadcast_B = B_broadcast_dims_flattened[1] <= 1;
|
||||||
|
ComputeMulGradient(
|
||||||
|
C_broadcast_dims_flattened[0],
|
||||||
|
C_broadcast_dims_flattened[1],
|
||||||
|
reinterpret_cast<const float*>(dC),
|
||||||
|
reinterpret_cast<const float*>(broadcast_B ? A : B),
|
||||||
|
reinterpret_cast<const float*>(broadcast_B ? B : A),
|
||||||
|
reinterpret_cast<float*>(broadcast_B ? dA : dB),
|
||||||
|
reinterpret_cast<float*>(broadcast_B ? dB : dA),
|
||||||
|
context);
|
||||||
|
} else {
|
||||||
|
// fast path when we have 1 flattened dimension
|
||||||
|
assert(ndim_flattened == 1);
|
||||||
|
ComputeMulGradient(
|
||||||
|
C_broadcast_dims_flattened[0],
|
||||||
|
reinterpret_cast<const float*>(dC),
|
||||||
|
reinterpret_cast<const float*>(A),
|
||||||
|
reinterpret_cast<const float*>(B),
|
||||||
|
reinterpret_cast<float*>(dA),
|
||||||
|
reinterpret_cast<float*>(dB));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ComputeMulGradient<TGrad, TIn>(
|
||||||
|
ndim_flattened,
|
||||||
|
A_broadcast_dims_flattened.data(),
|
||||||
|
B_broadcast_dims_flattened.data(),
|
||||||
|
C_broadcast_dims_flattened.data(),
|
||||||
|
dC,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
dA,
|
||||||
|
dB,
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_CPU_OPERATOR(
|
REGISTER_CPU_OPERATOR(
|
||||||
MulGradient,
|
MulGradient,
|
||||||
BinaryElementwiseGradientOp<
|
BinaryElementwiseGradientOp<
|
||||||
|
|
|
||||||
|
|
@ -8,42 +8,6 @@
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename TGrad, typename TIn>
|
|
||||||
void ComputeMulGradient(
|
|
||||||
const int ndim,
|
|
||||||
const int* A_dims,
|
|
||||||
const int* B_dims,
|
|
||||||
const int* C_dims,
|
|
||||||
const TGrad* dC,
|
|
||||||
const TIn* A,
|
|
||||||
const TIn* B,
|
|
||||||
TGrad* dA,
|
|
||||||
TGrad* dB,
|
|
||||||
CPUContext* context) {
|
|
||||||
const int A_size =
|
|
||||||
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
|
||||||
const int B_size =
|
|
||||||
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
|
||||||
const int C_size =
|
|
||||||
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
|
||||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
|
||||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
|
||||||
std::vector<int> index(ndim, 0);
|
|
||||||
for (int C_index = 0; C_index < C_size; ++C_index) {
|
|
||||||
const int A_index =
|
|
||||||
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
|
||||||
const int B_index =
|
|
||||||
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
|
||||||
dA[A_index] += dC[C_index] * B[B_index];
|
|
||||||
dB[B_index] += dC[C_index] * A[A_index];
|
|
||||||
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <class Context>
|
template <class Context>
|
||||||
struct MulFunctor {
|
struct MulFunctor {
|
||||||
template <typename TIn, typename TOut>
|
template <typename TIn, typename TOut>
|
||||||
|
|
@ -79,51 +43,6 @@ struct MulFunctor {
|
||||||
Context* context) const;
|
Context* context) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
template <typename TGrad, typename TIn, typename TOut>
|
|
||||||
bool MulFunctor<CPUContext>::Backward(
|
|
||||||
const std::vector<int>& A_dims,
|
|
||||||
const std::vector<int>& B_dims,
|
|
||||||
const TGrad* dC,
|
|
||||||
const TIn* A,
|
|
||||||
const TIn* B,
|
|
||||||
const TOut* /* C */,
|
|
||||||
TGrad* dA,
|
|
||||||
TGrad* dB,
|
|
||||||
CPUContext* context) const {
|
|
||||||
if (A_dims == B_dims) {
|
|
||||||
const int size = std::accumulate(
|
|
||||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
|
||||||
math::Mul(size, dC, B, dA, context);
|
|
||||||
math::Mul(size, dC, A, dB, context);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const int ndim = std::max(A_dims.size(), B_dims.size());
|
|
||||||
std::vector<int> A_broadcast_dims(ndim);
|
|
||||||
std::vector<int> B_broadcast_dims(ndim);
|
|
||||||
std::vector<int> C_broadcast_dims(ndim);
|
|
||||||
math::utils::ComputeBroadcastBinaryOpDims(
|
|
||||||
A_dims.size(),
|
|
||||||
A_dims.data(),
|
|
||||||
B_dims.size(),
|
|
||||||
B_dims.data(),
|
|
||||||
A_broadcast_dims.data(),
|
|
||||||
B_broadcast_dims.data(),
|
|
||||||
C_broadcast_dims.data());
|
|
||||||
ComputeMulGradient<TGrad, TIn>(
|
|
||||||
ndim,
|
|
||||||
A_broadcast_dims.data(),
|
|
||||||
B_broadcast_dims.data(),
|
|
||||||
C_broadcast_dims.data(),
|
|
||||||
dC,
|
|
||||||
A,
|
|
||||||
B,
|
|
||||||
dA,
|
|
||||||
dB,
|
|
||||||
context);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
|
||||||
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|
||||||
|
|
|
||||||
38
caffe2/python/operator_test/mul_gradient_benchmark.py
Normal file
38
caffe2/python/operator_test/mul_gradient_benchmark.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from caffe2.python import core, workspace
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_mul_gradient(args):
|
||||||
|
workspace.FeedBlob("dC", np.random.rand(args.m, args.n).astype(np.float32))
|
||||||
|
workspace.FeedBlob("A", np.random.rand(args.m, args.n).astype(np.float32))
|
||||||
|
workspace.FeedBlob("B", np.random.rand(args.m).astype(np.float32))
|
||||||
|
|
||||||
|
net = core.Net("mynet")
|
||||||
|
net.MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0)
|
||||||
|
workspace.CreateNet(net)
|
||||||
|
|
||||||
|
workspace.BenchmarkNet(net.Name(), 1, args.iteration, True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="benchmark for MulGradient.")
|
||||||
|
parser.add_argument(
|
||||||
|
'-m', type=int, default=9508,
|
||||||
|
help="The number of rows of A")
|
||||||
|
parser.add_argument(
|
||||||
|
"-n", type=int, default=80,
|
||||||
|
help="The number of columns of A")
|
||||||
|
parser.add_argument(
|
||||||
|
'-i', "--iteration", type=int, default=100,
|
||||||
|
help="The number of iterations.")
|
||||||
|
args, extra_args = parser.parse_known_args()
|
||||||
|
core.GlobalInit(['python'] + extra_args)
|
||||||
|
benchmark_mul_gradient(args)
|
||||||
Loading…
Reference in New Issue
Block a user