mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Windows DLL build with Caffe2 code (#11266)
Summary: This is an experimental build on top of what orionr and mingzhe09088 built. Essentially, the idea is that we will need separate *_API versions for different shared libraries. If this theory is right, I'll try to clean up the design a bit and document it properly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11266 Reviewed By: orionr Differential Revision: D9682942 Pulled By: Yangqing fbshipit-source-id: c79653199e67a1500c9174f39f8b0357324763f3
This commit is contained in:
parent
34c0043aae
commit
68613cf5a2
|
|
@ -404,24 +404,16 @@ else()
|
|||
endif()
|
||||
|
||||
# ---[ Modules
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (NOT (MSVC AND BUILD_SHARED_LIBS))
|
||||
add_subdirectory(modules)
|
||||
endif()
|
||||
|
||||
# ---[ Binaries
|
||||
# Binaries will be built after the Caffe2 main libraries and the modules
|
||||
# are built. For the binaries, they will be linked to the Caffe2 main
|
||||
# libraries, as well as all the modules that are built with Caffe2 (the ones
|
||||
# built in the previous Modules section above).
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (NOT (MSVC AND BUILD_SHARED_LIBS))
|
||||
if (BUILD_BINARY)
|
||||
add_subdirectory(binaries)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_BUILD_MAIN_LIB)
|
||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
|
||||
# define AT_CUDA_API __declspec(dllexport)
|
||||
# else
|
||||
# define AT_CUDA_API __declspec(dllimport)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_BUILD_MAIN_LIB)
|
||||
# if defined(ATen_cuda_EXPORTS) || defined(caffe2_gpu_EXPORTS) || defined(CAFFE2_CUDA_BUILD_MAIN_LIB)
|
||||
# define THC_API THC_EXTERNC __declspec(dllexport)
|
||||
# define THC_CLASS __declspec(dllexport)
|
||||
# else
|
||||
|
|
|
|||
|
|
@ -65,11 +65,6 @@ endif()
|
|||
# ---[ Caffe2 build
|
||||
# Note: the folders that are being commented out have not been properly
|
||||
# addressed yet.
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (MSVC AND BUILD_SHARED_LIBS)
|
||||
add_subdirectory(utils)
|
||||
else()
|
||||
add_subdirectory(proto)
|
||||
add_subdirectory(contrib)
|
||||
add_subdirectory(core)
|
||||
|
|
@ -101,7 +96,6 @@ add_subdirectory(sgd)
|
|||
add_subdirectory(share)
|
||||
# add_subdirectory(test) # todo: use caffe2_gtest_main instead of gtest_main because we will need to call GlobalInit
|
||||
add_subdirectory(transforms)
|
||||
endif()
|
||||
|
||||
# Advanced: if we have white list specified, we will do intersections for all
|
||||
# main lib srcs.
|
||||
|
|
@ -166,16 +160,8 @@ if (FALSE)
|
|||
endif()
|
||||
|
||||
# ---[ List of libraries to link with
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (NOT (MSVC AND BUILD_SHARED_LIBS))
|
||||
add_library(caffe2_protos STATIC $<TARGET_OBJECTS:Caffe2_PROTO>)
|
||||
add_dependencies(caffe2_protos Caffe2_PROTO)
|
||||
else()
|
||||
# Do not include caffe2 or caffe protos, but rather have it only be
|
||||
# a library to attach local protobuf.
|
||||
add_library(caffe2_protos STATIC utils/dummy.cpp)
|
||||
endif()
|
||||
# If we are going to link protobuf locally inside caffe2 libraries, what we will do is
|
||||
# to create a helper static library that always contains libprotobuf source files, and
|
||||
# link the caffe2 related dependent libraries to it.
|
||||
|
|
@ -341,7 +327,7 @@ if(USE_CUDA)
|
|||
# NB: This must be target_compile_definitions, not target_compile_options,
|
||||
# as the latter is not respected by nvcc
|
||||
if (MSVC)
|
||||
target_compile_definitions(caffe2_gpu PRIVATE "-DCAFFE2_BUILD_MAIN_LIB")
|
||||
target_compile_definitions(caffe2_gpu PRIVATE "-DCAFFE2_CUDA_BUILD_MAIN_LIB")
|
||||
endif()
|
||||
|
||||
# Set standard properties on the target
|
||||
|
|
@ -401,9 +387,6 @@ if ($ENV{WERROR})
|
|||
endif()
|
||||
|
||||
# ---[ Test binaries.
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (NOT (MSVC AND BUILD_SHARED_LIBS))
|
||||
if (BUILD_TEST)
|
||||
set(Caffe2_ALL_TEST_SRCS ${Caffe2_CPU_TEST_SRCS})
|
||||
if (USE_CUDA)
|
||||
|
|
@ -439,7 +422,6 @@ if (BUILD_TEST)
|
|||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(__aten_test_dir "test/aten")
|
||||
if (NOT USE_ROCM)
|
||||
|
|
@ -466,9 +448,6 @@ if (NOT USE_ROCM)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (NOT (MSVC AND BUILD_SHARED_LIBS))
|
||||
if (BUILD_PYTHON)
|
||||
# Python site-packages
|
||||
# Get canonical directory for python site packages (relative to install
|
||||
|
|
@ -657,7 +636,6 @@ if (BUILD_PYTHON)
|
|||
install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH}
|
||||
FILES_MATCHING PATTERN "*.py")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Finally, set the Caffe2_MAIN_LIBS variable in the parent scope.
|
||||
set(Caffe2_MAIN_LIBS ${Caffe2_MAIN_LIBS} PARENT_SCOPE)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,42 @@
|
|||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
// Defines CAFFE2_CUDA_EXPORT and CAFFE2_CUDA_IMPORT. On Windows, this corresponds to
|
||||
// different declarations (dllexport and dllimport). On Linux/Mac, it just
|
||||
// resolves to the same "default visibility" setting.
|
||||
#if defined(_MSC_VER)
|
||||
#if defined(CAFFE2_BUILD_SHARED_LIBS)
|
||||
#define CAFFE2_CUDA_EXPORT __declspec(dllexport)
|
||||
#define CAFFE2_CUDA_IMPORT __declspec(dllimport)
|
||||
#else
|
||||
#define CAFFE2_CUDA_EXPORT
|
||||
#define CAFFE2_CUDA_IMPORT
|
||||
#endif
|
||||
#else
|
||||
#if defined(__GNUC__)
|
||||
#define CAFFE2_CUDA_EXPORT __attribute__((__visibility__("default")))
|
||||
#else
|
||||
#define CAFFE2_CUDA_EXPORT
|
||||
#endif
|
||||
#define CAFFE2_CUDA_IMPORT CAFFE2_CUDA_EXPORT
|
||||
#endif
|
||||
|
||||
// CAFFE2_CUDA_API is a macro that, depends on whether you are building the
|
||||
// main caffe2 library or not, resolves to either CAFFE2_CUDA_EXPORT or
|
||||
// CAFFE2_CUDA_IMPORT.
|
||||
//
|
||||
// This is used in e.g. Caffe2's protobuf files: when building the main library,
|
||||
// it is defined as CAFFE2_CUDA_EXPORT to fix a Windows global-variable-in-dll
|
||||
// issue, and for anyone dependent on Caffe2 it will be defined as
|
||||
// CAFFE2_CUDA_IMPORT.
|
||||
|
||||
#ifdef CAFFE2_CUDA_BUILD_MAIN_LIB
|
||||
#define CAFFE2_CUDA_API CAFFE2_CUDA_EXPORT
|
||||
#else
|
||||
#define CAFFE2_CUDA_API CAFFE2_CUDA_IMPORT
|
||||
#endif
|
||||
|
||||
|
||||
// This is a macro defined for cuda fp16 support. In default, cuda fp16 is
|
||||
// supported by NVCC 7.5, but it is also included in the Tegra X1 platform with
|
||||
// a (custom?) NVCC 7.0. As a result, we would normally just check the cuda
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ enum class CudaMemoryPoolType {
|
|||
*
|
||||
* The memory pool is set up during caffe2's global initialization time.
|
||||
*/
|
||||
CAFFE2_API CudaMemoryPoolType GetCudaMemoryPoolType();
|
||||
CAFFE2_CUDA_API CudaMemoryPoolType GetCudaMemoryPoolType();
|
||||
|
||||
/**
|
||||
* A struct to host thread-local cuda objects.
|
||||
|
|
@ -44,7 +44,7 @@ CAFFE2_API CudaMemoryPoolType GetCudaMemoryPoolType();
|
|||
* and deallocating these objects at the thread scope. This class is solely
|
||||
* used inside CUDAContext and should not be used externally.
|
||||
*/
|
||||
class CAFFE2_API ThreadLocalCUDAObjects {
|
||||
class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
|
||||
friend class CUDAContext;
|
||||
|
||||
private:
|
||||
|
|
@ -135,9 +135,9 @@ class CAFFE2_API ThreadLocalCUDAObjects {
|
|||
#endif // CAFFE2_USE_CUDNN
|
||||
};
|
||||
|
||||
CAFFE2_API BaseStaticContext* GetCUDAStaticContext();
|
||||
CAFFE2_CUDA_API BaseStaticContext* GetCUDAStaticContext();
|
||||
|
||||
class CAFFE2_API CUDAContext final : public BaseContext {
|
||||
class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
|
||||
public:
|
||||
// The default cuda context constructor.
|
||||
explicit CUDAContext(const int gpu_id = -1);
|
||||
|
|
@ -332,7 +332,7 @@ inline void CPUContext::CopyBytes<CPUContext, CUDAContext>(
|
|||
* GPU present during runtime, at global initialization time we will set
|
||||
* the CPU memory allocator to allocate pinned memory.
|
||||
*/
|
||||
struct CAFFE2_API PinnedCPUAllocator final : CPUAllocator {
|
||||
struct CAFFE2_CUDA_API PinnedCPUAllocator final : CPUAllocator {
|
||||
PinnedCPUAllocator() {}
|
||||
~PinnedCPUAllocator() override {}
|
||||
std::pair<void*, MemoryDeleter> New(size_t nbytes) override {
|
||||
|
|
@ -381,7 +381,7 @@ struct CAFFE2_API PinnedCPUAllocator final : CPUAllocator {
|
|||
DefaultCPUAllocator baseAllocator_;
|
||||
};
|
||||
|
||||
class CAFFE2_API CUDAStaticContext final : public BaseStaticContext {
|
||||
class CAFFE2_CUDA_API CUDAStaticContext final : public BaseStaticContext {
|
||||
public:
|
||||
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -370,7 +370,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
|||
}
|
||||
|
||||
public:
|
||||
static constexpr int kNoNetPositionSet = -1;
|
||||
static const int kNoNetPositionSet = -1;
|
||||
|
||||
private:
|
||||
Workspace* operator_ws_;
|
||||
|
|
@ -447,7 +447,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
|||
// run on different devices. You should then implement the RunOnDevice()
|
||||
// function.
|
||||
template <class Context>
|
||||
class CAFFE2_API Operator : public OperatorBase {
|
||||
class Operator : public OperatorBase {
|
||||
public:
|
||||
explicit Operator(const OperatorDef& operator_def, Workspace* ws)
|
||||
: OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
|
||||
|
|
@ -835,7 +835,7 @@ CAFFE_DECLARE_REGISTRY(
|
|||
#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
|
||||
CAFFE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
|
||||
#define REGISTER_CPU_OPERATOR(name, ...) \
|
||||
extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
CAFFE2_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name();\
|
||||
static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \
|
||||
CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
} \
|
||||
|
|
@ -854,7 +854,7 @@ CAFFE_DECLARE_REGISTRY(
|
|||
#define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \
|
||||
CAFFE_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__)
|
||||
#define REGISTER_CUDA_OPERATOR(name, ...) \
|
||||
extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
CAFFE2_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \
|
||||
CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
} \
|
||||
|
|
@ -879,7 +879,7 @@ CAFFE_DECLARE_REGISTRY(
|
|||
#define REGISTER_HIP_OPERATOR_CREATOR(key, ...) \
|
||||
CAFFE_REGISTER_CREATOR(HIPOperatorRegistry, key, __VA_ARGS__)
|
||||
#define REGISTER_HIP_OPERATOR(name, ...) \
|
||||
extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
CAFFE2_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_HIP##name() { \
|
||||
CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
|
||||
} \
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
class CAFFE2_API StoreSetOp final : public Operator<CPUContext> {
|
||||
class StoreSetOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
StoreSetOp(const OperatorDef& operator_def, Workspace* ws);
|
||||
bool RunOnDevice() override;
|
||||
|
|
@ -17,7 +17,7 @@ class CAFFE2_API StoreSetOp final : public Operator<CPUContext> {
|
|||
INPUT_TAGS(HANDLER, DATA);
|
||||
};
|
||||
|
||||
class CAFFE2_API StoreGetOp final : public Operator<CPUContext> {
|
||||
class StoreGetOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
StoreGetOp(const OperatorDef& operator_def, Workspace* ws);
|
||||
bool RunOnDevice() override;
|
||||
|
|
@ -29,7 +29,7 @@ class CAFFE2_API StoreGetOp final : public Operator<CPUContext> {
|
|||
OUTPUT_TAGS(DATA);
|
||||
};
|
||||
|
||||
class CAFFE2_API StoreAddOp final : public Operator<CPUContext> {
|
||||
class StoreAddOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
StoreAddOp(const OperatorDef& operator_def, Workspace* ws);
|
||||
bool RunOnDevice() override;
|
||||
|
|
@ -42,7 +42,7 @@ class CAFFE2_API StoreAddOp final : public Operator<CPUContext> {
|
|||
OUTPUT_TAGS(VALUE);
|
||||
};
|
||||
|
||||
class CAFFE2_API StoreWaitOp final : public Operator<CPUContext> {
|
||||
class StoreWaitOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
StoreWaitOp(const OperatorDef& operator_def, Workspace* ws);
|
||||
bool RunOnDevice() override;
|
||||
|
|
|
|||
|
|
@ -60,19 +60,19 @@ REGISTER_CUDA_OPERATOR(MPIAllgather, MPIAllgatherOp<float, CUDAContext>);
|
|||
REGISTER_CUDA_OPERATOR(MPISendTensor, MPISendTensorOp<CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(MPIReceiveTensor, MPIReceiveTensorOp<CUDAContext>);
|
||||
#else
|
||||
REGISTER_CUDA_OPERATOR(MPIBroadcast, GPUFallbackOp<MPIBroadcastOp<CPUContext>>);
|
||||
REGISTER_CUDA_OPERATOR(MPIBroadcast, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIReduce,
|
||||
GPUFallbackOp<MPIReduceOp<float, CPUContext>>);
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIAllgather,
|
||||
GPUFallbackOp<MPIAllgatherOp<float, CPUContext>>);
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPISendTensor,
|
||||
GPUFallbackOp<MPISendTensorOp<CPUContext>>);
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIReceiveTensor,
|
||||
GPUFallbackOp<MPIReceiveTensorOp<CPUContext>, SkipIndices<1, 2>>);
|
||||
GPUFallbackOpEx<SkipIndices<1, 2>>);
|
||||
#endif
|
||||
|
||||
#if CAFFE2_HAS_CUDA_MPI_ALLREDUCE
|
||||
|
|
@ -80,7 +80,7 @@ REGISTER_CUDA_OPERATOR(MPIAllreduce, MPIAllreduceOp<float, CUDAContext>);
|
|||
#else
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIAllreduce,
|
||||
GPUFallbackOp<MPIAllreduceOp<float, CPUContext>>);
|
||||
GPUFallbackOp);
|
||||
#endif
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -454,9 +454,7 @@ REGISTER_CUDA_OPERATOR(MakeTwoClassGradient,
|
|||
MakeTwoClassGradientOp<float, CUDAContext>);
|
||||
|
||||
//TODO(surya) Add full GPU/CUDA support for the CrossEntropyOp
|
||||
REGISTER_CUDA_OPERATOR(CrossEntropy,
|
||||
GPUFallbackOp<CrossEntropyOp<float, CPUContext>>);
|
||||
REGISTER_CUDA_OPERATOR(CrossEntropyGradient,
|
||||
GPUFallbackOp<CrossEntropyGradientOp<float, CPUContext>>);
|
||||
REGISTER_CUDA_OPERATOR(CrossEntropy, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(CrossEntropyGradient, GPUFallbackOp);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -64,8 +64,6 @@ REGISTER_CUDA_OPERATOR(GaussianFill, GaussianFillOp<float, CUDAContext>);
|
|||
REGISTER_CUDA_OPERATOR(XavierFill, XavierFillOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(MSRAFill, MSRAFillOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(RangeFill, RangeFillOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
LengthsRangeFill,
|
||||
GPUFallbackOp<LengthsRangeFillOp<CPUContext>>);
|
||||
REGISTER_CUDA_OPERATOR(LengthsRangeFill, GPUFallbackOp);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -27,20 +27,22 @@ namespace caffe2 {
|
|||
* to register the CPU side, you can create its corresponding GPU operator
|
||||
* (with performance hits of course) via
|
||||
* REGISTER_CUDA_OPERATOR(MyMagic,
|
||||
* GPUFallbackOp<MyMagicOp>);
|
||||
* GPUFallbackOp);
|
||||
* Note that you will need to make sure that the operators actually share the
|
||||
* same name.
|
||||
*
|
||||
* Advanced usage: if you want to have some specific outputs never copied, you
|
||||
* can use the SkipOutputCopy template argument to do that. For example, if
|
||||
* MyMagic produces two outputs and the first output is always going to live on
|
||||
* the CPU, you can do
|
||||
* REGISTER_CUDA_OPERATOR(MyMagic,
|
||||
* GPUFallbackOp<MyMagicOp, SkipIndices<0>>);
|
||||
* GPUFallbackOpEx<SkipIndices<0>>);
|
||||
*/
|
||||
template <class CPUOp, typename SkipOutputCopy = SkipIndices<>>
|
||||
class GPUFallbackOp final : public Operator<CUDAContext> {
|
||||
template <typename SkipOutputCopy>
|
||||
class GPUFallbackOpEx final : public Operator<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
GPUFallbackOp(const OperatorDef& def, Workspace* ws)
|
||||
GPUFallbackOpEx(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<CUDAContext>(def, ws) {
|
||||
CAFFE_ENFORCE_EQ(def.device_option().device_type(), PROTO_CUDA);
|
||||
OperatorDef base_def_(def);
|
||||
|
|
@ -52,7 +54,7 @@ class GPUFallbackOp final : public Operator<CUDAContext> {
|
|||
local_input_blobs_.push_back(local_ws_.CreateBlob(name));
|
||||
CHECK_NOTNULL(local_input_blobs_.back());
|
||||
}
|
||||
base_op_.reset(new CPUOp(base_def_, &local_ws_));
|
||||
base_op_ = CreateOperator(base_def_, &local_ws_);
|
||||
for (const string& name : def.output()) {
|
||||
local_output_blobs_.push_back(local_ws_.GetBlob(name));
|
||||
CHECK_NOTNULL(local_output_blobs_.back());
|
||||
|
|
@ -105,9 +107,11 @@ class GPUFallbackOp final : public Operator<CUDAContext> {
|
|||
Workspace local_ws_;
|
||||
vector<Blob*> local_input_blobs_;
|
||||
vector<Blob*> local_output_blobs_;
|
||||
std::unique_ptr<CPUOp> base_op_;
|
||||
unique_ptr<OperatorBase> base_op_;
|
||||
};
|
||||
|
||||
using GPUFallbackOp = GPUFallbackOpEx<SkipIndices<>>;
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ OPERATOR_SCHEMA(IncrementByOne)
|
|||
.NumInputs(1).NumOutputs(1).AllowInplace({{0, 0}});
|
||||
|
||||
REGISTER_CPU_OPERATOR(IncrementByOne, IncrementByOneOp);
|
||||
REGISTER_CUDA_OPERATOR(IncrementByOne,
|
||||
GPUFallbackOp<IncrementByOneOp>);
|
||||
REGISTER_CUDA_OPERATOR(IncrementByOne, GPUFallbackOp);
|
||||
|
||||
TEST(OperatorFallbackTest, IncrementByOneOp) {
|
||||
OperatorDef op_def = CreateOperatorDef(
|
||||
|
|
|
|||
|
|
@ -5,5 +5,5 @@
|
|||
namespace caffe2 {
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
SparseNormalize,
|
||||
GPUFallbackOp<SparseNormalizeOp<float, CPUContext>>);
|
||||
GPUFallbackOp);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@
|
|||
namespace caffe2 {
|
||||
namespace python {
|
||||
|
||||
REGISTER_CUDA_OPERATOR(Python, GPUFallbackOp<PythonOp<CPUContext, false>>);
|
||||
REGISTER_CUDA_OPERATOR(Python, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
PythonGradient,
|
||||
GPUFallbackOp<PythonGradientOp<CPUContext, false>>);
|
||||
GPUFallbackOp);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(PythonDLPack, PythonOp<CUDAContext, true>);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@
|
|||
namespace caffe2 {
|
||||
namespace python {
|
||||
|
||||
REGISTER_HIP_OPERATOR(Python, GPUFallbackOp<PythonOp<CPUContext, false>>);
|
||||
REGISTER_HIP_OPERATOR(Python, GPUFallbackOp);
|
||||
REGISTER_HIP_OPERATOR(
|
||||
PythonGradient,
|
||||
GPUFallbackOp<PythonGradientOp<CPUContext, false>>);
|
||||
GPUFallbackOp);
|
||||
|
||||
REGISTER_HIP_OPERATOR(PythonDLPack, PythonOp<HIPContext, true>);
|
||||
REGISTER_HIP_OPERATOR(PythonDLPackGradient, PythonGradientOp<HIPContext, true>);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
class CAFFE2_API LarsOp final : public Operator<Context> {
|
||||
class LarsOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
LarsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,3 @@
|
|||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (MSVC AND BUILD_SHARED_LIBS)
|
||||
list(APPEND Caffe2_CPU_SRCS utils/proto_wrap.cc)
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
return()
|
||||
endif()
|
||||
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
utils/proto_wrap.cc
|
||||
utils/proto_utils.cc
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ namespace math {
|
|||
// (transpose) if the argument TransA or TransB is set to CblasNoTrans or
|
||||
// CblasTrans, respectively, for each of A and B.
|
||||
template <>
|
||||
void Gemm<float, CPUContext>(
|
||||
CAFFE2_EXPORT void Gemm<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const CBLAS_TRANSPOSE trans_B,
|
||||
const int M,
|
||||
|
|
@ -134,7 +134,7 @@ void Gemm<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void GemmEx<float, CPUContext>(
|
||||
CAFFE2_EXPORT void GemmEx<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const CBLAS_TRANSPOSE trans_B,
|
||||
const int M,
|
||||
|
|
@ -206,7 +206,7 @@ void GemmEx<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Gemv<float, CPUContext>(
|
||||
CAFFE2_EXPORT void Gemv<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const int M,
|
||||
const int N,
|
||||
|
|
@ -245,7 +245,7 @@ void Gemv<float, CPUContext>(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_DOT(T) \
|
||||
template <> \
|
||||
void Dot<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Dot<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, CPUContext* context) { \
|
||||
*y = ConstEigenVectorMap<T>(a, N).dot(ConstEigenVectorMap<T>(b, N)); \
|
||||
}
|
||||
|
|
@ -254,12 +254,12 @@ CAFFE2_SPECIALIZED_DOT(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_AXPY(T) \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpy<T, CPUContext>( \
|
||||
const int N, const T alpha, const T* x, T* Y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * alpha; \
|
||||
} \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpy<T, CPUContext>( \
|
||||
const int N, const T* alpha, const T* x, T* Y, CPUContext* context) { \
|
||||
EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * (*alpha); \
|
||||
}
|
||||
|
|
@ -268,7 +268,7 @@ CAFFE2_SPECIALIZED_AXPY(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_AXPBY(T) \
|
||||
template <> \
|
||||
void Axpby<T, T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpby<T, T, CPUContext>( \
|
||||
const int N, \
|
||||
const T alpha, \
|
||||
const T* x, \
|
||||
|
|
@ -279,7 +279,7 @@ CAFFE2_SPECIALIZED_AXPY(float)
|
|||
y_arr = y_arr * beta + ConstEigenVectorArrayMap<T>(x, N) * alpha; \
|
||||
} \
|
||||
template <> \
|
||||
void Axpby<T, T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpby<T, T, CPUContext>( \
|
||||
const int N, \
|
||||
const T* alpha, \
|
||||
const T* x, \
|
||||
|
|
@ -295,7 +295,7 @@ CAFFE2_SPECIALIZED_AXPBY(float)
|
|||
#else // CAFFE2_USE_EIGEN_FOR_BLAS
|
||||
|
||||
template <>
|
||||
void Gemm<float, CPUContext>(
|
||||
CAFFE2_EXPORT void Gemm<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const CBLAS_TRANSPOSE trans_B,
|
||||
const int M,
|
||||
|
|
@ -328,7 +328,7 @@ void Gemm<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void GemmEx<float, CPUContext>(
|
||||
CAFFE2_EXPORT void GemmEx<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const CBLAS_TRANSPOSE trans_B,
|
||||
const int M,
|
||||
|
|
@ -361,7 +361,7 @@ void GemmEx<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Gemv<float, CPUContext>(
|
||||
CAFFE2_EXPORT void Gemv<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const int M,
|
||||
const int N,
|
||||
|
|
@ -377,7 +377,7 @@ void Gemv<float, CPUContext>(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_SCALE(TAlpha, TData, prefix) \
|
||||
template <> \
|
||||
void Scale<TAlpha, TData, CPUContext>( \
|
||||
CAFFE2_EXPORT void Scale<TAlpha, TData, CPUContext>( \
|
||||
const int n, \
|
||||
const TAlpha alpha, \
|
||||
const TData* x, \
|
||||
|
|
@ -391,7 +391,7 @@ void Gemv<float, CPUContext>(
|
|||
} \
|
||||
} \
|
||||
template <> \
|
||||
void Scale<TAlpha, TData, CPUContext>( \
|
||||
CAFFE2_EXPORT void Scale<TAlpha, TData, CPUContext>( \
|
||||
const int n, \
|
||||
const TAlpha* alpha, \
|
||||
const TData* x, \
|
||||
|
|
@ -411,7 +411,7 @@ CAFFE2_SPECIALIZED_SCALE(float, double, d)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_DOT(T, prefix) \
|
||||
template <> \
|
||||
void Dot<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Dot<T, CPUContext>( \
|
||||
const int N, const T* a, const T* b, T* y, CPUContext*) { \
|
||||
*y = cblas_##prefix##dot(N, a, 1, b, 1); \
|
||||
}
|
||||
|
|
@ -420,12 +420,12 @@ CAFFE2_SPECIALIZED_DOT(float, s)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_AXPY(T, prefix) \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpy<T, CPUContext>( \
|
||||
const int N, const T alpha, const T* x, T* y, CPUContext*) { \
|
||||
cblas_##prefix##axpy(N, alpha, x, 1, y, 1); \
|
||||
} \
|
||||
template <> \
|
||||
void Axpy<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpy<T, CPUContext>( \
|
||||
const int N, const T* alpha, const T* x, T* y, CPUContext*) { \
|
||||
cblas_##prefix##axpy(N, *alpha, x, 1, y, 1); \
|
||||
}
|
||||
|
|
@ -437,7 +437,7 @@ CAFFE2_SPECIALIZED_AXPY(float, s)
|
|||
#ifdef CAFFE2_USE_MKL
|
||||
#define CAFFE2_SPECIALIZED_AXPBY(T, prefix) \
|
||||
template <> \
|
||||
void Axpby<T, T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpby<T, T, CPUContext>( \
|
||||
const int N, \
|
||||
const T alpha, \
|
||||
const T* x, \
|
||||
|
|
@ -447,7 +447,7 @@ CAFFE2_SPECIALIZED_AXPY(float, s)
|
|||
cblas_##prefix##axpby(N, alpha, x, 1, beta, y, 1); \
|
||||
} \
|
||||
template <> \
|
||||
void Axpby<T, T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpby<T, T, CPUContext>( \
|
||||
const int N, \
|
||||
const T* alpha, \
|
||||
const T* x, \
|
||||
|
|
@ -459,7 +459,7 @@ CAFFE2_SPECIALIZED_AXPY(float, s)
|
|||
#else // CAFFE2_USE_MKL
|
||||
#define CAFFE2_SPECIALIZED_AXPBY(T, prefix) \
|
||||
template <> \
|
||||
void Axpby<T, T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpby<T, T, CPUContext>( \
|
||||
const int N, \
|
||||
const T alpha, \
|
||||
const T* x, \
|
||||
|
|
@ -470,7 +470,7 @@ CAFFE2_SPECIALIZED_AXPY(float, s)
|
|||
cblas_##prefix##axpy(N, alpha, x, 1, y, 1); \
|
||||
} \
|
||||
template <> \
|
||||
void Axpby<T, T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Axpby<T, T, CPUContext>( \
|
||||
const int N, \
|
||||
const T* alpha, \
|
||||
const T* x, \
|
||||
|
|
@ -488,7 +488,7 @@ CAFFE2_SPECIALIZED_AXPBY(float, s)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_SCALE(TAlpha, TData) \
|
||||
template <> \
|
||||
void Scale<TAlpha, TData, CPUContext>( \
|
||||
CAFFE2_EXPORT void Scale<TAlpha, TData, CPUContext>( \
|
||||
const int n, \
|
||||
const TAlpha alpha, \
|
||||
const TData* x, \
|
||||
|
|
@ -498,7 +498,7 @@ CAFFE2_SPECIALIZED_AXPBY(float, s)
|
|||
ConstEigenVectorMap<TData>(x, n) * static_cast<TData>(alpha); \
|
||||
} \
|
||||
template <> \
|
||||
void Scale<TAlpha, TData, CPUContext>( \
|
||||
CAFFE2_EXPORT void Scale<TAlpha, TData, CPUContext>( \
|
||||
const int n, \
|
||||
const TAlpha* alpha, \
|
||||
const TData* x, \
|
||||
|
|
@ -517,7 +517,7 @@ CAFFE2_SPECIALIZED_SCALE(std::int64_t, std::int64_t)
|
|||
#undef CAFFE2_SPECIALIZED_SCALE
|
||||
|
||||
template <>
|
||||
void GemmBatched<float, CPUContext>(
|
||||
CAFFE2_EXPORT void GemmBatched<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const CBLAS_TRANSPOSE trans_B,
|
||||
const int batch_size,
|
||||
|
|
@ -563,7 +563,7 @@ void GemmBatched<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void GemmStridedBatched<float, CPUContext>(
|
||||
CAFFE2_EXPORT void GemmStridedBatched<float, CPUContext>(
|
||||
const CBLAS_TRANSPOSE trans_A,
|
||||
const CBLAS_TRANSPOSE trans_B,
|
||||
const int batch_size,
|
||||
|
|
@ -634,7 +634,7 @@ void GemmStridedBatched<float, CPUContext>(
|
|||
|
||||
#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, OriginalFunc, ...) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Funcname<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
OriginalFunc(N, x, y, ##__VA_ARGS__); \
|
||||
}
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(
|
||||
|
|
@ -683,7 +683,7 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(double, Inv, vdInv)
|
|||
|
||||
#define DELEGATE_SINCOS_FUNCTION(T, OriginalFunc) \
|
||||
template <> \
|
||||
void SinCos<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void SinCos<T, CPUContext>( \
|
||||
const int N, const T* a, T* ys, T* yc, CPUContext*) { \
|
||||
OriginalFunc(N, a, ys, yc); \
|
||||
}
|
||||
|
|
@ -693,7 +693,7 @@ DELEGATE_SINCOS_FUNCTION(double, vdSinCos)
|
|||
|
||||
#define DELEGATE_POWX_FUNCTION(T, OriginalFunc) \
|
||||
template <> \
|
||||
void Powx<T, CPUContext>(const int N, const T* a, T b, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Powx<T, CPUContext>(const int N, const T* a, T b, T* y, CPUContext*) { \
|
||||
OriginalFunc(N, a, b, y); \
|
||||
}
|
||||
DELEGATE_POWX_FUNCTION(float, vsPowx)
|
||||
|
|
@ -702,7 +702,7 @@ DELEGATE_POWX_FUNCTION(double, vdPowx)
|
|||
|
||||
#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Func, FuncImpl) \
|
||||
template <> \
|
||||
void Func<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Func<T, CPUContext>( \
|
||||
const int N, const T* A, const T* B, T* C, CPUContext*) { \
|
||||
FuncImpl(N, A, B, C); \
|
||||
}
|
||||
|
|
@ -720,7 +720,7 @@ DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv)
|
|||
|
||||
#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, expr) \
|
||||
template <> \
|
||||
void Funcname<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Funcname<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorArrayMap<T>(x, N).expr(); \
|
||||
}
|
||||
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp)
|
||||
|
|
@ -750,7 +750,7 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(double, Rsqrt, rsqrt)
|
|||
|
||||
#define DELEGATE_SINCOS_FUNCTION(T) \
|
||||
template <> \
|
||||
void SinCos<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void SinCos<T, CPUContext>( \
|
||||
const int N, const T* x, T* ys, T* yc, CPUContext*) { \
|
||||
EigenVectorMap<T>(ys, N) = ConstEigenVectorArrayMap<T>(x, N).sin(); \
|
||||
EigenVectorMap<T>(yc, N) = ConstEigenVectorArrayMap<T>(x, N).cos(); \
|
||||
|
|
@ -761,7 +761,7 @@ DELEGATE_SINCOS_FUNCTION(double)
|
|||
|
||||
#define DELEGATE_TANH_FUNCTION(T) \
|
||||
template <> \
|
||||
void Tanh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Tanh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
EigenVectorMap<T>(Y, N) = T(1) - \
|
||||
((ConstEigenVectorArrayMap<T>(X, N) * T(2)).exp() + T(1)).inverse() * \
|
||||
T(2); \
|
||||
|
|
@ -772,7 +772,7 @@ DELEGATE_TANH_FUNCTION(double)
|
|||
|
||||
#define DELEGATE_CBRT_FUNCTION(T) \
|
||||
template <> \
|
||||
void Cbrt<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Cbrt<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
std::transform(X, X + N, Y, [](const T x) { return cbrt(x); }); \
|
||||
}
|
||||
DELEGATE_CBRT_FUNCTION(float)
|
||||
|
|
@ -781,7 +781,7 @@ DELEGATE_CBRT_FUNCTION(double)
|
|||
|
||||
#define DELEGATE_POWX_FUNCTION(T) \
|
||||
template <> \
|
||||
void Powx<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Powx<T, CPUContext>( \
|
||||
const int N, const T* a, const T b, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorArrayMap<T>(a, N).pow(b); \
|
||||
}
|
||||
|
|
@ -790,7 +790,7 @@ DELEGATE_POWX_FUNCTION(float)
|
|||
|
||||
#define DELEGATE_SINH_FUNCTION(T) \
|
||||
template <> \
|
||||
void Sinh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Sinh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
ConstEigenVectorArrayMap<T> X_arr(X, N); \
|
||||
EigenVectorMap<T>(Y, N) = (X_arr.exp() - (-X_arr).exp()) / 2; \
|
||||
}
|
||||
|
|
@ -800,7 +800,7 @@ DELEGATE_SINH_FUNCTION(double)
|
|||
|
||||
#define DELEGATE_COSH_FUNCTION(T) \
|
||||
template <> \
|
||||
void Cosh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Cosh<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
ConstEigenVectorArrayMap<T> X_arr(X, N); \
|
||||
EigenVectorMap<T>(Y, N) = (X_arr.exp() + (-X_arr).exp()) / 2; \
|
||||
}
|
||||
|
|
@ -810,7 +810,7 @@ DELEGATE_COSH_FUNCTION(double)
|
|||
|
||||
#define DELEGATE_INV_FUNCTION(T) \
|
||||
template <> \
|
||||
void Inv<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Inv<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorArrayMap<T>(x, N).inverse(); \
|
||||
}
|
||||
DELEGATE_INV_FUNCTION(float)
|
||||
|
|
@ -821,7 +821,7 @@ DELEGATE_INV_FUNCTION(double)
|
|||
|
||||
#define DELEGATE_NEG_FUNCTION(T) \
|
||||
template <> \
|
||||
void Neg<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Neg<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = -ConstEigenVectorMap<T>(x, N); \
|
||||
}
|
||||
DELEGATE_NEG_FUNCTION(float)
|
||||
|
|
@ -832,7 +832,7 @@ DELEGATE_NEG_FUNCTION(std::int64_t)
|
|||
|
||||
#define DELEGATE_SIGN_FUNCTION(T) \
|
||||
template <> \
|
||||
void Sign<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Sign<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorArrayMap<T>(x, N).sign(); \
|
||||
}
|
||||
DELEGATE_SIGN_FUNCTION(float)
|
||||
|
|
@ -843,7 +843,7 @@ DELEGATE_SIGN_FUNCTION(std::int64_t)
|
|||
|
||||
#define DELEGATE_ABS_FUNCTION(T) \
|
||||
template <> \
|
||||
void Abs<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Abs<T, CPUContext>(const int N, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = ConstEigenVectorArrayMap<T>(x, N).abs(); \
|
||||
}
|
||||
#ifndef CAFFE2_USE_MKL
|
||||
|
|
@ -856,7 +856,7 @@ DELEGATE_ABS_FUNCTION(std::int64_t)
|
|||
|
||||
#define DELEGATE_CUBE_FUNCTION(T) \
|
||||
template <> \
|
||||
void Cube<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Cube<T, CPUContext>(const int N, const T* X, T* Y, CPUContext*) { \
|
||||
EigenVectorMap<T>(Y, N) = ConstEigenVectorArrayMap<T>(X, N).cube(); \
|
||||
}
|
||||
DELEGATE_CUBE_FUNCTION(float)
|
||||
|
|
@ -867,7 +867,7 @@ DELEGATE_CUBE_FUNCTION(std::int64_t)
|
|||
|
||||
#define EIGEN_SIMPLE_BINARY_FUNCTION(T, Func, expr) \
|
||||
template <> \
|
||||
void Func<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Func<T, CPUContext>( \
|
||||
const int N, const T* A, const T* B, T* C, CPUContext*) { \
|
||||
EigenVectorMap<T>(C, N) = ConstEigenVectorArrayMap<T>(A, N) \
|
||||
expr ConstEigenVectorArrayMap<T>(B, N); \
|
||||
|
|
@ -905,7 +905,7 @@ DEFINE_SIMPLE_BINARY_FUNCTION(Div, /)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_SET(T) \
|
||||
template <> \
|
||||
void Set<T, CPUContext>(const size_t N, const T alpha, T* Y, CPUContext*) { \
|
||||
CAFFE2_EXPORT void Set<T, CPUContext>(const size_t N, const T alpha, T* Y, CPUContext*) { \
|
||||
if (N == 0) { \
|
||||
return; \
|
||||
} \
|
||||
|
|
@ -932,7 +932,7 @@ CAFFE2_SPECIALIZED_SET(uint16_t);
|
|||
|
||||
#define CAFFE2_SPECIALIZED_REDUCEMIN(T) \
|
||||
template <> \
|
||||
void ReduceMin<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void ReduceMin<T, CPUContext>( \
|
||||
const int N, \
|
||||
const T* x, \
|
||||
T* y, \
|
||||
|
|
@ -945,7 +945,7 @@ CAFFE2_SPECIALIZED_REDUCEMIN(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_REDUCEMAX(T) \
|
||||
template <> \
|
||||
void ReduceMax<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void ReduceMax<T, CPUContext>( \
|
||||
const int N, \
|
||||
const T* x, \
|
||||
T* y, \
|
||||
|
|
@ -991,7 +991,7 @@ struct SquaredL2NormFunctor {
|
|||
|
||||
#define DELEGATE_ROWWISE_REDUCE_FUNCTION(Func, EigenOp) \
|
||||
template <typename T> \
|
||||
void Rowwise##Func( \
|
||||
CAFFE2_EXPORT void Rowwise##Func( \
|
||||
const int rows, const int cols, const T alpha, const T* X, T* Y) { \
|
||||
EigenVectorMap<T>(Y, rows) = \
|
||||
ConstEigenMatrixMap<T>(X, cols, rows).colwise().EigenOp() * alpha; \
|
||||
|
|
@ -1006,7 +1006,7 @@ DELEGATE_ROWWISE_REDUCE_FUNCTION(ReduceL2, norm)
|
|||
|
||||
#define DELEGATE_COLWISE_REDUCE_FUNCTION(Func, EigenOp) \
|
||||
template <typename T> \
|
||||
void Colwise##Func( \
|
||||
CAFFE2_EXPORT void Colwise##Func( \
|
||||
const int rows, const int cols, const T alpha, const T* X, T* Y) { \
|
||||
EigenVectorMap<T>(Y, cols) = \
|
||||
ConstEigenMatrixMap<T>(X, cols, rows).rowwise().EigenOp() * alpha; \
|
||||
|
|
@ -1020,7 +1020,7 @@ DELEGATE_COLWISE_REDUCE_FUNCTION(ReduceL2, norm)
|
|||
#undef DELEGATE_COLWISE_REDUCE_FUNCTION
|
||||
|
||||
template <typename T>
|
||||
void BothEndsReduceMin(
|
||||
CAFFE2_EXPORT void BothEndsReduceMin(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1044,7 +1044,7 @@ void BothEndsReduceMin(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void BothEndsReduceMax(
|
||||
CAFFE2_EXPORT void BothEndsReduceMax(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1066,7 +1066,7 @@ void BothEndsReduceMax(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void BothEndsReduceSum(
|
||||
CAFFE2_EXPORT void BothEndsReduceSum(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1087,7 +1087,7 @@ void BothEndsReduceSum(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void BothEndsReduceMean(
|
||||
CAFFE2_EXPORT void BothEndsReduceMean(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1108,7 +1108,7 @@ void BothEndsReduceMean(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void BothEndsReduceL1(
|
||||
CAFFE2_EXPORT void BothEndsReduceL1(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1135,7 +1135,7 @@ void BothEndsReduceL1(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void BothEndsReduceL2(
|
||||
CAFFE2_EXPORT void BothEndsReduceL2(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1155,7 +1155,7 @@ void BothEndsReduceL2(
|
|||
}
|
||||
|
||||
template <typename T, class Reducer>
|
||||
void ReduceTensor(
|
||||
CAFFE2_EXPORT void ReduceTensor(
|
||||
const int ndim,
|
||||
const int* X_dims,
|
||||
const int* Y_dims,
|
||||
|
|
@ -1183,7 +1183,7 @@ void ReduceTensor(
|
|||
|
||||
#define DELEGATE_REDUCE_FUNCTION(T, Func, reducer, init, is_norm) \
|
||||
template <> \
|
||||
void Func<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Func<T, CPUContext>( \
|
||||
const int num_dims, \
|
||||
const int* dims, \
|
||||
const int num_axes, \
|
||||
|
|
@ -1325,7 +1325,7 @@ DELEGATE_REDUCE_FUNCTION(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_REDUCE_MEAN(T) \
|
||||
template <> \
|
||||
void ReduceMean<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void ReduceMean<T, CPUContext>( \
|
||||
const int num_dims, \
|
||||
const int* dims, \
|
||||
const int num_axes, \
|
||||
|
|
@ -1392,7 +1392,7 @@ CAFFE2_SPECIALIZED_REDUCE_MEAN(double)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_REDUCE_L2(T) \
|
||||
template <> \
|
||||
void ReduceL2<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void ReduceL2<T, CPUContext>( \
|
||||
const int num_dims, \
|
||||
const int* dims, \
|
||||
const int num_axes, \
|
||||
|
|
@ -1462,7 +1462,7 @@ CAFFE2_SPECIALIZED_REDUCE_L2(double)
|
|||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void BroadcastImpl(
|
||||
CAFFE2_EXPORT void BroadcastImpl(
|
||||
const int X_ndim,
|
||||
const int* X_dims,
|
||||
const int Y_ndim,
|
||||
|
|
@ -1495,7 +1495,7 @@ void BroadcastImpl(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_BROADCAST(T) \
|
||||
template <> \
|
||||
void Broadcast<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Broadcast<T, CPUContext>( \
|
||||
const int X_ndim, \
|
||||
const int* X_dims, \
|
||||
const int Y_ndim, \
|
||||
|
|
@ -1515,7 +1515,7 @@ CAFFE2_SPECIALIZED_BROADCAST(double)
|
|||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void RowwiseMoments(
|
||||
CAFFE2_EXPORT void RowwiseMoments(
|
||||
const int rows,
|
||||
const int cols,
|
||||
const T* X,
|
||||
|
|
@ -1530,7 +1530,7 @@ void RowwiseMoments(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ColwiseMoments(
|
||||
CAFFE2_EXPORT void ColwiseMoments(
|
||||
const int rows,
|
||||
const int cols,
|
||||
const T* X,
|
||||
|
|
@ -1545,7 +1545,7 @@ void ColwiseMoments(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void BothEndsMoments(
|
||||
CAFFE2_EXPORT void BothEndsMoments(
|
||||
const int pre,
|
||||
const int mid,
|
||||
const int nxt,
|
||||
|
|
@ -1572,7 +1572,7 @@ void BothEndsMoments(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MomentsImpl(
|
||||
CAFFE2_EXPORT void MomentsImpl(
|
||||
const int num_dims,
|
||||
const int* dims,
|
||||
const int num_axes,
|
||||
|
|
@ -1640,7 +1640,7 @@ void MomentsImpl(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_MOMENTS(T) \
|
||||
template <> \
|
||||
void Moments<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Moments<T, CPUContext>( \
|
||||
const int num_dims, \
|
||||
const int* dims, \
|
||||
const int num_axes, \
|
||||
|
|
@ -1657,7 +1657,7 @@ CAFFE2_SPECIALIZED_MOMENTS(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_ROWWISEMAX(T) \
|
||||
template <> \
|
||||
void RowwiseMax<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RowwiseMax<T, CPUContext>( \
|
||||
const int N, const int D, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, N) = \
|
||||
ConstEigenMatrixMap<T>(x, D, N).colwise().maxCoeff(); \
|
||||
|
|
@ -1667,7 +1667,7 @@ CAFFE2_SPECIALIZED_ROWWISEMAX(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_COLWISEMAX(T) \
|
||||
template <> \
|
||||
void ColwiseMax<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void ColwiseMax<T, CPUContext>( \
|
||||
const int N, const int D, const T* x, T* y, CPUContext*) { \
|
||||
EigenVectorMap<T>(y, D) = \
|
||||
ConstEigenMatrixMap<T>(x, D, N).rowwise().maxCoeff(); \
|
||||
|
|
@ -1677,7 +1677,7 @@ CAFFE2_SPECIALIZED_COLWISEMAX(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_ELEMWISEMAX(T) \
|
||||
template <> \
|
||||
void ElemwiseMax<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void ElemwiseMax<T, CPUContext>( \
|
||||
const int N, const T* x, const T* y, T* z, CPUContext* /*context*/) { \
|
||||
std::transform(x, x + N, y, z, [](const T& x_i, const T& y_i) { \
|
||||
return std::max(x_i, y_i); \
|
||||
|
|
@ -1688,7 +1688,7 @@ CAFFE2_SPECIALIZED_ELEMWISEMAX(float)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_MAXIMUM(T) \
|
||||
template <> \
|
||||
void Maximum<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Maximum<T, CPUContext>( \
|
||||
const int N, const float alpha, const T* x, T* y, CPUContext* context) { \
|
||||
std::transform( \
|
||||
x, x + N, y, [&alpha](const T& x_i) { return std::max(x_i, alpha); }); \
|
||||
|
|
@ -1701,7 +1701,7 @@ CAFFE2_SPECIALIZED_MAXIMUM(float)
|
|||
|
||||
#define DELEGATE_EIGEN_2D_BROADCAST_1ST_BINARY_FUNCTION(T, Func, expr) \
|
||||
template <> \
|
||||
void Rowwise##Func<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void Rowwise##Func<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1718,7 +1718,7 @@ CAFFE2_SPECIALIZED_MAXIMUM(float)
|
|||
} \
|
||||
} \
|
||||
template <> \
|
||||
void Colwise##Func<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void Colwise##Func<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1738,7 +1738,7 @@ CAFFE2_SPECIALIZED_MAXIMUM(float)
|
|||
|
||||
#define DELEGATE_EIGEN_2D_BROADCAST_2ND_BINARY_FUNCTION(T, Func, expr) \
|
||||
template <> \
|
||||
void Rowwise##Func<T, CPUContext, false>( \
|
||||
CAFFE2_EXPORT void Rowwise##Func<T, CPUContext, false>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1755,7 +1755,7 @@ CAFFE2_SPECIALIZED_MAXIMUM(float)
|
|||
} \
|
||||
} \
|
||||
template <> \
|
||||
void Colwise##Func<T, CPUContext, false>( \
|
||||
CAFFE2_EXPORT void Colwise##Func<T, CPUContext, false>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1791,7 +1791,7 @@ DEFINE_EIGEN_2D_BROADCAST_BINARY_FUNCTION(Mul, *)
|
|||
|
||||
#define DEFINE_EIGEN_2D_BROADCAST_SUB_FUNCTION(T) \
|
||||
template <> \
|
||||
void RowwiseSub<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void RowwiseSub<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1803,7 +1803,7 @@ DEFINE_EIGEN_2D_BROADCAST_BINARY_FUNCTION(Mul, *)
|
|||
ConstEigenVectorArrayMap<T>(A, cols); \
|
||||
} \
|
||||
template <> \
|
||||
void ColwiseSub<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void ColwiseSub<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1825,7 +1825,7 @@ DEFINE_EIGEN_2D_BROADCAST_SUB_FUNCTION(std::int64_t)
|
|||
|
||||
#define DEFINE_EIGEN_2D_BROADCAST_DIV_FUNCTION(T) \
|
||||
template <> \
|
||||
void RowwiseDiv<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void RowwiseDiv<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1837,7 +1837,7 @@ DEFINE_EIGEN_2D_BROADCAST_SUB_FUNCTION(std::int64_t)
|
|||
ConstEigenVectorArrayMap<T>(A, cols); \
|
||||
} \
|
||||
template <> \
|
||||
void ColwiseDiv<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void ColwiseDiv<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -1861,7 +1861,7 @@ DELEGATE_EIGEN_2D_BROADCAST_2ND_BINARY_FUNCTION(std::int64_t, Div, /)
|
|||
#undef DELEGATE_EIGEN_2D_BROADCAST_2ND_BINARY_FUNCTION
|
||||
|
||||
template <>
|
||||
void Not<bool, CPUContext>(
|
||||
CAFFE2_EXPORT void Not<bool, CPUContext>(
|
||||
const int N,
|
||||
const bool* x,
|
||||
bool* y,
|
||||
|
|
@ -1876,7 +1876,7 @@ void Not<bool, CPUContext>(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(T) \
|
||||
template <> \
|
||||
void AddStripedBatch( \
|
||||
CAFFE2_EXPORT void AddStripedBatch( \
|
||||
const int N, \
|
||||
const T* first, \
|
||||
T* y, \
|
||||
|
|
@ -1894,7 +1894,7 @@ CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(float);
|
|||
namespace {
|
||||
|
||||
template <typename TIn, typename TOut, class BinaryOperator, bool kBroadcast1st>
|
||||
void RowwiseBinaryOp(
|
||||
CAFFE2_EXPORT void RowwiseBinaryOp(
|
||||
const int rows,
|
||||
const int cols,
|
||||
const BinaryOperator& op,
|
||||
|
|
@ -1912,7 +1912,7 @@ void RowwiseBinaryOp(
|
|||
}
|
||||
|
||||
template <typename TIn, typename TOut, class BinaryOperator, bool kBroadcast1st>
|
||||
void ColwiseBinaryOp(
|
||||
CAFFE2_EXPORT void ColwiseBinaryOp(
|
||||
const int rows,
|
||||
const int cols,
|
||||
const BinaryOperator& op,
|
||||
|
|
@ -1930,7 +1930,7 @@ void ColwiseBinaryOp(
|
|||
}
|
||||
|
||||
template <typename TIn, typename TOut, class BinaryOperator>
|
||||
void BroadcastBinaryOpImpl(
|
||||
CAFFE2_EXPORT void BroadcastBinaryOpImpl(
|
||||
const int ndim,
|
||||
const int* A_dims,
|
||||
const int* B_dims,
|
||||
|
|
@ -1954,7 +1954,7 @@ void BroadcastBinaryOpImpl(
|
|||
|
||||
#define DELEGATE_1D_BINARY_FUNCTION(TIn, TOut, Func, Op) \
|
||||
template <> \
|
||||
void Func<TIn, CPUContext>( \
|
||||
CAFFE2_EXPORT void Func<TIn, CPUContext>( \
|
||||
const int N, const TIn* A, const TIn* B, TOut* C, CPUContext*) { \
|
||||
std::transform(A, A + N, B, C, Op<TIn>()); \
|
||||
}
|
||||
|
|
@ -1994,7 +1994,7 @@ DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
|
||||
#define DELEGATE_2D_BROADCAST_BINARY_FUNCTION(TIn, TOut, Func, Op) \
|
||||
template <> \
|
||||
void Rowwise##Func<TIn, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void Rowwise##Func<TIn, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const TIn* A, \
|
||||
|
|
@ -2004,7 +2004,7 @@ DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
RowwiseBinaryOp<TIn, TOut, Op<TIn>, true>(rows, cols, Op<TIn>(), A, B, C); \
|
||||
} \
|
||||
template <> \
|
||||
void Rowwise##Func<TIn, CPUContext, false>( \
|
||||
CAFFE2_EXPORT void Rowwise##Func<TIn, CPUContext, false>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const TIn* A, \
|
||||
|
|
@ -2015,7 +2015,7 @@ DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
rows, cols, Op<TIn>(), A, B, C); \
|
||||
} \
|
||||
template <> \
|
||||
void Colwise##Func<TIn, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void Colwise##Func<TIn, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const TIn* A, \
|
||||
|
|
@ -2025,7 +2025,7 @@ DEFINE_1D_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
ColwiseBinaryOp<TIn, TOut, Op<TIn>, true>(rows, cols, Op<TIn>(), A, B, C); \
|
||||
} \
|
||||
template <> \
|
||||
void Colwise##Func<TIn, CPUContext, false>( \
|
||||
CAFFE2_EXPORT void Colwise##Func<TIn, CPUContext, false>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const TIn* A, \
|
||||
|
|
@ -2071,7 +2071,7 @@ DEFINE_2D_BROADCAST_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
|
||||
#define DEFINE_2D_BROADCAST_1ST_DIV_FUNCTION(T) \
|
||||
template <> \
|
||||
void RowwiseDiv<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void RowwiseDiv<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -2082,7 +2082,7 @@ DEFINE_2D_BROADCAST_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
rows, cols, std::divides<T>(), A, B, C); \
|
||||
} \
|
||||
template <> \
|
||||
void ColwiseDiv<T, CPUContext, true>( \
|
||||
CAFFE2_EXPORT void ColwiseDiv<T, CPUContext, true>( \
|
||||
const int rows, \
|
||||
const int cols, \
|
||||
const T* A, \
|
||||
|
|
@ -2098,7 +2098,7 @@ DEFINE_2D_BROADCAST_1ST_DIV_FUNCTION(std::int64_t)
|
|||
|
||||
#define DELEGATE_BROADCAST_BINARY_FUNCTION(TIn, TOut, Func, Op) \
|
||||
template <> \
|
||||
void Func<TIn, CPUContext>( \
|
||||
CAFFE2_EXPORT void Func<TIn, CPUContext>( \
|
||||
const int A_ndim, \
|
||||
const int* A_dims, \
|
||||
const int B_ndim, \
|
||||
|
|
@ -2241,7 +2241,7 @@ DEFINE_BROADCAST_BITWISE_BINARY_FUNCTION(BitwiseXor, std::bit_xor)
|
|||
|
||||
#define CAFFE2_RAND_UNIFORM_REAL(T) \
|
||||
template <> \
|
||||
void RandUniform<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RandUniform<T, CPUContext>( \
|
||||
const size_t n, const T a, const T b, T* r, CPUContext* context) { \
|
||||
std::uniform_real_distribution<T> distribution(a, b); \
|
||||
for (size_t i = 0; i < n; ++i) { \
|
||||
|
|
@ -2254,7 +2254,7 @@ CAFFE2_RAND_UNIFORM_REAL(double);
|
|||
|
||||
#define CAFFE2_RAND_UNIFORM_CHAR(T) \
|
||||
template <> \
|
||||
void RandUniform<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RandUniform<T, CPUContext>( \
|
||||
const size_t n, const T a, const T b, T* r, CPUContext* context) { \
|
||||
std::uniform_int_distribution<short> distribution((short)a, (short)b); \
|
||||
for (size_t i = 0; i < n; ++i) { \
|
||||
|
|
@ -2267,7 +2267,7 @@ CAFFE2_RAND_UNIFORM_CHAR(uint8_t);
|
|||
|
||||
#define CAFFE2_RAND_UNIFORM_INT(T) \
|
||||
template <> \
|
||||
void RandUniform<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RandUniform<T, CPUContext>( \
|
||||
const size_t n, const T a, const T b, T* r, CPUContext* context) { \
|
||||
std::uniform_int_distribution<T> distribution(a, b); \
|
||||
for (size_t i = 0; i < n; ++i) { \
|
||||
|
|
@ -2293,7 +2293,7 @@ CAFFE2_RAND_UNIFORM_INT(uint64_t);
|
|||
// each value.
|
||||
#define CAFFE2_RAND_FIXED_SUM(T) \
|
||||
template <> \
|
||||
void RandFixedSum<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RandFixedSum<T, CPUContext>( \
|
||||
const size_t n, \
|
||||
const T a, \
|
||||
const T b, \
|
||||
|
|
@ -2387,7 +2387,7 @@ Ind_t generate_stack_distance(
|
|||
}
|
||||
|
||||
template <class Type, class Val_t, class Ind_t, class Context_t, bool cdf_app>
|
||||
void generate_trace_lru(
|
||||
CAFFE2_EXPORT void generate_trace_lru(
|
||||
std::vector<Ind_t>& uni_ref,
|
||||
std::vector<Ind_t>& cum_val,
|
||||
std::vector<Val_t>& cum_dis,
|
||||
|
|
@ -2464,7 +2464,7 @@ void generate_trace_lru(
|
|||
// case we need to know the table id, to sample from the right distribution
|
||||
#define CAFFE2_RAND_SYNTHETIC_DATA(T) \
|
||||
template <> \
|
||||
void RandSyntheticData<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RandSyntheticData<T, CPUContext>( \
|
||||
const size_t n, const T a, const T b, T* r, CPUContext* context) { \
|
||||
/* unique memory references */ \
|
||||
std::vector<int> mem_ref = {1, 2, 3, 4, 5, 6}; \
|
||||
|
|
@ -2503,7 +2503,7 @@ CAFFE2_RAND_SYNTHETIC_DATA(uint64_t);
|
|||
|
||||
#define CAFFE2_SPECIALIZED_RAND_UNIFORM_UNIQUE(T) \
|
||||
template <> \
|
||||
void RandUniformUnique<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void RandUniformUnique<T, CPUContext>( \
|
||||
const size_t n, \
|
||||
const T a, \
|
||||
const T b, \
|
||||
|
|
@ -2516,7 +2516,7 @@ CAFFE2_RAND_SYNTHETIC_DATA(uint64_t);
|
|||
std::unordered_set<T> avoid_set(n); \
|
||||
if (m) { \
|
||||
avoid_set.insert(avoid, avoid + m); \
|
||||
CAFFE_ENFORCE_EQ(m, avoid_set.size(), "Avoid should be unique"); \
|
||||
CAFFE_ENFORCE_EQ(m, avoid_set.size(), "ACAFFE2_EXPORT void should be unique"); \
|
||||
} \
|
||||
std::uniform_int_distribution<T> distribution(a, b); \
|
||||
T v = 0; \
|
||||
|
|
@ -2534,7 +2534,7 @@ CAFFE2_SPECIALIZED_RAND_UNIFORM_UNIQUE(int64_t);
|
|||
#undef CAFFE2_SPECIALIZED_RAND_UNIFORM_UNIQUE
|
||||
|
||||
template <>
|
||||
void RandGaussian<float, CPUContext>(
|
||||
CAFFE2_EXPORT void RandGaussian<float, CPUContext>(
|
||||
const size_t n,
|
||||
const float mean,
|
||||
const float std,
|
||||
|
|
@ -2548,7 +2548,7 @@ void RandGaussian<float, CPUContext>(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_SUM(T) \
|
||||
template <> \
|
||||
void Sum<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Sum<T, CPUContext>( \
|
||||
const int N, \
|
||||
const T* x, \
|
||||
T* y, \
|
||||
|
|
@ -2564,7 +2564,7 @@ CAFFE2_SPECIALIZED_SUM(int64_t);
|
|||
#undef CAFFE2_SPECIALIZED_SUM
|
||||
|
||||
template <>
|
||||
void SumSqr<float, CPUContext>(
|
||||
CAFFE2_EXPORT void SumSqr<float, CPUContext>(
|
||||
const int N,
|
||||
const float* x,
|
||||
float* y,
|
||||
|
|
@ -2574,7 +2574,7 @@ void SumSqr<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Select<float, CPUContext>(
|
||||
CAFFE2_EXPORT void Select<float, CPUContext>(
|
||||
const int N,
|
||||
const int D,
|
||||
const float* x,
|
||||
|
|
@ -2588,7 +2588,7 @@ void Select<float, CPUContext>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void CopyMatrix<CPUContext>(
|
||||
CAFFE2_EXPORT void CopyMatrix<CPUContext>(
|
||||
const size_t itemsize,
|
||||
const int M,
|
||||
const int N,
|
||||
|
|
@ -2631,7 +2631,7 @@ void CopyMatrix<CPUContext>(
|
|||
|
||||
#define DELEGATE_COPY_MATRIX_FUNCTION(T, Func) \
|
||||
template <> \
|
||||
void CopyMatrix<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void CopyMatrix<T, CPUContext>( \
|
||||
const int M, \
|
||||
const int N, \
|
||||
const T* A, \
|
||||
|
|
@ -2642,7 +2642,7 @@ void CopyMatrix<CPUContext>(
|
|||
Func('R', 'N', M, N, T(1), A, lda, B, ldb); \
|
||||
} \
|
||||
template <> \
|
||||
void CopyMatrix<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void CopyMatrix<T, CPUContext>( \
|
||||
const int M, \
|
||||
const int N, \
|
||||
const T* A, \
|
||||
|
|
@ -2673,7 +2673,7 @@ DELEGATE_COPY_MATRIX_FUNCTION(double, mkl_domatcopy)
|
|||
|
||||
#define CAFFE2_SPECIALIZED_COPY_MATRIX(T) \
|
||||
template <> \
|
||||
void CopyMatrix<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void CopyMatrix<T, CPUContext>( \
|
||||
const int M, \
|
||||
const int N, \
|
||||
const T* A, \
|
||||
|
|
@ -2703,7 +2703,7 @@ DELEGATE_COPY_MATRIX_FUNCTION(double, mkl_domatcopy)
|
|||
} \
|
||||
} \
|
||||
template <> \
|
||||
void CopyMatrix<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void CopyMatrix<T, CPUContext>( \
|
||||
const int M, \
|
||||
const int N, \
|
||||
const T* A, \
|
||||
|
|
@ -2742,7 +2742,7 @@ CAFFE2_SPECIALIZED_COPY_MATRIX(std::uint16_t)
|
|||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void Im2ColZeroPaddingAndNoDilationNCHW(
|
||||
CAFFE2_EXPORT void Im2ColZeroPaddingAndNoDilationNCHW(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -2789,7 +2789,7 @@ void Im2ColZeroPaddingAndNoDilationNCHW(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Col2ImZeroPaddingAndNoDilationNCHW(
|
||||
CAFFE2_EXPORT void Col2ImZeroPaddingAndNoDilationNCHW(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -2825,7 +2825,7 @@ void Col2ImZeroPaddingAndNoDilationNCHW(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Im2ColZeroPaddingAndNoDilationNHWC(
|
||||
CAFFE2_EXPORT void Im2ColZeroPaddingAndNoDilationNHWC(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -2850,7 +2850,7 @@ void Im2ColZeroPaddingAndNoDilationNHWC(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Col2ImZeroPaddingAndNoDilationNHWC(
|
||||
CAFFE2_EXPORT void Col2ImZeroPaddingAndNoDilationNHWC(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -2877,7 +2877,7 @@ void Col2ImZeroPaddingAndNoDilationNHWC(
|
|||
}
|
||||
|
||||
template <typename T, bool kCol2Im>
|
||||
void Im2ColNdNCHWImpl(
|
||||
CAFFE2_EXPORT void Im2ColNdNCHWImpl(
|
||||
const int N,
|
||||
const int img_size,
|
||||
const int col_size,
|
||||
|
|
@ -2933,7 +2933,7 @@ void Im2ColNdNCHWImpl(
|
|||
} // namespace
|
||||
|
||||
template <>
|
||||
void Im2ColNd<float, CPUContext, StorageOrder::NCHW>(
|
||||
CAFFE2_EXPORT void Im2ColNd<float, CPUContext, StorageOrder::NCHW>(
|
||||
const int N,
|
||||
const int img_size,
|
||||
const int col_size,
|
||||
|
|
@ -2961,7 +2961,7 @@ void Im2ColNd<float, CPUContext, StorageOrder::NCHW>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Col2ImNd<float, CPUContext, StorageOrder::NCHW>(
|
||||
CAFFE2_EXPORT void Col2ImNd<float, CPUContext, StorageOrder::NCHW>(
|
||||
const int N,
|
||||
const int img_size,
|
||||
const int col_size,
|
||||
|
|
@ -2989,7 +2989,7 @@ void Col2ImNd<float, CPUContext, StorageOrder::NCHW>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Im2Col<float, CPUContext, StorageOrder::NCHW>(
|
||||
CAFFE2_EXPORT void Im2Col<float, CPUContext, StorageOrder::NCHW>(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -3055,7 +3055,7 @@ void Im2Col<float, CPUContext, StorageOrder::NCHW>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Im2Col<float, CPUContext, StorageOrder::NHWC>(
|
||||
CAFFE2_EXPORT void Im2Col<float, CPUContext, StorageOrder::NHWC>(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -3155,7 +3155,7 @@ void Im2Col<float, CPUContext, StorageOrder::NHWC>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Col2Im<float, CPUContext, StorageOrder::NCHW>(
|
||||
CAFFE2_EXPORT void Col2Im<float, CPUContext, StorageOrder::NCHW>(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -3222,7 +3222,7 @@ void Col2Im<float, CPUContext, StorageOrder::NCHW>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void Col2Im<float, CPUContext, StorageOrder::NHWC>(
|
||||
CAFFE2_EXPORT void Col2Im<float, CPUContext, StorageOrder::NHWC>(
|
||||
const int C,
|
||||
const int H,
|
||||
const int W,
|
||||
|
|
@ -3318,7 +3318,7 @@ void Col2Im<float, CPUContext, StorageOrder::NHWC>(
|
|||
}
|
||||
|
||||
template <>
|
||||
void BiasCHW<float, CPUContext>(
|
||||
CAFFE2_EXPORT void BiasCHW<float, CPUContext>(
|
||||
const float* bias,
|
||||
const float* /*bias_multiplier*/,
|
||||
const int bias_channels,
|
||||
|
|
@ -3403,7 +3403,7 @@ void BiasCHW<float, CPUContext>(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_COPYVECTOR(T) \
|
||||
template <> \
|
||||
void CopyVector<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void CopyVector<T, CPUContext>( \
|
||||
const int N, const T* src, T* dst, CPUContext* /*context*/) { \
|
||||
if (src != dst && N > 0) { \
|
||||
memcpy(dst, src, sizeof(T) * N); \
|
||||
|
|
@ -3616,7 +3616,7 @@ void TransposeCPUImpl(
|
|||
|
||||
#define CAFFE2_SPECIALIZED_TRANSPOSE(T) \
|
||||
template <> \
|
||||
void Transpose<T, CPUContext>( \
|
||||
CAFFE2_EXPORT void Transpose<T, CPUContext>( \
|
||||
const int ndim, \
|
||||
const int* dims, \
|
||||
const int* axes, \
|
||||
|
|
|
|||
|
|
@ -757,9 +757,6 @@ if (USE_NNAPI AND NOT ANDROID)
|
|||
caffe2_update_option(USE_NNAPI OFF)
|
||||
endif()
|
||||
|
||||
# TODO(orionr): Enable all of this for Windows DLL when we
|
||||
# can figure out how to get it to build
|
||||
if (NOT (MSVC AND BUILD_SHARED_LIBS))
|
||||
if (NOT BUILD_ATEN_MOBILE)
|
||||
if (CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS aten_op_header_gen)
|
||||
|
|
@ -769,7 +766,6 @@ if (NOT BUILD_ATEN_MOBILE)
|
|||
include_directories(${PROJECT_BINARY_DIR}/caffe2/contrib/aten)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (USE_ZSTD)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS libzstd_static)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,13 @@ string(
|
|||
content
|
||||
"${content}")
|
||||
|
||||
string(
|
||||
REPLACE
|
||||
"PROTOBUF_CONSTEXPR"
|
||||
""
|
||||
content
|
||||
"${content}")
|
||||
|
||||
foreach(ns ${NAMESPACES})
|
||||
# Insert "const ::std::string& GetEmptyStringAlreadyInited();" within
|
||||
# the namespace and make sure we only do it once in the file. Unfortunately
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user