mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19009 Move the definition of `MulFunctor<>::Backward()` into a header file. Reviewed By: BIT-silence Differential Revision: D14823230 fbshipit-source-id: 1efaec01863fcc02dcbe7e788d376e72f8564501
130 lines
3.3 KiB
C++
130 lines
3.3 KiB
C++
#ifndef CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|
|
#define CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|
|
|
|
#include <vector>
|
|
|
|
#include "caffe2/operators/elementwise_ops.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace {
|
|
|
|
template <typename TGrad, typename TIn>
|
|
void ComputeMulGradient(
|
|
const int ndim,
|
|
const int* A_dims,
|
|
const int* B_dims,
|
|
const int* C_dims,
|
|
const TGrad* dC,
|
|
const TIn* A,
|
|
const TIn* B,
|
|
TGrad* dA,
|
|
TGrad* dB,
|
|
CPUContext* context) {
|
|
const int A_size =
|
|
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
|
const int B_size =
|
|
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
|
const int C_size =
|
|
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
|
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
|
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
|
std::vector<int> index(ndim, 0);
|
|
for (int C_index = 0; C_index < C_size; ++C_index) {
|
|
const int A_index =
|
|
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
|
const int B_index =
|
|
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
|
dA[A_index] += dC[C_index] * B[B_index];
|
|
dB[B_index] += dC[C_index] * A[A_index];
|
|
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
template <class Context>
|
|
struct MulFunctor {
|
|
template <typename TIn, typename TOut>
|
|
bool Forward(
|
|
const std::vector<int>& A_dims,
|
|
const std::vector<int>& B_dims,
|
|
const TIn* A,
|
|
const TIn* B,
|
|
TOut* C,
|
|
Context* context) const {
|
|
math::Mul(
|
|
A_dims.size(),
|
|
A_dims.data(),
|
|
B_dims.size(),
|
|
B_dims.data(),
|
|
A,
|
|
B,
|
|
C,
|
|
context);
|
|
return true;
|
|
}
|
|
|
|
template <typename TGrad, typename TIn, typename TOut>
|
|
bool Backward(
|
|
const std::vector<int>& A_dims,
|
|
const std::vector<int>& B_dims,
|
|
const TGrad* dC_data,
|
|
const TIn* A_data,
|
|
const TIn* B_data,
|
|
const TOut* C_data,
|
|
TGrad* dA_data,
|
|
TGrad* dB_data,
|
|
Context* context) const;
|
|
};
|
|
|
|
template <>
|
|
template <typename TGrad, typename TIn, typename TOut>
|
|
bool MulFunctor<CPUContext>::Backward(
|
|
const std::vector<int>& A_dims,
|
|
const std::vector<int>& B_dims,
|
|
const TGrad* dC,
|
|
const TIn* A,
|
|
const TIn* B,
|
|
const TOut* /* C */,
|
|
TGrad* dA,
|
|
TGrad* dB,
|
|
CPUContext* context) const {
|
|
if (A_dims == B_dims) {
|
|
const int size = std::accumulate(
|
|
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
|
math::Mul(size, dC, B, dA, context);
|
|
math::Mul(size, dC, A, dB, context);
|
|
return true;
|
|
}
|
|
const int ndim = std::max(A_dims.size(), B_dims.size());
|
|
std::vector<int> A_broadcast_dims(ndim);
|
|
std::vector<int> B_broadcast_dims(ndim);
|
|
std::vector<int> C_broadcast_dims(ndim);
|
|
math::utils::ComputeBroadcastBinaryOpDims(
|
|
A_dims.size(),
|
|
A_dims.data(),
|
|
B_dims.size(),
|
|
B_dims.data(),
|
|
A_broadcast_dims.data(),
|
|
B_broadcast_dims.data(),
|
|
C_broadcast_dims.data());
|
|
ComputeMulGradient<TGrad, TIn>(
|
|
ndim,
|
|
A_broadcast_dims.data(),
|
|
B_broadcast_dims.data(),
|
|
C_broadcast_dims.data(),
|
|
dC,
|
|
A,
|
|
B,
|
|
dA,
|
|
dB,
|
|
context);
|
|
return true;
|
|
}
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|