Avoid undefined symbol error when building AdIndexer LTO (#19009)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19009

Move the definition of `MulFunctor<>::Backward()` into a header file.

Reviewed By: BIT-silence

Differential Revision: D14823230

fbshipit-source-id: 1efaec01863fcc02dcbe7e788d376e72f8564501
This commit is contained in:
Mark Santaniello 2019-04-15 23:40:21 -07:00 committed by Facebook Github Bot
parent ada10ad416
commit 20fc7b6ec7
2 changed files with 81 additions and 81 deletions

View File

@ -7,87 +7,6 @@
namespace caffe2 {
namespace {
template <typename TGrad, typename TIn>
void ComputeMulGradient(
const int ndim,
const int* A_dims,
const int* B_dims,
const int* C_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
TGrad* dA,
TGrad* dB,
CPUContext* context) {
const int A_size =
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
const int B_size =
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
const int C_size =
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
std::vector<int> index(ndim, 0);
for (int C_index = 0; C_index < C_size; ++C_index) {
const int A_index =
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
const int B_index =
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
dA[A_index] += dC[C_index] * B[B_index];
dB[B_index] += dC[C_index] * A[A_index];
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
}
}
} // namespace
template <>
template <typename TGrad, typename TIn, typename TOut>
bool MulFunctor<CPUContext>::Backward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
const TOut* /* C */,
TGrad* dA,
TGrad* dB,
CPUContext* context) const {
if (A_dims == B_dims) {
const int size = std::accumulate(
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
math::Mul(size, dC, B, dA, context);
math::Mul(size, dC, A, dB, context);
return true;
}
const int ndim = std::max(A_dims.size(), B_dims.size());
std::vector<int> A_broadcast_dims(ndim);
std::vector<int> B_broadcast_dims(ndim);
std::vector<int> C_broadcast_dims(ndim);
math::utils::ComputeBroadcastBinaryOpDims(
A_dims.size(),
A_dims.data(),
B_dims.size(),
B_dims.data(),
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data());
ComputeMulGradient<TGrad, TIn>(
ndim,
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data(),
dC,
A,
B,
dA,
dB,
context);
return true;
}
REGISTER_CPU_OPERATOR(
MulGradient,
BinaryElementwiseGradientOp<

View File

@ -8,6 +8,42 @@
namespace caffe2 {
namespace {
template <typename TGrad, typename TIn>
void ComputeMulGradient(
const int ndim,
const int* A_dims,
const int* B_dims,
const int* C_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
TGrad* dA,
TGrad* dB,
CPUContext* context) {
const int A_size =
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
const int B_size =
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
const int C_size =
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
std::vector<int> index(ndim, 0);
for (int C_index = 0; C_index < C_size; ++C_index) {
const int A_index =
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
const int B_index =
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
dA[A_index] += dC[C_index] * B[B_index];
dB[B_index] += dC[C_index] * A[A_index];
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
}
}
} // namespace
template <class Context>
struct MulFunctor {
template <typename TIn, typename TOut>
@ -43,6 +79,51 @@ struct MulFunctor {
Context* context) const;
};
template <>
template <typename TGrad, typename TIn, typename TOut>
bool MulFunctor<CPUContext>::Backward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TGrad* dC,
const TIn* A,
const TIn* B,
const TOut* /* C */,
TGrad* dA,
TGrad* dB,
CPUContext* context) const {
if (A_dims == B_dims) {
const int size = std::accumulate(
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
math::Mul(size, dC, B, dA, context);
math::Mul(size, dC, A, dB, context);
return true;
}
const int ndim = std::max(A_dims.size(), B_dims.size());
std::vector<int> A_broadcast_dims(ndim);
std::vector<int> B_broadcast_dims(ndim);
std::vector<int> C_broadcast_dims(ndim);
math::utils::ComputeBroadcastBinaryOpDims(
A_dims.size(),
A_dims.data(),
B_dims.size(),
B_dims.data(),
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data());
ComputeMulGradient<TGrad, TIn>(
ndim,
A_broadcast_dims.data(),
B_broadcast_dims.data(),
C_broadcast_dims.data(),
dC,
A,
B,
dA,
dB,
context);
return true;
}
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_