diff --git a/caffe2/operators/elementwise_mul_gradient_op.cc b/caffe2/operators/elementwise_mul_gradient_op.cc index 50655043497..b08bb377a28 100644 --- a/caffe2/operators/elementwise_mul_gradient_op.cc +++ b/caffe2/operators/elementwise_mul_gradient_op.cc @@ -7,6 +7,219 @@ namespace caffe2 { +namespace { + +template +void ComputeMulGradient( + const int ndim, + const int* A_dims, + const int* B_dims, + const int* C_dims, + const TGrad* dC, + const TIn* A, + const TIn* B, + TGrad* dA, + TGrad* dB, + CPUContext* context) { + const int A_size = + std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies()); + const int B_size = + std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies()); + const int C_size = + std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies()); + math::Set(A_size, TGrad(0), dA, context); + math::Set(B_size, TGrad(0), dB, context); + std::vector index(ndim, 0); + for (int C_index = 0; C_index < C_size; ++C_index) { + const int A_index = + math::utils::GetIndexFromDims(ndim, A_dims, index.data()); + const int B_index = + math::utils::GetIndexFromDims(ndim, B_dims, index.data()); + dA[A_index] += dC[C_index] * B[B_index]; + dB[B_index] += dC[C_index] * A[A_index]; + math::utils::IncreaseIndexInDims(ndim, C_dims, index.data()); + } +} + +// A : input not to broadcast whose size is common_size x broadcast_size +// B : input to broadcast whose size is common_size +void ComputeMulGradient( + const int common_size, + const int broadcast_size, + const float* dC, + const float* A, + const float* B, + float* dA, + float* dB, + CPUContext* context) { + for (int i = 0; i < common_size; ++i) { + caffe2::math::Scale( + broadcast_size, + B[i], + dC + i * broadcast_size, + dA + i * broadcast_size, + context); + caffe2::math::Dot( + broadcast_size, + dC + i * broadcast_size, + A + i * broadcast_size, + dB + i, + context); + } +} + +void ComputeMulGradient( + const int size, + const float* dC, + const float* A, + const float* B, + float* dA, + float* dB) { + for (int i = 0; i < size; ++i) { + dA[i] = dC[i] * B[i]; + dB[i] = dC[i] * A[i]; + } +} + +} // namespace + +template <> +template +bool MulFunctor::Backward( + const std::vector& A_dims, + const std::vector& B_dims, + const TGrad* dC, + const TIn* A, + const TIn* B, + const TOut* /* C */, + TGrad* dA, + TGrad* dB, + CPUContext* context) const { + if (A_dims == B_dims) { + const int size = std::accumulate( + A_dims.cbegin(), A_dims.cend(), 1, std::multiplies()); + math::Mul(size, dC, B, dA, context); + math::Mul(size, dC, A, dB, context); + return true; + } + + const int ndim = std::max(A_dims.size(), B_dims.size()); + if (ndim == 0) { + return true; + } + + std::vector A_broadcast_dims(ndim); + std::vector B_broadcast_dims(ndim); + std::vector C_broadcast_dims(ndim); + math::utils::ComputeBroadcastBinaryOpDims( + A_dims.size(), + A_dims.data(), + B_dims.size(), + B_dims.data(), + A_broadcast_dims.data(), + B_broadcast_dims.data(), + C_broadcast_dims.data()); + + const int C_size = std::accumulate( + C_broadcast_dims.cbegin(), + C_broadcast_dims.cbegin() + ndim, + 1, + std::multiplies()); + if (C_size == 0) { + const int A_size = std::accumulate( + A_dims.cbegin(), A_dims.cend(), 1, std::multiplies()); + const int B_size = std::accumulate( + B_dims.cbegin(), B_dims.cend(), 1, std::multiplies()); + math::Set(A_size, TGrad(0), dA, context); + math::Set(B_size, TGrad(0), dB, context); + return true; + } + + // Flatten dims as much as possible + // We call A is broadcasted at dim d if A_broadcast_dims[d] <= 1 + // Two consecutive dims d and d+1 can be flattened if + // A and B are broadcasted at dim d, or + // A and B are broadcasted at dim d + 1, or + // A is broadcasted at dim d and d + 1, or + // B is broadcasted at dim d and d + 1, or + // A and B are not broadcasted at dim d and d + 1 + std::vector A_broadcast_dims_flattened, B_broadcast_dims_flattened, + C_broadcast_dims_flattened; + A_broadcast_dims_flattened.reserve(ndim); + B_broadcast_dims_flattened.reserve(ndim); + + A_broadcast_dims_flattened.push_back(A_broadcast_dims[0]); + B_broadcast_dims_flattened.push_back(B_broadcast_dims[0]); + + for (int i = 1; i < ndim; ++i) { + int A_old = A_broadcast_dims_flattened.back(); + int B_old = B_broadcast_dims_flattened.back(); + int A_new = A_broadcast_dims[i]; + int B_new = B_broadcast_dims[i]; + if ((A_old == 1 && B_old == 1) || (A_new == 1 && B_new == 1) || + (A_old == 1 && A_new == 1) || (B_old == 1 && B_new == 1) || + (A_old > 1 && B_old > 1 && A_new > 1 && B_new > 1)) { + A_broadcast_dims_flattened.back() *= A_new; + B_broadcast_dims_flattened.back() *= B_new; + } else { + A_broadcast_dims_flattened.push_back(A_new); + B_broadcast_dims_flattened.push_back(B_new); + } + } + + int ndim_flattened = A_broadcast_dims_flattened.size(); + C_broadcast_dims_flattened.resize(ndim_flattened); + for (int i = 0; i < ndim_flattened; ++i) { + C_broadcast_dims_flattened[i] = + std::max(A_broadcast_dims_flattened[i], B_broadcast_dims_flattened[i]); + } + + if (std::is_same::value && std::is_same::value && + ndim_flattened <= 2 && + A_broadcast_dims_flattened[0] == B_broadcast_dims_flattened[0] && + (ndim_flattened == 1 || A_broadcast_dims_flattened[1] <= 1 || + B_broadcast_dims_flattened[1] <= 1)) { + if (ndim_flattened == 2) { + // fast path when we have 2 flattened dimensions and the second dimension + // is broadcasted. + bool broadcast_B = B_broadcast_dims_flattened[1] <= 1; + ComputeMulGradient( + C_broadcast_dims_flattened[0], + C_broadcast_dims_flattened[1], + reinterpret_cast(dC), + reinterpret_cast(broadcast_B ? A : B), + reinterpret_cast(broadcast_B ? B : A), + reinterpret_cast(broadcast_B ? dA : dB), + reinterpret_cast(broadcast_B ? dB : dA), + context); + } else { + // fast path when we have 1 flattened dimension + assert(ndim_flattened == 1); + ComputeMulGradient( + C_broadcast_dims_flattened[0], + reinterpret_cast(dC), + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(dA), + reinterpret_cast(dB)); + } + } else { + ComputeMulGradient( + ndim_flattened, + A_broadcast_dims_flattened.data(), + B_broadcast_dims_flattened.data(), + C_broadcast_dims_flattened.data(), + dC, + A, + B, + dA, + dB, + context); + } + + return true; +} + REGISTER_CPU_OPERATOR( MulGradient, BinaryElementwiseGradientOp< diff --git a/caffe2/operators/elementwise_mul_op.h b/caffe2/operators/elementwise_mul_op.h index 6b31fe3684d..f1c42edc48b 100644 --- a/caffe2/operators/elementwise_mul_op.h +++ b/caffe2/operators/elementwise_mul_op.h @@ -8,42 +8,6 @@ namespace caffe2 { -namespace { - -template -void ComputeMulGradient( - const int ndim, - const int* A_dims, - const int* B_dims, - const int* C_dims, - const TGrad* dC, - const TIn* A, - const TIn* B, - TGrad* dA, - TGrad* dB, - CPUContext* context) { - const int A_size = - std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies()); - const int B_size = - std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies()); - const int C_size = - std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies()); - math::Set(A_size, TGrad(0), dA, context); - math::Set(B_size, TGrad(0), dB, context); - std::vector index(ndim, 0); - for (int C_index = 0; C_index < C_size; ++C_index) { - const int A_index = - math::utils::GetIndexFromDims(ndim, A_dims, index.data()); - const int B_index = - math::utils::GetIndexFromDims(ndim, B_dims, index.data()); - dA[A_index] += dC[C_index] * B[B_index]; - dB[B_index] += dC[C_index] * A[A_index]; - math::utils::IncreaseIndexInDims(ndim, C_dims, index.data()); - } -} - -} // namespace - template struct MulFunctor { template @@ -79,51 +43,6 @@ struct MulFunctor { Context* context) const; }; -template <> -template -bool MulFunctor::Backward( - const std::vector& A_dims, - const std::vector& B_dims, - const TGrad* dC, - const TIn* A, - const TIn* B, - const TOut* /* C */, - TGrad* dA, - TGrad* dB, - CPUContext* context) const { - if (A_dims == B_dims) { - const int size = std::accumulate( - A_dims.cbegin(), A_dims.cend(), 1, std::multiplies()); - math::Mul(size, dC, B, dA, context); - math::Mul(size, dC, A, dB, context); - return true; - } - const int ndim = std::max(A_dims.size(), B_dims.size()); - std::vector A_broadcast_dims(ndim); - std::vector B_broadcast_dims(ndim); - std::vector C_broadcast_dims(ndim); - math::utils::ComputeBroadcastBinaryOpDims( - A_dims.size(), - A_dims.data(), - B_dims.size(), - B_dims.data(), - A_broadcast_dims.data(), - B_broadcast_dims.data(), - C_broadcast_dims.data()); - ComputeMulGradient( - ndim, - A_broadcast_dims.data(), - B_broadcast_dims.data(), - C_broadcast_dims.data(), - dC, - A, - B, - dA, - dB, - context); - return true; -} - } // namespace caffe2 #endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_ diff --git a/caffe2/python/operator_test/mul_gradient_benchmark.py b/caffe2/python/operator_test/mul_gradient_benchmark.py new file mode 100644 index 00000000000..72167623940 --- /dev/null +++ b/caffe2/python/operator_test/mul_gradient_benchmark.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import numpy as np + +from caffe2.python import core, workspace + + +def benchmark_mul_gradient(args): + workspace.FeedBlob("dC", np.random.rand(args.m, args.n).astype(np.float32)) + workspace.FeedBlob("A", np.random.rand(args.m, args.n).astype(np.float32)) + workspace.FeedBlob("B", np.random.rand(args.m).astype(np.float32)) + + net = core.Net("mynet") + net.MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0) + workspace.CreateNet(net) + + workspace.BenchmarkNet(net.Name(), 1, args.iteration, True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="benchmark for MulGradient.") + parser.add_argument( + '-m', type=int, default=9508, + help="The number of rows of A") + parser.add_argument( + "-n", type=int, default=80, + help="The number of columns of A") + parser.add_argument( + '-i', "--iteration", type=int, default=100, + help="The number of iterations.") + args, extra_args = parser.parse_known_args() + core.GlobalInit(['python'] + extra_args) + benchmark_mul_gradient(args)