mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Avoid undefined symbol error when building AdIndexer LTO (#19009)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19009 Move the definition of `MulFunctor<>::Backward()` into a header file. Reviewed By: BIT-silence Differential Revision: D14823230 fbshipit-source-id: 1efaec01863fcc02dcbe7e788d376e72f8564501
This commit is contained in:
parent
ada10ad416
commit
20fc7b6ec7
|
|
@ -7,87 +7,6 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TGrad, typename TIn>
|
||||
void ComputeMulGradient(
|
||||
const int ndim,
|
||||
const int* A_dims,
|
||||
const int* B_dims,
|
||||
const int* C_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) {
|
||||
const int A_size =
|
||||
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
||||
const int B_size =
|
||||
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
||||
const int C_size =
|
||||
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int C_index = 0; C_index < C_size; ++C_index) {
|
||||
const int A_index =
|
||||
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
||||
const int B_index =
|
||||
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
||||
dA[A_index] += dC[C_index] * B[B_index];
|
||||
dB[B_index] += dC[C_index] * A[A_index];
|
||||
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
template <typename TGrad, typename TIn, typename TOut>
|
||||
bool MulFunctor<CPUContext>::Backward(
|
||||
const std::vector<int>& A_dims,
|
||||
const std::vector<int>& B_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
const TOut* /* C */,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) const {
|
||||
if (A_dims == B_dims) {
|
||||
const int size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
math::Mul(size, dC, B, dA, context);
|
||||
math::Mul(size, dC, A, dB, context);
|
||||
return true;
|
||||
}
|
||||
const int ndim = std::max(A_dims.size(), B_dims.size());
|
||||
std::vector<int> A_broadcast_dims(ndim);
|
||||
std::vector<int> B_broadcast_dims(ndim);
|
||||
std::vector<int> C_broadcast_dims(ndim);
|
||||
math::utils::ComputeBroadcastBinaryOpDims(
|
||||
A_dims.size(),
|
||||
A_dims.data(),
|
||||
B_dims.size(),
|
||||
B_dims.data(),
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data());
|
||||
ComputeMulGradient<TGrad, TIn>(
|
||||
ndim,
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data(),
|
||||
dC,
|
||||
A,
|
||||
B,
|
||||
dA,
|
||||
dB,
|
||||
context);
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CPU_OPERATOR(
|
||||
MulGradient,
|
||||
BinaryElementwiseGradientOp<
|
||||
|
|
|
|||
|
|
@ -8,6 +8,42 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename TGrad, typename TIn>
|
||||
void ComputeMulGradient(
|
||||
const int ndim,
|
||||
const int* A_dims,
|
||||
const int* B_dims,
|
||||
const int* C_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) {
|
||||
const int A_size =
|
||||
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
||||
const int B_size =
|
||||
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
||||
const int C_size =
|
||||
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int C_index = 0; C_index < C_size; ++C_index) {
|
||||
const int A_index =
|
||||
math::utils::GetIndexFromDims(ndim, A_dims, index.data());
|
||||
const int B_index =
|
||||
math::utils::GetIndexFromDims(ndim, B_dims, index.data());
|
||||
dA[A_index] += dC[C_index] * B[B_index];
|
||||
dB[B_index] += dC[C_index] * A[A_index];
|
||||
math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class Context>
|
||||
struct MulFunctor {
|
||||
template <typename TIn, typename TOut>
|
||||
|
|
@ -43,6 +79,51 @@ struct MulFunctor {
|
|||
Context* context) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
template <typename TGrad, typename TIn, typename TOut>
|
||||
bool MulFunctor<CPUContext>::Backward(
|
||||
const std::vector<int>& A_dims,
|
||||
const std::vector<int>& B_dims,
|
||||
const TGrad* dC,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
const TOut* /* C */,
|
||||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) const {
|
||||
if (A_dims == B_dims) {
|
||||
const int size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
math::Mul(size, dC, B, dA, context);
|
||||
math::Mul(size, dC, A, dB, context);
|
||||
return true;
|
||||
}
|
||||
const int ndim = std::max(A_dims.size(), B_dims.size());
|
||||
std::vector<int> A_broadcast_dims(ndim);
|
||||
std::vector<int> B_broadcast_dims(ndim);
|
||||
std::vector<int> C_broadcast_dims(ndim);
|
||||
math::utils::ComputeBroadcastBinaryOpDims(
|
||||
A_dims.size(),
|
||||
A_dims.data(),
|
||||
B_dims.size(),
|
||||
B_dims.data(),
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data());
|
||||
ComputeMulGradient<TGrad, TIn>(
|
||||
ndim,
|
||||
A_broadcast_dims.data(),
|
||||
B_broadcast_dims.data(),
|
||||
C_broadcast_dims.data(),
|
||||
dC,
|
||||
A,
|
||||
B,
|
||||
dA,
|
||||
dB,
|
||||
context);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user