mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO. Manually edited some references of the removed `CAFFE2_API`: * `CONTRIBUTING.md` * `caffe2/proto/CMakeLists.txt` * `cmake/ProtoBuf.cmake` * `c10/macros/Export.h` * `torch/csrc/WindowsTorchApiMacro.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496 Reviewed By: malfet, samestep Differential Revision: D25600726 Pulled By: janeyx99 fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
187 lines
4.5 KiB
C++
187 lines
4.5 KiB
C++
#ifndef CAFFE2_UTILS_MATH_UTILS_H_
|
|
#define CAFFE2_UTILS_MATH_UTILS_H_
|
|
|
|
#include <vector>
|
|
|
|
#include "caffe2/core/common.h"
|
|
|
|
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || \
|
|
defined(__HIP__) || (defined(__clang__) && defined(__CUDA__))
|
|
#define MATH_UTILS_DECL inline __host__ __device__
|
|
#else
|
|
#define MATH_UTILS_DECL inline
|
|
#endif
|
|
|
|
namespace caffe2 {
|
|
namespace math {
|
|
|
|
namespace utils {
|
|
|
|
template <typename T>
|
|
MATH_UTILS_DECL T Not(const T x) {
|
|
return !x;
|
|
}
|
|
|
|
template <typename T>
|
|
MATH_UTILS_DECL T Sign(const T x) {
|
|
return x > 0 ? T(1) : (x < 0 ? T(-1) : T(0));
|
|
}
|
|
|
|
template <typename T>
|
|
MATH_UTILS_DECL T Negate(const T x) {
|
|
return -x;
|
|
}
|
|
|
|
template <typename T>
|
|
MATH_UTILS_DECL T Inv(const T x) {
|
|
return T(1) / x;
|
|
}
|
|
|
|
template <typename T>
|
|
MATH_UTILS_DECL T Square(const T x) {
|
|
return x * x;
|
|
}
|
|
|
|
template <typename T>
|
|
MATH_UTILS_DECL T Cube(const T x) {
|
|
return x * x * x;
|
|
}
|
|
|
|
// Function uses casting from int to unsigned to compare if value of
|
|
// parameter a is greater or equal to zero and lower than value of
|
|
// parameter b. The b parameter is of type signed and is always
|
|
// positive,
|
|
// therefore its value is always lower than 0x800... where casting
|
|
// negative value of a parameter converts it to value higher than
|
|
// 0x800...
|
|
// The casting allows to use one condition instead of two.
|
|
MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) {
|
|
return static_cast<unsigned int>(a) < static_cast<unsigned int>(b);
|
|
}
|
|
|
|
// Increase the index digits by one based on dims.
|
|
template <typename TIndex>
|
|
TORCH_API void
|
|
IncreaseIndexInDims(int ndim, const TIndex* dims, TIndex* index);
|
|
|
|
// Get index value from dims and index digits.
|
|
template <typename TIndex>
|
|
TORCH_API TIndex
|
|
GetIndexFromDims(const int n, const TIndex* dims, const TIndex* index);
|
|
|
|
// Checks if the input permutation is an identity permutation;
|
|
TORCH_API bool IsIdentityPermutation(const int n, const int* perm);
|
|
|
|
TORCH_API bool
|
|
CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims);
|
|
|
|
TORCH_API bool IsRowwiseReduce(
|
|
const int ndim,
|
|
const int* X_dims,
|
|
const int* Y_dims,
|
|
int* rows,
|
|
int* cols);
|
|
|
|
TORCH_API bool IsColwiseReduce(
|
|
const int ndim,
|
|
const int* X_dims,
|
|
const int* Y_dims,
|
|
int* rows,
|
|
int* cols);
|
|
|
|
TORCH_API bool IsBothEndsReduce(
|
|
const int ndim,
|
|
const int* X_dims,
|
|
const int* Y_dims,
|
|
int* pre,
|
|
int* mid,
|
|
int* nxt);
|
|
|
|
// Computest the broadcast binary operation dims.
|
|
template <typename TIndex>
|
|
TORCH_API void ComputeBroadcastBinaryOpDims(
|
|
const int A_ndim,
|
|
const TIndex* A_dims,
|
|
const int B_ndim,
|
|
const TIndex* B_dims,
|
|
TIndex* A_broadcast_dims,
|
|
TIndex* B_broadcast_dims,
|
|
TIndex* C_broadcast_dims);
|
|
|
|
TORCH_API bool IsRowwiseBroadcastBinaryOp(
|
|
const int ndim,
|
|
const int* A_dims,
|
|
const int* B_dims,
|
|
int* rows,
|
|
int* cols,
|
|
bool* broadcast_1st);
|
|
|
|
TORCH_API bool IsColwiseBroadcastBinaryOp(
|
|
const int ndim,
|
|
const int* A_dims,
|
|
const int* B_dims,
|
|
int* rows,
|
|
int* cols,
|
|
bool* broadcast_1st);
|
|
|
|
TORCH_API bool IsBothEndsBroadcastBinaryOp(
|
|
const int ndim,
|
|
const int* A_dims,
|
|
const int* B_dims,
|
|
int* pre,
|
|
int* mid,
|
|
int* nxt,
|
|
bool* broadcast_1st);
|
|
|
|
TORCH_API bool IsBatchTranspose2D(const int ndim, const int* axes);
|
|
|
|
TORCH_API void ComputeTransposeAxesForReduceOp(
|
|
const int num_dims,
|
|
const int num_reduce_axes,
|
|
const int* reduce_axes,
|
|
int* transpose_axes);
|
|
|
|
TORCH_API void
|
|
ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes);
|
|
|
|
template <typename TIndex>
|
|
TORCH_API void ComputeTransposedStrides(
|
|
int ndim,
|
|
const TIndex* dims,
|
|
const int* axes,
|
|
TIndex* strides);
|
|
|
|
} // namespace utils
|
|
|
|
// Calculates ceil(a / b). User must be careful to ensure that there
|
|
// is no overflow or underflow in the calculation.
|
|
template <typename T>
|
|
constexpr T DivUp(const T a, const T b) {
|
|
return (a + b - T(1)) / b;
|
|
}
|
|
|
|
// Rounds a up to the next highest multiple of b. User must be careful
|
|
// to ensure that there is no overflow or underflow in the calculation
|
|
// of divUp.
|
|
template <typename T>
|
|
constexpr T RoundUp(const T a, const T b) {
|
|
return DivUp<T>(a, b) * b;
|
|
}
|
|
|
|
// Returns log2(n) for a positive integer type
|
|
template <typename T>
|
|
constexpr int IntegerLog2(T n, int p = 0) {
|
|
return (n <= 1) ? p : IntegerLog2(n / 2, p + 1);
|
|
}
|
|
|
|
// Returns the next highest power-of-2 for an integer type
|
|
template <typename T>
|
|
constexpr T IntegerNextHighestPowerOf2(T v) {
|
|
return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1)));
|
|
}
|
|
|
|
} // namespace math
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_UTILS_MATH_UTILS_H_
|