mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NOOP][clangformat][codemod] Enable CLANGFORMAT for caffe2/caffe2/* (#67624)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67624 Test Plan: Visual inspection. Sandcastle. Reviewed By: malfet Differential Revision: D31986628 fbshipit-source-id: c872bded7325997a2945dbf5d4d052628dcb3659
This commit is contained in:
parent
e86a5a3a1a
commit
06d1be2447
|
|
@ -7,14 +7,14 @@
|
|||
#include <cuda.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
#define NVRTC_CHECK(condition) \
|
||||
do { \
|
||||
nvrtcResult result = condition; \
|
||||
if (result != NVRTC_SUCCESS) { \
|
||||
LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< nvrtcGetErrorString(result); \
|
||||
} \
|
||||
} while(0)
|
||||
#define NVRTC_CHECK(condition) \
|
||||
do { \
|
||||
nvrtcResult result = condition; \
|
||||
if (result != NVRTC_SUCCESS) { \
|
||||
LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \
|
||||
<< nvrtcGetErrorString(result); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
|
@ -39,15 +39,14 @@ class CudaRTCFunction {
|
|||
VLOG(1) << "function src:\n" << src;
|
||||
// Actually do the compiling.
|
||||
nvrtcProgram prog;
|
||||
NVRTC_CHECK(nvrtcCreateProgram(
|
||||
&prog, src.c_str(), nullptr, 0, nullptr, nullptr));
|
||||
NVRTC_CHECK(
|
||||
nvrtcCreateProgram(&prog, src.c_str(), nullptr, 0, nullptr, nullptr));
|
||||
// Compile the program.
|
||||
// TODO(Yangqing): how to find the current gpu architecture instead of hard
|
||||
// coding it?
|
||||
const char *nvrtc_opts[] = {"--gpu-architecture=compute_35",
|
||||
"--use_fast_math"};
|
||||
nvrtcResult compile_result = nvrtcCompileProgram(
|
||||
prog, 2, nvrtc_opts);
|
||||
const char* nvrtc_opts[] = {
|
||||
"--gpu-architecture=compute_35", "--use_fast_math"};
|
||||
nvrtcResult compile_result = nvrtcCompileProgram(prog, 2, nvrtc_opts);
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
|
|
@ -74,21 +73,33 @@ class CudaRTCFunction {
|
|||
}
|
||||
|
||||
template <typename... Args>
|
||||
void Launch(unsigned int gx, unsigned int gy, unsigned int gz,
|
||||
unsigned int bx, unsigned int by, unsigned int bz,
|
||||
unsigned int shared_mem, cudaStream_t stream,
|
||||
Args... args) {
|
||||
void Launch(
|
||||
unsigned int gx,
|
||||
unsigned int gy,
|
||||
unsigned int gz,
|
||||
unsigned int bx,
|
||||
unsigned int by,
|
||||
unsigned int bz,
|
||||
unsigned int shared_mem,
|
||||
cudaStream_t stream,
|
||||
Args... args) {
|
||||
CAFFE_ENFORCE(
|
||||
module_loaded_, "Cannot call Launch before a module is loaded.");
|
||||
void * args_voidp[] = {&args...};
|
||||
void* args_voidp[] = {&args...};
|
||||
CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
|
||||
kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, args_voidp, 0));
|
||||
}
|
||||
|
||||
void LaunchEx(unsigned int gx, unsigned int gy, unsigned int gz,
|
||||
unsigned int bx, unsigned int by, unsigned int bz,
|
||||
unsigned int shared_mem, cudaStream_t stream,
|
||||
void** extra) {
|
||||
void LaunchEx(
|
||||
unsigned int gx,
|
||||
unsigned int gy,
|
||||
unsigned int gz,
|
||||
unsigned int bx,
|
||||
unsigned int by,
|
||||
unsigned int bz,
|
||||
unsigned int shared_mem,
|
||||
cudaStream_t stream,
|
||||
void** extra) {
|
||||
CAFFE_ENFORCE(
|
||||
module_loaded_, "Cannot call Launch before a module is loaded.");
|
||||
CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
|
||||
|
|
@ -115,6 +126,6 @@ inline std::string GetUniqueName() {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
} // namepsace caffe2
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_
|
||||
#endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@
|
|||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
class ElementwiseRTCFunction
|
||||
: public CudaRTCFunction<ElementwiseRTCFunction> {
|
||||
class ElementwiseRTCFunction : public CudaRTCFunction<ElementwiseRTCFunction> {
|
||||
public:
|
||||
ElementwiseRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
|
||||
|
||||
|
|
@ -22,22 +21,21 @@ class ElementwiseRTCFunction
|
|||
string name_;
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
string ElementwiseRTCFunction::GetSource(
|
||||
int input_size, int output_size,
|
||||
int input_size,
|
||||
int output_size,
|
||||
const string command_string) {
|
||||
std::stringstream ss;
|
||||
ss << "extern \"C\" __global__ void " << name_ <<
|
||||
"(const size_t nthreads, \n";
|
||||
ss << "extern \"C\" __global__ void " << name_
|
||||
<< "(const size_t nthreads, \n";
|
||||
// Insert the parameter list.
|
||||
int remain_params = input_size + output_size;
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
ss << "const float* in" << i
|
||||
<< ((remain_params--) ? ", \n" : "");
|
||||
ss << "const float* in" << i << ((remain_params--) ? ", \n" : "");
|
||||
}
|
||||
for (int i = 0; i < output_size; ++i) {
|
||||
ss << "float* out" << i
|
||||
<< ((remain_params--) ? ", \n" : "");
|
||||
ss << "float* out" << i << ((remain_params--) ? ", \n" : "");
|
||||
}
|
||||
ss << ") {\n"
|
||||
"for (int index = blockIdx.x * blockDim.x + threadIdx.x;\n"
|
||||
|
|
@ -46,7 +44,7 @@ string ElementwiseRTCFunction::GetSource(
|
|||
<< "}\n}";
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* A GPU operator that can generate limited elementwise operations.
|
||||
|
|
@ -75,17 +73,17 @@ class ElementwiseRTCOp final : public Operator<CUDAContext> {
|
|||
public:
|
||||
ElementwiseRTCOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CUDAContext>(operator_def, ws) {
|
||||
const string src = OperatorBase::GetSingleArgument<string>(
|
||||
"rtc_src", "");
|
||||
const string src = OperatorBase::GetSingleArgument<string>("rtc_src", "");
|
||||
CAFFE_ENFORCE(src.size(), "Op should have a non-zero source code size.");
|
||||
func_.Compile(InputSize(), OutputSize(), src);
|
||||
}
|
||||
~ElementwiseRTCOp() override {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
static_assert(sizeof(void*) == sizeof(size_t),
|
||||
"The argbuffer relies on the assumption that void* and "
|
||||
"size_t have the same size.");
|
||||
static_assert(
|
||||
sizeof(void*) == sizeof(size_t),
|
||||
"The argbuffer relies on the assumption that void* and "
|
||||
"size_t have the same size.");
|
||||
vector<size_t> argBuffer_vec(InputSize() + OutputSize() + 1);
|
||||
size_t* argBuffer = argBuffer_vec.data();
|
||||
CAFFE_ENFORCE(
|
||||
|
|
@ -102,10 +100,11 @@ class ElementwiseRTCOp final : public Operator<CUDAContext> {
|
|||
}
|
||||
size_t argBufferSize = sizeof(argBuffer);
|
||||
void* config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, argBuffer,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &argBufferSize,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER,
|
||||
argBuffer,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE,
|
||||
&argBufferSize,
|
||||
CU_LAUNCH_PARAM_END};
|
||||
func_.LaunchEx(
|
||||
CAFFE_GET_BLOCKS(Input(0).numel()),
|
||||
1,
|
||||
|
|
@ -127,4 +126,4 @@ namespace {
|
|||
REGISTER_CUDA_OPERATOR_WITH_ENGINE(ElementwiseRTC, NVRTC, ElementwiseRTCOp);
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@
|
|||
|
||||
#include "caffe2/core/common_gpu.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/pool_op.h"
|
||||
#include "caffe2/cuda_rtc/common_rtc.h"
|
||||
#include "caffe2/operators/pool_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
class AveragePool {};
|
||||
class MaxPool {};
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -98,7 +98,6 @@ __global__ void %s(
|
|||
}
|
||||
)";
|
||||
|
||||
|
||||
class MaxPoolRTCFunction : public CudaRTCFunction<MaxPoolRTCFunction> {
|
||||
public:
|
||||
MaxPoolRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
|
||||
|
|
@ -132,7 +131,6 @@ class MaxPoolGradientRTCFunction
|
|||
string name_;
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
string MaxPoolRTCFunction::GetSource(
|
||||
const int output_size,
|
||||
|
|
@ -149,9 +147,22 @@ string MaxPoolRTCFunction::GetSource(
|
|||
const int pad_l) {
|
||||
char buffer[65536];
|
||||
int nbytes = snprintf(
|
||||
buffer, 65536, kMaxPoolForwardNCHWSource, name_.c_str(), output_size,
|
||||
channels, height, width, pooled_height, pooled_width, kernel_h, kernel_w,
|
||||
stride_h, stride_w, pad_t, pad_l);
|
||||
buffer,
|
||||
65536,
|
||||
kMaxPoolForwardNCHWSource,
|
||||
name_.c_str(),
|
||||
output_size,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_t,
|
||||
pad_l);
|
||||
DCHECK_GE(nbytes, 0);
|
||||
DCHECK_LT(nbytes, 65536);
|
||||
return string(buffer);
|
||||
|
|
@ -174,16 +185,29 @@ string MaxPoolGradientRTCFunction::GetSource(
|
|||
const int pad_l) {
|
||||
char buffer[65536];
|
||||
int nbytes = snprintf(
|
||||
buffer, 65536, kMaxPoolBackwardNCHWSource, name_.c_str(), output_size,
|
||||
num, channels, height, width, pooled_height, pooled_width, kernel_h,
|
||||
kernel_w, stride_h, stride_w, pad_t, pad_l);
|
||||
buffer,
|
||||
65536,
|
||||
kMaxPoolBackwardNCHWSource,
|
||||
name_.c_str(),
|
||||
output_size,
|
||||
num,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_t,
|
||||
pad_l);
|
||||
DCHECK_GE(nbytes, 0);
|
||||
DCHECK_LT(nbytes, 65536);
|
||||
return string(buffer);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace
|
||||
|
||||
class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
|
||||
public:
|
||||
|
|
@ -196,7 +220,8 @@ class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
|
|||
|
||||
bool RunOnDeviceWithOrderNCHW() override {
|
||||
auto& X = Input(0);
|
||||
auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
|
||||
auto output_sizes =
|
||||
ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
|
||||
auto* Y = Output(0, output_sizes, at::dtype<float>());
|
||||
|
||||
if (input_dims_ != X.sizes()) {
|
||||
|
|
@ -307,7 +332,9 @@ class MaxPoolGradientRTCOp final : public ConvPoolOpBase<CUDAContext> {
|
|||
|
||||
namespace {
|
||||
REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPool, NVRTC, MaxPoolRTCOp);
|
||||
REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPoolGradient, NVRTC,
|
||||
MaxPoolGradientRTCOp);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
|
||||
MaxPoolGradient,
|
||||
NVRTC,
|
||||
MaxPoolGradientRTCOp);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -6,4 +6,4 @@ REGISTER_CPU_OPERATOR(CreateDB, CreateDBOp<CPUContext>);
|
|||
OPERATOR_SCHEMA(CreateDB).NumInputs(0).NumOutputs(1);
|
||||
|
||||
NO_GRADIENT(CreateDB);
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/flags.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "leveldb/db.h"
|
||||
#include "leveldb/write_batch.h"
|
||||
|
||||
|
|
@ -19,13 +19,27 @@ class LevelDBCursor : public Cursor {
|
|||
SeekToFirst();
|
||||
}
|
||||
~LevelDBCursor() override {}
|
||||
void Seek(const string& key) override { iter_->Seek(key); }
|
||||
bool SupportsSeek() override { return true; }
|
||||
void SeekToFirst() override { iter_->SeekToFirst(); }
|
||||
void Next() override { iter_->Next(); }
|
||||
string key() override { return iter_->key().ToString(); }
|
||||
string value() override { return iter_->value().ToString(); }
|
||||
bool Valid() override { return iter_->Valid(); }
|
||||
void Seek(const string& key) override {
|
||||
iter_->Seek(key);
|
||||
}
|
||||
bool SupportsSeek() override {
|
||||
return true;
|
||||
}
|
||||
void SeekToFirst() override {
|
||||
iter_->SeekToFirst();
|
||||
}
|
||||
void Next() override {
|
||||
iter_->Next();
|
||||
}
|
||||
string key() override {
|
||||
return iter_->key().ToString();
|
||||
}
|
||||
string value() override {
|
||||
return iter_->value().ToString();
|
||||
}
|
||||
bool Valid() override {
|
||||
return iter_->Valid();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<leveldb::Iterator> iter_;
|
||||
|
|
@ -47,8 +61,7 @@ class LevelDBTransaction : public Transaction {
|
|||
leveldb::Status status = db_->Write(leveldb::WriteOptions(), batch_.get());
|
||||
batch_.reset(new leveldb::WriteBatch());
|
||||
CAFFE_ENFORCE(
|
||||
status.ok(),
|
||||
"Failed to write batch to leveldb. ", status.ToString());
|
||||
status.ok(), "Failed to write batch to leveldb. ", status.ToString());
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -71,12 +84,17 @@ class LevelDB : public DB {
|
|||
leveldb::Status status = leveldb::DB::Open(options, source, &db_temp);
|
||||
CAFFE_ENFORCE(
|
||||
status.ok(),
|
||||
"Failed to open leveldb ", source, ". ", status.ToString());
|
||||
"Failed to open leveldb ",
|
||||
source,
|
||||
". ",
|
||||
status.ToString());
|
||||
db_.reset(db_temp);
|
||||
VLOG(1) << "Opened leveldb " << source;
|
||||
}
|
||||
|
||||
void Close() override { db_.reset(); }
|
||||
void Close() override {
|
||||
db_.reset();
|
||||
}
|
||||
unique_ptr<Cursor> NewCursor() override {
|
||||
return make_unique<LevelDBCursor>(db_.get());
|
||||
}
|
||||
|
|
@ -92,5 +110,5 @@ REGISTER_CAFFE2_DB(LevelDB, LevelDB);
|
|||
// For lazy-minded, one can also call with lower-case name.
|
||||
REGISTER_CAFFE2_DB(leveldb, LevelDB);
|
||||
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include "lmdb.h" // NOLINT
|
||||
#include "lmdb.h" // NOLINT
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <direct.h>
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
namespace caffe2 {
|
||||
namespace db {
|
||||
|
||||
constexpr size_t LMDB_MAP_SIZE = 1099511627776; // 1 TB
|
||||
constexpr size_t LMDB_MAP_SIZE = 1099511627776; // 1 TB
|
||||
|
||||
inline void MDB_CHECK(int mdb_status) {
|
||||
CAFFE_ENFORCE_EQ(mdb_status, MDB_SUCCESS, mdb_strerror(mdb_status));
|
||||
|
|
@ -22,8 +22,7 @@ inline void MDB_CHECK(int mdb_status) {
|
|||
|
||||
class LMDBCursor : public Cursor {
|
||||
public:
|
||||
explicit LMDBCursor(MDB_env* mdb_env)
|
||||
: mdb_env_(mdb_env), valid_(false) {
|
||||
explicit LMDBCursor(MDB_env* mdb_env) : mdb_env_(mdb_env), valid_(false) {
|
||||
MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, MDB_RDONLY, &mdb_txn_));
|
||||
MDB_CHECK(mdb_dbi_open(mdb_txn_, NULL, 0, &mdb_dbi_));
|
||||
MDB_CHECK(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_));
|
||||
|
|
@ -43,8 +42,8 @@ class LMDBCursor : public Cursor {
|
|||
// a key of 16k size should be enough? I am not sure though.
|
||||
mdb_key_.mv_size = key.size();
|
||||
mdb_key_.mv_data = const_cast<char*>(key.c_str());
|
||||
int mdb_status = mdb_cursor_get(
|
||||
mdb_cursor_, &mdb_key_, &mdb_value_, MDB_SET_RANGE);
|
||||
int mdb_status =
|
||||
mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_SET_RANGE);
|
||||
if (mdb_status == MDB_NOTFOUND) {
|
||||
valid_ = false;
|
||||
} else {
|
||||
|
|
@ -53,22 +52,30 @@ class LMDBCursor : public Cursor {
|
|||
}
|
||||
}
|
||||
|
||||
bool SupportsSeek() override { return true; }
|
||||
bool SupportsSeek() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
void SeekToFirst() override { SeekLMDB(MDB_FIRST); }
|
||||
void SeekToFirst() override {
|
||||
SeekLMDB(MDB_FIRST);
|
||||
}
|
||||
|
||||
void Next() override { SeekLMDB(MDB_NEXT); }
|
||||
void Next() override {
|
||||
SeekLMDB(MDB_NEXT);
|
||||
}
|
||||
|
||||
string key() override {
|
||||
return string(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
|
||||
}
|
||||
|
||||
string value() override {
|
||||
return string(static_cast<const char*>(mdb_value_.mv_data),
|
||||
mdb_value_.mv_size);
|
||||
return string(
|
||||
static_cast<const char*>(mdb_value_.mv_data), mdb_value_.mv_size);
|
||||
}
|
||||
|
||||
bool Valid() override { return valid_; }
|
||||
bool Valid() override {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
private:
|
||||
void SeekLMDB(MDB_cursor_op op) {
|
||||
|
|
@ -91,8 +98,7 @@ class LMDBCursor : public Cursor {
|
|||
|
||||
class LMDBTransaction final : public Transaction {
|
||||
public:
|
||||
explicit LMDBTransaction(MDB_env* mdb_env)
|
||||
: mdb_env_(mdb_env) {
|
||||
explicit LMDBTransaction(MDB_env* mdb_env) : mdb_env_(mdb_env) {
|
||||
MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, 0, &mdb_txn_));
|
||||
MDB_CHECK(mdb_dbi_open(mdb_txn_, NULL, 0, &mdb_dbi_));
|
||||
}
|
||||
|
|
@ -171,5 +177,5 @@ void LMDBTransaction::Put(const string& key, string&& value) {
|
|||
REGISTER_CAFFE2_DB(LMDB, LMDB);
|
||||
REGISTER_CAFFE2_DB(lmdb, LMDB);
|
||||
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
#include <unordered_set>
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace db {
|
||||
|
||||
class ProtoDBCursor : public Cursor {
|
||||
public:
|
||||
explicit ProtoDBCursor(const TensorProtos* proto)
|
||||
: proto_(proto), iter_(0) {}
|
||||
explicit ProtoDBCursor(const TensorProtos* proto) : proto_(proto), iter_(0) {}
|
||||
// NOLINTNEXTLINE(modernize-use-equals-default)
|
||||
~ProtoDBCursor() override {}
|
||||
|
||||
|
|
@ -18,14 +17,22 @@ class ProtoDBCursor : public Cursor {
|
|||
CAFFE_THROW("ProtoDB is not designed to support seeking.");
|
||||
}
|
||||
|
||||
void SeekToFirst() override { iter_ = 0; }
|
||||
void Next() override { ++iter_; }
|
||||
string key() override { return proto_->protos(iter_).name(); }
|
||||
string value() override {
|
||||
return
|
||||
SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor");
|
||||
void SeekToFirst() override {
|
||||
iter_ = 0;
|
||||
}
|
||||
void Next() override {
|
||||
++iter_;
|
||||
}
|
||||
string key() override {
|
||||
return proto_->protos(iter_).name();
|
||||
}
|
||||
string value() override {
|
||||
return SerializeAsString_EnforceCheck(
|
||||
proto_->protos(iter_), "ProtoDBCursor");
|
||||
}
|
||||
bool Valid() override {
|
||||
return iter_ < proto_->protos_size();
|
||||
}
|
||||
bool Valid() override { return iter_ < proto_->protos_size(); }
|
||||
|
||||
private:
|
||||
const TensorProtos* proto_;
|
||||
|
|
@ -108,5 +115,5 @@ REGISTER_CAFFE2_DB(ProtoDB, ProtoDB);
|
|||
// For lazy-minded, one can also call with lower-case name.
|
||||
REGISTER_CAFFE2_DB(protodb, ProtoDB);
|
||||
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread> // NOLINT
|
||||
#include <thread> // NOLINT
|
||||
|
||||
#include "caffe2/core/db.h"
|
||||
#include "caffe2/utils/zmq_helper.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/utils/zmq_helper.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace db {
|
||||
|
|
@ -13,12 +13,13 @@ namespace db {
|
|||
class ZmqDBCursor : public Cursor {
|
||||
public:
|
||||
explicit ZmqDBCursor(const string& source)
|
||||
: source_(source), socket_(ZMQ_PULL),
|
||||
prefetched_(false), finalize_(false) {
|
||||
: source_(source),
|
||||
socket_(ZMQ_PULL),
|
||||
prefetched_(false),
|
||||
finalize_(false) {
|
||||
socket_.Connect(source_);
|
||||
// Start prefetching thread.
|
||||
prefetch_thread_.reset(
|
||||
new std::thread([this] { this->Prefetch(); }));
|
||||
prefetch_thread_.reset(new std::thread([this] { this->Prefetch(); }));
|
||||
// obtain the first value.
|
||||
Next();
|
||||
}
|
||||
|
|
@ -35,27 +36,35 @@ class ZmqDBCursor : public Cursor {
|
|||
void Seek(const string& /*key*/) override { /* do nothing */
|
||||
}
|
||||
|
||||
void SeekToFirst() override { /* do nothing */ }
|
||||
void SeekToFirst() override { /* do nothing */
|
||||
}
|
||||
|
||||
void Next() override {
|
||||
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
|
||||
while (!prefetched_) consumer_.wait(lock);
|
||||
while (!prefetched_)
|
||||
consumer_.wait(lock);
|
||||
key_ = prefetch_key_;
|
||||
value_ = prefetch_value_;
|
||||
prefetched_ = false;
|
||||
producer_.notify_one();
|
||||
}
|
||||
|
||||
string key() override { return key_; }
|
||||
string value() override { return value_; }
|
||||
bool Valid() override { return true; }
|
||||
string key() override {
|
||||
return key_;
|
||||
}
|
||||
string value() override {
|
||||
return value_;
|
||||
}
|
||||
bool Valid() override {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void Prefetch() {
|
||||
while (!finalize_) {
|
||||
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
|
||||
while (prefetched_) producer_.wait(lock);
|
||||
while (prefetched_)
|
||||
producer_.wait(lock);
|
||||
if (finalize_) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -86,8 +95,7 @@ class ZmqDBCursor : public Cursor {
|
|||
|
||||
class ZmqDB : public DB {
|
||||
public:
|
||||
ZmqDB(const string& source, Mode mode)
|
||||
: DB(source, mode), source_(source) {
|
||||
ZmqDB(const string& source, Mode mode) : DB(source, mode), source_(source) {
|
||||
CAFFE_ENFORCE(mode == READ, "ZeroMQ DB only supports read mode.");
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +109,7 @@ class ZmqDB : public DB {
|
|||
|
||||
unique_ptr<Transaction> NewTransaction() override {
|
||||
CAFFE_THROW("ZeroMQ DB does not support writing with a transaction.");
|
||||
return nullptr; // dummy placeholder to suppress old compiler warnings.
|
||||
return nullptr; // dummy placeholder to suppress old compiler warnings.
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -112,5 +120,5 @@ REGISTER_CAFFE2_DB(ZmqDB, ZmqDB);
|
|||
// For lazy-minded, one can also call with lower-case name.
|
||||
REGISTER_CAFFE2_DB(zmqdb, ZmqDB);
|
||||
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
} // namespace db
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -181,4 +181,4 @@ void FileStoreHandler::wait(
|
|||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -9,19 +9,19 @@ class TORCH_API FileStoreHandler : public StoreHandler {
|
|||
explicit FileStoreHandler(const std::string& path, const std::string& prefix);
|
||||
virtual ~FileStoreHandler();
|
||||
|
||||
void set(const std::string& name, const std::string& data) override;
|
||||
void set(const std::string& name, const std::string& data) override;
|
||||
|
||||
virtual std::string get(
|
||||
const std::string& name,
|
||||
const std::chrono::milliseconds& timeout = kDefaultTimeout) override;
|
||||
|
||||
int64_t add(const std::string& name, int64_t value) override;
|
||||
int64_t add(const std::string& name, int64_t value) override;
|
||||
|
||||
bool deleteKey(const std::string& key) override;
|
||||
bool deleteKey(const std::string& key) override;
|
||||
|
||||
int64_t getNumKeys() override;
|
||||
int64_t getNumKeys() override;
|
||||
|
||||
bool check(const std::vector<std::string>& names) override;
|
||||
bool check(const std::vector<std::string>& names) override;
|
||||
|
||||
virtual void wait(
|
||||
const std::vector<std::string>& names,
|
||||
|
|
|
|||
|
|
@ -15,19 +15,19 @@ class TORCH_API RedisStoreHandler : public StoreHandler {
|
|||
explicit RedisStoreHandler(std::string& host, int port, std::string& prefix);
|
||||
virtual ~RedisStoreHandler();
|
||||
|
||||
void set(const std::string& name, const std::string& data) override;
|
||||
void set(const std::string& name, const std::string& data) override;
|
||||
|
||||
virtual std::string get(
|
||||
const std::string& name,
|
||||
const std::chrono::milliseconds& timeout = kDefaultTimeout) override;
|
||||
|
||||
int64_t add(const std::string& name, int64_t value) override;
|
||||
int64_t add(const std::string& name, int64_t value) override;
|
||||
|
||||
int64_t getNumKeys() override;
|
||||
int64_t getNumKeys() override;
|
||||
|
||||
bool deleteKey(const std::string& key) override;
|
||||
bool deleteKey(const std::string& key) override;
|
||||
|
||||
bool check(const std::vector<std::string>& names) override;
|
||||
bool check(const std::vector<std::string>& names) override;
|
||||
|
||||
virtual void wait(
|
||||
const std::vector<std::string>& names,
|
||||
|
|
|
|||
|
|
@ -67,8 +67,7 @@ class TORCH_API StoreHandler {
|
|||
/*
|
||||
* The backing store is no longer available. It may have been deleted.
|
||||
*/
|
||||
struct TORCH_API StoreHandlerNotAvailableException
|
||||
: public std::runtime_error {
|
||||
struct TORCH_API StoreHandlerNotAvailableException : public std::runtime_error {
|
||||
explicit StoreHandlerNotAvailableException(const std::string& msg)
|
||||
: std::runtime_error(msg) {}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -124,4 +124,4 @@ either as an input blob with blob names or as an argument.
|
|||
.Arg("blob_names", "names of the blobs to wait for (optional)")
|
||||
.Input(0, "handler", "unique_ptr<StoreHandler>")
|
||||
.Input(1, "names", "names of the blobs to wait for (optional)");
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -52,4 +52,4 @@ class StoreWaitOp final : public Operator<CPUContext> {
|
|||
|
||||
INPUT_TAGS(HANDLER);
|
||||
};
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@
|
|||
namespace caffe2 {
|
||||
|
||||
REGISTER_CPU_OPERATOR(FC_Decomp, FullyConnectedOpDecomp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(FCGradient_Decomp,
|
||||
FullyConnectedDecompGradientOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
FCGradient_Decomp,
|
||||
FullyConnectedDecompGradientOp<float, CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(FC_Decomp).NumInputs(4).NumOutputs(1);
|
||||
OPERATOR_SCHEMA(FCGradient_Decomp).NumInputs(4).NumOutputs(3, 4);
|
||||
|
|
@ -31,10 +32,11 @@ class GetFCDecompGradient : public GradientMakerBase {
|
|||
CAFFE_ENFORCE_EQ(def_.input_size(), 4);
|
||||
// TODO(wyiming): Check whether it is right? Let's move fast first.
|
||||
return SingleGradientDef(
|
||||
"FCGradient_Decomp", "",
|
||||
"FCGradient_Decomp",
|
||||
"",
|
||||
vector<string>{I(0), I(1), I(2), GO(0)},
|
||||
vector<string>{GI(1), GI(2), GI(3), GI(0)});
|
||||
}
|
||||
};
|
||||
REGISTER_GRADIENT(FC_Decomp, GetFCDecompGradient);
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ namespace caffe2 {
|
|||
* W(N * K) = U(N * middle) * trans(V(K * middle))
|
||||
* */
|
||||
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
|
||||
template <typename T, class Context, class Engine=DefaultEngine>
|
||||
template <typename T, class Context, class Engine = DefaultEngine>
|
||||
class FullyConnectedOpDecomp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
|
@ -44,15 +44,15 @@ class FullyConnectedOpDecomp final : public Operator<Context> {
|
|||
const auto& V = Input(2);
|
||||
const auto& b = Input(3);
|
||||
|
||||
//auto* buffer_ptr = Output(1);
|
||||
// auto* buffer_ptr = Output(1);
|
||||
// Size M * middle;
|
||||
//auto& multi_buffer_ = *buffer_ptr;
|
||||
// auto& multi_buffer_ = *buffer_ptr;
|
||||
CAFFE_ENFORCE_GE(X.dim(), 1);
|
||||
CAFFE_ENFORCE_GE(U.dim(), 2);
|
||||
CAFFE_ENFORCE_GE(V.dim(), 2);
|
||||
if (X.dim() > 2 || U.dim() > 2 || V.dim() > 2) {
|
||||
VLOG(1) << "Using legacy support for arbitrary input and weight "
|
||||
"dimensions.";
|
||||
"dimensions.";
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(b.dim(), 1);
|
||||
// batch size
|
||||
|
|
@ -79,25 +79,51 @@ class FullyConnectedOpDecomp final : public Operator<Context> {
|
|||
T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
|
||||
// X * V * tans(U)
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, middle, K, 1, X.template data<T>(),
|
||||
V.template data<T>(), 0, multi_buffer_data,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
middle,
|
||||
K,
|
||||
1,
|
||||
X.template data<T>(),
|
||||
V.template data<T>(),
|
||||
0,
|
||||
multi_buffer_data,
|
||||
&context_);
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasTrans, M, N, middle, 1, multi_buffer_data,
|
||||
U.template data<T>(), 0, Y->template mutable_data<T>(),
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
M,
|
||||
N,
|
||||
middle,
|
||||
1,
|
||||
multi_buffer_data,
|
||||
U.template data<T>(),
|
||||
0,
|
||||
Y->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Add bias term
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M, reshape and fill it with one.
|
||||
bias_multiplier_.Resize(M);
|
||||
math::Set<T, Context>(
|
||||
M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
|
||||
M,
|
||||
static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, N, 1, 1,
|
||||
bias_multiplier_.template data<T>(), b.template data<T>(), 1,
|
||||
Y->template mutable_data<T>(), &context_);
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
N,
|
||||
1,
|
||||
1,
|
||||
bias_multiplier_.template data<T>(),
|
||||
b.template data<T>(),
|
||||
1,
|
||||
Y->template mutable_data<T>(),
|
||||
&context_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -106,7 +132,7 @@ class FullyConnectedOpDecomp final : public Operator<Context> {
|
|||
Tensor multi_buffer_{Context::GetDeviceType()};
|
||||
};
|
||||
|
||||
template <typename T, class Context, class Engine=DefaultEngine>
|
||||
template <typename T, class Context, class Engine = DefaultEngine>
|
||||
class FullyConnectedDecompGradientOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
|
@ -148,41 +174,75 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
du_buffer_.Resize(N, middle);
|
||||
T* du_buffer_data = du_buffer_.template mutable_data<T>();
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, middle, K, 1,
|
||||
X.template data<T>(), V.template data<T>(),
|
||||
0, du_buffer_data,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
middle,
|
||||
K,
|
||||
1,
|
||||
X.template data<T>(),
|
||||
V.template data<T>(),
|
||||
0,
|
||||
du_buffer_data,
|
||||
&context_);
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasTrans, CblasNoTrans, N, middle, M, 1,
|
||||
dY.template data<T>(), du_buffer_data,
|
||||
0, dU->template mutable_data<T>(),
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
N,
|
||||
middle,
|
||||
M,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
du_buffer_data,
|
||||
0,
|
||||
dU->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Compute dV
|
||||
// first compute dY * U
|
||||
dv_buffer_.Resize(M, middle);
|
||||
T* dv_buffer_data = dv_buffer_.template mutable_data<T>();
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, middle, N, 1,
|
||||
dY.template data<T>(), U.template data<T>(),
|
||||
0, dv_buffer_data,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
middle,
|
||||
N,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
U.template data<T>(),
|
||||
0,
|
||||
dv_buffer_data,
|
||||
&context_);
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasTrans, CblasNoTrans, K, middle, M, 1,
|
||||
dY.template data<T>(), du_buffer_data,
|
||||
0, dV->template mutable_data<T>(),
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
K,
|
||||
middle,
|
||||
M,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
du_buffer_data,
|
||||
0,
|
||||
dV->template mutable_data<T>(),
|
||||
&context_);
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M, reshape and fill it with one.
|
||||
bias_multiplier_.Resize(M);
|
||||
math::Set<T, Context>(
|
||||
M, static_cast<T>(1),
|
||||
M,
|
||||
static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
// Compute dB
|
||||
math::Gemv<T, Context>(
|
||||
CblasTrans, M, N, 1, dY.template data<T>(),
|
||||
bias_multiplier_.template data<T>(), 0,
|
||||
CblasTrans,
|
||||
M,
|
||||
N,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
bias_multiplier_.template data<T>(),
|
||||
0,
|
||||
db->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Compute dX if necessary.
|
||||
|
|
@ -191,14 +251,28 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
dx_buffer_.Resize(M, middle);
|
||||
T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, middle, N, 1,
|
||||
dY.template data<T>(), U.template data<T>(),
|
||||
0, dx_buffer_data,
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
middle,
|
||||
N,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
U.template data<T>(),
|
||||
0,
|
||||
dx_buffer_data,
|
||||
&context_);
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasTrans, M, K, middle, 1,
|
||||
dx_buffer_data, V.template data<T>(),
|
||||
0, dX->template mutable_data<T>(),
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
M,
|
||||
K,
|
||||
middle,
|
||||
1,
|
||||
dx_buffer_data,
|
||||
V.template data<T>(),
|
||||
0,
|
||||
dX->template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
|
||||
|
|
@ -212,6 +286,6 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
Tensor dx_buffer_{Context::GetDeviceType()};
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@
|
|||
namespace caffe2 {
|
||||
|
||||
REGISTER_CUDA_OPERATOR(FC_Decomp, FullyConnectedOpDecomp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(FCGradient_Decomp,
|
||||
FullyConnectedDecompGradientOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
FCGradient_Decomp,
|
||||
FullyConnectedDecompGradientOp<float, CUDAContext>);
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -20,25 +20,29 @@ namespace caffe2 {
|
|||
namespace {
|
||||
|
||||
REGISTER_CPU_OPERATOR(FC_Prune, FullyConnectedOpPrune<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(FCGradient_Prune,
|
||||
FullyConnectedPruneGradientOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
FCGradient_Prune,
|
||||
FullyConnectedPruneGradientOp<float, CPUContext>);
|
||||
/* 8 Inputs:
|
||||
* X W Mask bias Ag_dw Mask_seq thres comp_lb
|
||||
* */
|
||||
OPERATOR_SCHEMA(FC_Prune).NumInputs(8).NumOutputs(1, 2);
|
||||
OPERATOR_SCHEMA(FCGradient_Prune).NumInputs(8).NumOutputs(6, 7)
|
||||
.AllowInplace({{1, 2}, {2, 3}, {4, 4}, {5, 5}});
|
||||
OPERATOR_SCHEMA(FCGradient_Prune)
|
||||
.NumInputs(8)
|
||||
.NumOutputs(6, 7)
|
||||
.AllowInplace({{1, 2}, {2, 3}, {4, 4}, {5, 5}});
|
||||
|
||||
class GetFCPruneGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
CAFFE_ENFORCE_EQ(def_.input_size(), 8);
|
||||
return SingleGradientDef(
|
||||
"FCGradient_Prune", "",
|
||||
"FCGradient_Prune",
|
||||
"",
|
||||
vector<string>{I(0), I(1), I(2), GO(0), I(4), I(5), I(6), I(7)},
|
||||
vector<string>{GI(1), GI(3), I(1), I(2), I(4), I(5), GI(0)});
|
||||
}
|
||||
};
|
||||
REGISTER_GRADIENT(FC_Prune, GetFCPruneGradient);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -23,337 +23,384 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
namespace {
|
||||
|
||||
template<int N>
|
||||
using Shape = std::array<int, N>;
|
||||
template <int N>
|
||||
using Shape = std::array<int, N>;
|
||||
|
||||
template<int N>
|
||||
const std::vector<int64_t>& shape(Shape<N> vs) {
|
||||
static thread_local std::vector<int64_t> cache;
|
||||
cache.resize(vs.size());
|
||||
for (auto i = 0; i < vs.size(); ++i) {
|
||||
cache[i] = vs[i];
|
||||
}
|
||||
return cache;
|
||||
}
|
||||
template <int N>
|
||||
const std::vector<int64_t>& shape(Shape<N> vs) {
|
||||
static thread_local std::vector<int64_t> cache;
|
||||
cache.resize(vs.size());
|
||||
for (auto i = 0; i < vs.size(); ++i) {
|
||||
cache[i] = vs[i];
|
||||
}
|
||||
return cache;
|
||||
}
|
||||
|
||||
inline const std::vector<int64_t>& shape(int i) {
|
||||
return shape<1>(Shape<1>({i}));
|
||||
inline const std::vector<int64_t>& shape(int i) {
|
||||
return shape<1>(Shape<1>({i}));
|
||||
}
|
||||
|
||||
inline const std::vector<int64_t>& shape(int i, int j) {
|
||||
return shape<2>(Shape<2>({i, j}));
|
||||
}
|
||||
|
||||
template <typename T, class Context>
|
||||
void MaskMatrix(const T* mask, T* mat, int M, int N);
|
||||
|
||||
template <typename T, class Context>
|
||||
void MaskMatrix_Inc(T* mask_seq, T* mat, int M, int N, int seq_len, T target);
|
||||
|
||||
template <typename T, class Context>
|
||||
void AggrDW(T* ag_dw, const T* dw, int N, int K, Context* context);
|
||||
|
||||
template <typename T>
|
||||
int MatrixCompare_LT(const T* mat, float thres, T* mask_seq, int M, int N);
|
||||
|
||||
// TODO(wyiming): write an incremental Mask
|
||||
// Incremental Mask: only give the new mask positions;
|
||||
// Assuming that weights masked will not be mask again;
|
||||
// The incremental mask can also be used to update mask matrix;
|
||||
// But this will include template for bool and float;
|
||||
template <>
|
||||
void MaskMatrix<float, CPUContext>(
|
||||
const float* mask,
|
||||
float* mat,
|
||||
int M,
|
||||
int N) {
|
||||
int offset = 0;
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
mat[offset] = mask[offset] ? mat[offset] : 0;
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline const std::vector<int64_t>& shape(int i, int j) {
|
||||
return shape<2>(Shape<2>({i, j}));
|
||||
template <>
|
||||
void MaskMatrix_Inc<float, CPUContext>(
|
||||
float* mask_seq,
|
||||
float* mat,
|
||||
int /*M*/,
|
||||
int /*N*/,
|
||||
int seq_len,
|
||||
float target) {
|
||||
for (int i = 0; i < seq_len; ++i) {
|
||||
// assume that the mask_seq is smaller than size
|
||||
// Although it seems that random access gets bad performance,
|
||||
// we make sure that seq is in order;
|
||||
mat[static_cast<int>(mask_seq[i])] = target;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void AggrDW<float, CPUContext>(
|
||||
float* ag_dw,
|
||||
const float* dw,
|
||||
int N,
|
||||
int K,
|
||||
CPUContext* context) {
|
||||
math::Add<float, CPUContext>(N * K, dw, ag_dw, ag_dw, context);
|
||||
}
|
||||
|
||||
template <>
|
||||
int MatrixCompare_LT<float>(
|
||||
const float* mat,
|
||||
float thres,
|
||||
float* mask_seq,
|
||||
int M,
|
||||
int N) {
|
||||
int seq_len = 0;
|
||||
int offset = 0;
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
if (mat[offset] != 0 && (mat[offset] < thres && mat[offset] > -thres)) {
|
||||
mask_seq[seq_len++] = static_cast<float>(offset);
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
return seq_len;
|
||||
}
|
||||
|
||||
template <typename T, class Context>
|
||||
void MaskMatrix(const T* mask, T* mat,
|
||||
int M, int N);
|
||||
} // namespace
|
||||
|
||||
template <typename T, class Context>
|
||||
void MaskMatrix_Inc(T* mask_seq, T* mat,
|
||||
int M, int N, int seq_len, T target);
|
||||
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
|
||||
template <typename T, class Context, class Engine = DefaultEngine>
|
||||
class FullyConnectedOpPrune final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
FullyConnectedOpPrune(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
~FullyConnectedOpPrune() {}
|
||||
|
||||
template <typename T, class Context>
|
||||
void AggrDW(T* ag_dw, const T* dw, int N, int K, Context* context);
|
||||
|
||||
template <typename T>
|
||||
int MatrixCompare_LT(const T* mat, float thres,
|
||||
T* mask_seq, int M, int N);
|
||||
|
||||
// TODO(wyiming): write an incremental Mask
|
||||
// Incremental Mask: only give the new mask positions;
|
||||
// Assuming that weights masked will not be mask again;
|
||||
// The incremental mask can also be used to update mask matrix;
|
||||
// But this will include template for bool and float;
|
||||
template <>
|
||||
void MaskMatrix<float, CPUContext>(
|
||||
const float* mask, float* mat, int M, int N) {
|
||||
int offset = 0;
|
||||
for (int i = 0; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
mat[offset] = mask[offset]? mat[offset] : 0;
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void MaskMatrix_Inc<float, CPUContext>(
|
||||
float* mask_seq,
|
||||
float* mat,
|
||||
int /*M*/,
|
||||
int /*N*/,
|
||||
int seq_len,
|
||||
float target) {
|
||||
for (int i = 0; i < seq_len; ++i) {
|
||||
// assume that the mask_seq is smaller than size
|
||||
// Although it seems that random access gets bad performance,
|
||||
// we make sure that seq is in order;
|
||||
mat[static_cast<int>(mask_seq[i])] = target;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void AggrDW<float, CPUContext>(
|
||||
float* ag_dw, const float* dw,
|
||||
int N, int K, CPUContext* context) {
|
||||
math::Add<float, CPUContext>(N*K, dw, ag_dw, ag_dw, context);
|
||||
}
|
||||
|
||||
template <>
|
||||
int MatrixCompare_LT<float>(
|
||||
const float* mat, float thres,
|
||||
float* mask_seq, int M, int N) {
|
||||
int seq_len = 0;
|
||||
int offset = 0;
|
||||
for (int i = 0 ; i < M; ++i) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
if (mat[offset] != 0 &&
|
||||
(mat[offset] < thres && mat[offset] > -thres)) {
|
||||
mask_seq[seq_len++] = static_cast<float>(offset);
|
||||
}
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
return seq_len;
|
||||
}
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
const auto& W = Input(1);
|
||||
const auto& Mask = Input(2);
|
||||
const auto& b = Input(3);
|
||||
|
||||
CAFFE_ENFORCE_GE(X.dim(), 1);
|
||||
CAFFE_ENFORCE_GE(W.dim(), 2);
|
||||
if (X.dim() > 2 || W.dim() > 2) {
|
||||
VLOG(1) << "Using legacy support for arbitrary input and weight "
|
||||
"dimensions.";
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(b.dim(), 1);
|
||||
// batch size
|
||||
int M = X.dim() > 1 ? X.dim32(0) : 1;
|
||||
// Feature dimension
|
||||
int K = X.numel() / M;
|
||||
// number of outputs.
|
||||
int N = W.dim32(0);
|
||||
CAFFE_ENFORCE_EQ(K, W.numel() / W.dim32(0));
|
||||
CAFFE_ENFORCE_EQ(N, b.dim32(0));
|
||||
std::vector<int64_t> dims;
|
||||
if (X.dim() > 1) {
|
||||
dims = {M, N};
|
||||
} else {
|
||||
dims = {N};
|
||||
}
|
||||
auto* Y = Output(0, dims, at::dtype<T>());
|
||||
// W * x
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1,
|
||||
X.template data<T>(),
|
||||
W.template data<T>(),
|
||||
0,
|
||||
Y->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Add bias term
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M,
|
||||
// reshape and fill it with one.
|
||||
bias_multiplier_.Resize(M);
|
||||
math::Set<T, Context>(
|
||||
M,
|
||||
static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
N,
|
||||
1,
|
||||
1,
|
||||
bias_multiplier_.template data<T>(),
|
||||
b.template data<T>(),
|
||||
1,
|
||||
Y->template mutable_data<T>(),
|
||||
&context_);
|
||||
if (OutputSize() == 2) {
|
||||
auto* Comp_rate = Output(1, vector<int64_t>(), at::dtype<T>());
|
||||
T* comp_data = Comp_rate->template mutable_data<T>();
|
||||
math::Sum<T, Context>(
|
||||
Mask.numel(), Mask.template data<T>(), comp_data, &context_);
|
||||
math::Scale<float, T, Context>(
|
||||
1,
|
||||
static_cast<T>(1.) / Mask.numel(),
|
||||
comp_data,
|
||||
comp_data,
|
||||
&context_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
|
||||
template <typename T, class Context, class Engine=DefaultEngine>
|
||||
class FullyConnectedOpPrune final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
FullyConnectedOpPrune(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
~FullyConnectedOpPrune() {}
|
||||
protected:
|
||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
||||
};
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
const auto& W = Input(1);
|
||||
const auto& Mask = Input(2);
|
||||
const auto& b = Input(3);
|
||||
template <typename T, class Context, class Engine = DefaultEngine>
|
||||
class FullyConnectedPruneGradientOp : public Operator<Context> {
|
||||
public:
|
||||
int iter_offset;
|
||||
|
||||
CAFFE_ENFORCE_GE(X.dim(), 1);
|
||||
CAFFE_ENFORCE_GE(W.dim(), 2);
|
||||
if (X.dim() > 2 || W.dim() > 2) {
|
||||
VLOG(1) << "Using legacy support for arbitrary input and weight "
|
||||
"dimensions.";
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(b.dim(), 1);
|
||||
// batch size
|
||||
int M = X.dim() > 1 ? X.dim32(0) : 1;
|
||||
// Feature dimension
|
||||
int K = X.numel() / M;
|
||||
// number of outputs.
|
||||
int N = W.dim32(0);
|
||||
CAFFE_ENFORCE_EQ(K, W.numel() / W.dim32(0));
|
||||
CAFFE_ENFORCE_EQ(N, b.dim32(0));
|
||||
std::vector<int64_t> dims;
|
||||
if (X.dim() > 1) {
|
||||
dims = {M, N};
|
||||
} else {
|
||||
dims = {N};
|
||||
}
|
||||
auto* Y = Output(0, dims, at::dtype<T>());
|
||||
// W * x
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
|
||||
W.template data<T>(), 0, Y->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Add bias term
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M,
|
||||
// reshape and fill it with one.
|
||||
bias_multiplier_.Resize(M);
|
||||
math::Set<T, Context>(
|
||||
M, static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, N, 1, 1,
|
||||
bias_multiplier_.template data<T>(), b.template data<T>(), 1,
|
||||
Y->template mutable_data<T>(), &context_);
|
||||
if (OutputSize() == 2){
|
||||
auto* Comp_rate = Output(1, vector<int64_t>(), at::dtype<T>());
|
||||
T* comp_data = Comp_rate->template mutable_data<T>();
|
||||
math::Sum<T, Context>(
|
||||
Mask.numel(), Mask.template data<T>(), comp_data, &context_);
|
||||
math::Scale<float, T, Context>(
|
||||
1,
|
||||
static_cast<T>(1.) / Mask.numel(),
|
||||
comp_data,
|
||||
comp_data,
|
||||
&context_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
FullyConnectedPruneGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
iter_offset = 0;
|
||||
}
|
||||
~FullyConnectedPruneGradientOp() {}
|
||||
|
||||
protected:
|
||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
||||
};
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
// const auto& W = Input(1);
|
||||
auto* W_ptr = Output(2);
|
||||
auto& W = *W_ptr;
|
||||
// const auto& Mask = Input(2);
|
||||
auto* Mask_ptr = Output(3);
|
||||
auto& Mask = *Mask_ptr;
|
||||
const auto& dY = Input(3);
|
||||
// const auto& Ag_dW = Input(4);
|
||||
auto* Ag_dW_ptr = Output(4);
|
||||
auto& Ag_dW = *Ag_dW_ptr;
|
||||
// it is also the Input(5)
|
||||
|
||||
template <typename T, class Context, class Engine=DefaultEngine>
|
||||
class FullyConnectedPruneGradientOp : public Operator<Context> {
|
||||
public:
|
||||
int iter_offset;
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
FullyConnectedPruneGradientOp
|
||||
(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) { iter_offset = 0; }
|
||||
~FullyConnectedPruneGradientOp() {}
|
||||
// how about get threshold
|
||||
auto& thres = Input(6);
|
||||
// TODO(wyiming): check comp_lb is a float
|
||||
auto& comp_lb = Input(7);
|
||||
DCHECK_GE(X.dim(), 1);
|
||||
DCHECK_GE(W.dim(), 2);
|
||||
DCHECK_LE(dY.dim(), 2);
|
||||
// batch size
|
||||
int M = X.dim() > 1 ? X.dim32(0) : 1;
|
||||
// Feature dimension
|
||||
int K = X.numel() / M;
|
||||
// number of outputs.
|
||||
int N = W.dim32(0);
|
||||
// TODO(wyiming): add this window_size to workspace?
|
||||
int window_size = 100;
|
||||
// TODO(wyiming): this threshold should be
|
||||
// based on distribution of the layer weight
|
||||
float thr = 0.01;
|
||||
DCHECK_EQ(Mask.dim32(0), W.dim32(0));
|
||||
DCHECK_EQ(Mask.dim32(1), W.dim32(1));
|
||||
DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
|
||||
DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
|
||||
DCHECK_EQ(K, W.numel() / W.dim32(0));
|
||||
if (dY.dim() > 1) {
|
||||
DCHECK_EQ(M, dY.dim32(0));
|
||||
DCHECK_EQ(N, dY.dim32(1));
|
||||
} else {
|
||||
DCHECK_EQ(X.dim(), 1);
|
||||
DCHECK_EQ(N, dY.numel());
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
//const auto& W = Input(1);
|
||||
auto* W_ptr = Output(2);
|
||||
auto& W = *W_ptr;
|
||||
//const auto& Mask = Input(2);
|
||||
auto* Mask_ptr = Output(3);
|
||||
auto& Mask = *Mask_ptr;
|
||||
const auto& dY = Input(3);
|
||||
//const auto& Ag_dW = Input(4);
|
||||
auto* Ag_dW_ptr = Output(4);
|
||||
auto& Ag_dW = *Ag_dW_ptr;
|
||||
// it is also the Input(5)
|
||||
auto* dW = Output(0, W.sizes(), at::dtype<T>());
|
||||
auto* db = Output(1, {N}, at::dtype<T>());
|
||||
|
||||
// how about get threshold
|
||||
auto& thres = Input(6);
|
||||
//TODO(wyiming): check comp_lb is a float
|
||||
auto& comp_lb = Input(7);
|
||||
DCHECK_GE(X.dim(), 1);
|
||||
DCHECK_GE(W.dim(), 2);
|
||||
DCHECK_LE(dY.dim(), 2);
|
||||
// batch size
|
||||
int M = X.dim() > 1 ? X.dim32(0) : 1;
|
||||
// Feature dimension
|
||||
int K = X.numel() / M;
|
||||
// number of outputs.
|
||||
int N = W.dim32(0);
|
||||
// TODO(wyiming): add this window_size to workspace?
|
||||
int window_size = 100;
|
||||
// TODO(wyiming): this threshold should be
|
||||
// based on distribution of the layer weight
|
||||
float thr = 0.01;
|
||||
DCHECK_EQ(Mask.dim32(0), W.dim32(0));
|
||||
DCHECK_EQ(Mask.dim32(1), W.dim32(1));
|
||||
DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
|
||||
DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
|
||||
DCHECK_EQ(K, W.numel() / W.dim32(0));
|
||||
if (dY.dim() > 1) {
|
||||
DCHECK_EQ(M, dY.dim32(0));
|
||||
DCHECK_EQ(N, dY.dim32(1));
|
||||
} else {
|
||||
DCHECK_EQ(X.dim(), 1);
|
||||
DCHECK_EQ(N, dY.numel());
|
||||
}
|
||||
// Compute dW
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
N,
|
||||
K,
|
||||
M,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
X.template data<T>(),
|
||||
0,
|
||||
dW->template mutable_data<T>(),
|
||||
&context_);
|
||||
|
||||
auto* dW = Output(0, W.sizes(), at::dtype<T>());
|
||||
auto* db = Output(1, {N}, at::dtype<T>());
|
||||
comp_r_buf_.Resize(vector<int64_t>());
|
||||
T* comp_data = comp_r_buf_.template mutable_data<T>();
|
||||
math::Sum<T, Context>(
|
||||
Mask.numel(), Mask.template data<T>(), comp_data, &context_);
|
||||
math::Scale<float, T, Context>(
|
||||
1, static_cast<T>(1.) / Mask.numel(), comp_data, comp_data, &context_);
|
||||
// update W size window
|
||||
// Notice here we need to maintain state in OP.
|
||||
// This is new in Caffe2.
|
||||
// And this is something we might need to discuss in the future.
|
||||
// at most mask half of the matrix at time
|
||||
// 1. mask dw with previous mask
|
||||
MaskMatrix<T, Context>(
|
||||
Mask.template mutable_data<T>(), dW->template mutable_data<T>(), N, K);
|
||||
if (*comp_data > *(comp_lb.template data<T>())) {
|
||||
iter_offset++;
|
||||
if (iter_offset % window_size == 0) {
|
||||
// TODO(wyiming):do the prune here;
|
||||
sum_buffer_.ResizeLike(W);
|
||||
math::Add<T, Context>(
|
||||
W.numel(),
|
||||
W.template mutable_data<T>(),
|
||||
Ag_dW.template mutable_data<T>(),
|
||||
sum_buffer_.template mutable_data<T>(),
|
||||
&context_);
|
||||
auto* mask_seq_auto = Output(5, W.sizes(), at::dtype<T>());
|
||||
T* mask_seq = mask_seq_auto->template mutable_data<T>();
|
||||
math::Set<T, Context>(
|
||||
N * K,
|
||||
static_cast<T>(0),
|
||||
mask_seq_auto->template mutable_data<T>(),
|
||||
&context_);
|
||||
// 2. find dw below thres but not eq 0
|
||||
int seq_len = MatrixCompare_LT<T>(
|
||||
Ag_dW_ptr->template mutable_data<T>(),
|
||||
*thres.template data<T>(),
|
||||
mask_seq,
|
||||
N,
|
||||
K);
|
||||
// 3. use the mask_seq to update W and dw
|
||||
MaskMatrix_Inc<T, Context>(
|
||||
mask_seq, dW->template mutable_data<T>(), N, K, seq_len, 0);
|
||||
MaskMatrix_Inc<T, Context>(
|
||||
mask_seq, W.template mutable_data<T>(), N, K, seq_len, 0);
|
||||
MaskMatrix_Inc<T, Context>(
|
||||
mask_seq, Mask.template mutable_data<T>(), N, K, seq_len, 0);
|
||||
math::Set<T, Context>(
|
||||
N * K,
|
||||
static_cast<T>(0),
|
||||
Ag_dW.template mutable_data<T>(),
|
||||
&context_);
|
||||
} else {
|
||||
// add dW to Aggregate dW.
|
||||
AggrDW<T, Context>(
|
||||
Ag_dW.template mutable_data<T>(),
|
||||
dW->template mutable_data<T>(),
|
||||
N,
|
||||
K,
|
||||
&context_);
|
||||
}
|
||||
}
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M,
|
||||
// reshape and fill it with one.
|
||||
bias_multiplier_.Resize(M);
|
||||
math::Set<T, Context>(
|
||||
M,
|
||||
static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
// Compute dB
|
||||
math::Gemv<T, Context>(
|
||||
CblasTrans,
|
||||
M,
|
||||
N,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
bias_multiplier_.template data<T>(),
|
||||
0,
|
||||
db->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Compute dX if necessary.
|
||||
if (OutputSize() == 7) {
|
||||
auto* dX = Output(6, X.sizes(), at::dtype<T>());
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
M,
|
||||
K,
|
||||
N,
|
||||
1,
|
||||
dY.template data<T>(),
|
||||
W.template data<T>(),
|
||||
0,
|
||||
dX->template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
|
||||
// Compute dW
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasTrans, CblasNoTrans, N, K, M, 1,
|
||||
dY.template data<T>(), X.template data<T>(),
|
||||
0, dW->template mutable_data<T>(),
|
||||
&context_);
|
||||
return true;
|
||||
}
|
||||
|
||||
comp_r_buf_.Resize(vector<int64_t>());
|
||||
T* comp_data = comp_r_buf_.template mutable_data<T>();
|
||||
math::Sum<T, Context>(
|
||||
Mask.numel(), Mask.template data<T>(), comp_data, &context_);
|
||||
math::Scale<float, T, Context>(
|
||||
1,
|
||||
static_cast<T>(1.) / Mask.numel(),
|
||||
comp_data,
|
||||
comp_data,
|
||||
&context_);
|
||||
// update W size window
|
||||
// Notice here we need to maintain state in OP.
|
||||
// This is new in Caffe2.
|
||||
// And this is something we might need to discuss in the future.
|
||||
// at most mask half of the matrix at time
|
||||
// 1. mask dw with previous mask
|
||||
MaskMatrix<T, Context>(Mask.template mutable_data<T>(),
|
||||
dW->template mutable_data<T>(), N, K);
|
||||
if(*comp_data > *(comp_lb.template data<T>())){
|
||||
iter_offset++;
|
||||
if (iter_offset % window_size == 0) {
|
||||
// TODO(wyiming):do the prune here;
|
||||
sum_buffer_.ResizeLike(W);
|
||||
math::Add<T, Context>(
|
||||
W.numel(),
|
||||
W.template mutable_data<T>(),
|
||||
Ag_dW.template mutable_data<T>(),
|
||||
sum_buffer_.template mutable_data<T>(),
|
||||
&context_);
|
||||
auto* mask_seq_auto = Output(5, W.sizes(), at::dtype<T>());
|
||||
T* mask_seq = mask_seq_auto->template mutable_data<T>();
|
||||
math::Set<T, Context>(N*K, static_cast<T>(0),
|
||||
mask_seq_auto->template mutable_data<T>(), &context_);
|
||||
// 2. find dw below thres but not eq 0
|
||||
int seq_len = MatrixCompare_LT<T>(
|
||||
Ag_dW_ptr->template mutable_data<T>(),
|
||||
*thres.template data<T>(), mask_seq, N, K);
|
||||
// 3. use the mask_seq to update W and dw
|
||||
MaskMatrix_Inc<T, Context>(mask_seq,
|
||||
dW->template mutable_data<T>(),
|
||||
N, K, seq_len, 0);
|
||||
MaskMatrix_Inc<T, Context>(mask_seq,
|
||||
W.template mutable_data<T>(),
|
||||
N, K, seq_len, 0);
|
||||
MaskMatrix_Inc<T, Context>(mask_seq,
|
||||
Mask.template mutable_data<T>(),
|
||||
N, K, seq_len, 0);
|
||||
math::Set<T, Context>(N*K, static_cast<T>(0),
|
||||
Ag_dW.template mutable_data<T>(),
|
||||
&context_);
|
||||
} else {
|
||||
// add dW to Aggregate dW.
|
||||
AggrDW<T, Context>(
|
||||
Ag_dW.template mutable_data<T>(),
|
||||
dW->template mutable_data<T>(),
|
||||
N, K, &context_);
|
||||
}
|
||||
}
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M,
|
||||
// reshape and fill it with one.
|
||||
bias_multiplier_.Resize(M);
|
||||
math::Set<T, Context>(
|
||||
M, static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
// Compute dB
|
||||
math::Gemv<T, Context>(
|
||||
CblasTrans, M, N, 1, dY.template data<T>(),
|
||||
bias_multiplier_.template data<T>(), 0,
|
||||
db->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Compute dX if necessary.
|
||||
if (OutputSize() == 7) {
|
||||
auto* dX = Output(6, X.sizes(), at::dtype<T>());
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, K, N, 1,
|
||||
dY.template data<T>(), W.template data<T>(),
|
||||
0, dX->template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
protected:
|
||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
||||
Tensor sum_buffer_{Context::GetDeviceType()};
|
||||
Tensor comp_r_buf_{Context::GetDeviceType()};
|
||||
};
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
||||
protected:
|
||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
||||
Tensor sum_buffer_{Context::GetDeviceType()};
|
||||
Tensor comp_r_buf_{Context::GetDeviceType()};
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||
|
|
|
|||
|
|
@ -22,5 +22,5 @@ namespace {
|
|||
REGISTER_CPU_OPERATOR(FC_Sparse, FullyConnectedOp_SPARSE<float, CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(FC_Sparse).NumInputs(5).NumOutputs(1);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -22,16 +22,16 @@
|
|||
#include "caffe2/utils/math.h"
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
#include <mkl.h>
|
||||
#endif // CAFFE2_USE_MKL
|
||||
#endif // CAFFE2_USE_MKL
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
template<int N>
|
||||
template <int N>
|
||||
using Shape = std::array<int, N>;
|
||||
|
||||
template<int N>
|
||||
template <int N>
|
||||
const std::vector<int64_t>& shape(Shape<N> vs) {
|
||||
static thread_local std::vector<int64_t> cache;
|
||||
cache.resize(vs.size());
|
||||
|
|
@ -50,10 +50,18 @@ inline const std::vector<int64_t>& shape(int i, int j) {
|
|||
}
|
||||
|
||||
template <typename T, class Context>
|
||||
void Sparse_mm(const T* acsr, const int* ia, const int* ja,
|
||||
int m, int k, int n, const T* b, T* c, Context* context);
|
||||
void Sparse_mm(
|
||||
const T* acsr,
|
||||
const int* ia,
|
||||
const int* ja,
|
||||
int m,
|
||||
int k,
|
||||
int n,
|
||||
const T* b,
|
||||
T* c,
|
||||
Context* context);
|
||||
|
||||
template<typename T, class Context>
|
||||
template <typename T, class Context>
|
||||
void trans_mat(const T* o, T* t, int m, int n, Context* context);
|
||||
|
||||
template <>
|
||||
|
|
@ -63,9 +71,9 @@ void trans_mat<float, CPUContext>(
|
|||
int m,
|
||||
int n,
|
||||
CPUContext* /*context*/) {
|
||||
for(int i = 0; i < m; ++i){
|
||||
for(int j = 0; j < n; ++j){
|
||||
t[j*m+i]=o[i*n+j];
|
||||
for (int i = 0; i < m; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
t[j * m + i] = o[i * n + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -83,22 +91,35 @@ void Sparse_mm<float, CPUContext>(
|
|||
const float* b,
|
||||
float* c,
|
||||
CPUContext* /*context*/) {
|
||||
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
#ifdef CAFFE2_USE_MKL
|
||||
|
||||
float alpha = 1.0, beta = 0.;
|
||||
mkl_scsrmm("N", &m, &n, &k, &alpha, "GLNC",
|
||||
acsr, ja, ia, ia+1, b, &n, &beta, c, &n);
|
||||
mkl_scsrmm(
|
||||
"N",
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&alpha,
|
||||
"GLNC",
|
||||
acsr,
|
||||
ja,
|
||||
ia,
|
||||
ia + 1,
|
||||
b,
|
||||
&n,
|
||||
&beta,
|
||||
c,
|
||||
&n);
|
||||
|
||||
#else
|
||||
#else
|
||||
throw std::runtime_error("Not compiled with MKL");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
|
||||
template <typename T, class Context, class Engine=DefaultEngine>
|
||||
template <typename T, class Context, class Engine = DefaultEngine>
|
||||
class FullyConnectedOp_SPARSE final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
|
@ -122,27 +143,43 @@ class FullyConnectedOp_SPARSE final : public Operator<Context> {
|
|||
// Feature dimension
|
||||
int M = Xt.numel() / K;
|
||||
// number of outputs.
|
||||
int N = iw.dim32(0)-1;
|
||||
int N = iw.dim32(0) - 1;
|
||||
CAFFE_ENFORCE_EQ(N, b.dim32(0));
|
||||
auto* Yt = Output(0, shape(N, M), at::dtype<T>());
|
||||
|
||||
// Y' = W * X';
|
||||
Sparse_mm<T, Context>(
|
||||
Wcsr.template data<T>(), iw.template data<int>(),
|
||||
jw.template data<int>(), N, K, M, Xt.template data<T>(),
|
||||
Yt->template mutable_data<T>(), &context_);
|
||||
Wcsr.template data<T>(),
|
||||
iw.template data<int>(),
|
||||
jw.template data<int>(),
|
||||
N,
|
||||
K,
|
||||
M,
|
||||
Xt.template data<T>(),
|
||||
Yt->template mutable_data<T>(),
|
||||
&context_);
|
||||
// Add bias term
|
||||
if (bias_multiplier_.numel() != M) {
|
||||
// If the helper bias multiplier is not M, reshape and fill it with one.
|
||||
bias_multiplier_.Resize(shape(M));
|
||||
math::Set<T, Context>(
|
||||
M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
|
||||
M,
|
||||
static_cast<T>(1),
|
||||
bias_multiplier_.template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, N, M, 1, 1,
|
||||
b.template data<T>(), bias_multiplier_.template data<T>(), 1,
|
||||
Yt->template mutable_data<T>(), &context_);
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
N,
|
||||
M,
|
||||
1,
|
||||
1,
|
||||
b.template data<T>(),
|
||||
bias_multiplier_.template data<T>(),
|
||||
1,
|
||||
Yt->template mutable_data<T>(),
|
||||
&context_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -150,7 +187,6 @@ class FullyConnectedOp_SPARSE final : public Operator<Context> {
|
|||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||
|
|
|
|||
|
|
@ -46,33 +46,39 @@ randomly mapped from the weight vector, provided the input
|
|||
)DOC")
|
||||
.Input(0, "scalars", "Values of the non-zero entries of the sparse data.")
|
||||
.Input(1, "indices", "Indices to the non-zero valued features.")
|
||||
.Input(2, "segment_ids",
|
||||
.Input(
|
||||
2,
|
||||
"segment_ids",
|
||||
"Segment IDs corresponding to the non-zero entries.")
|
||||
.Input(3, "weight", "Weight vector")
|
||||
.Input(4, "alpha",
|
||||
.Input(
|
||||
4,
|
||||
"alpha",
|
||||
"Optional coefficients for linear combination of hashed weights.")
|
||||
.Output(0, "output",
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor with the first dimension equal to the number "
|
||||
"of segments.")
|
||||
.Arg("num_outputs", "Number of outputs")
|
||||
.Arg("num_segments", "Number of segments");
|
||||
|
||||
OPERATOR_SCHEMA(FunHashGradient)
|
||||
.NumInputs(5, 6)
|
||||
.NumOutputs(1, 2);
|
||||
OPERATOR_SCHEMA(FunHashGradient).NumInputs(5, 6).NumOutputs(1, 2);
|
||||
|
||||
class GetFunHashGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
if (def_.input_size() == 4) {
|
||||
return SingleGradientDef(
|
||||
"FunHashGradient", "",
|
||||
"FunHashGradient",
|
||||
"",
|
||||
vector<string>{GO(0), I(0), I(1), I(2), I(3)},
|
||||
vector<string>{GI(3)});
|
||||
}
|
||||
// def_.input_size() == 5
|
||||
return SingleGradientDef(
|
||||
"FunHashGradient", "",
|
||||
"FunHashGradient",
|
||||
"",
|
||||
vector<string>{GO(0), I(0), I(1), I(2), I(3), I(4)},
|
||||
vector<string>{GI(3), GI(4)});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,10 @@ class TTContractionOp final : public Operator<Context> {
|
|||
math::Gemm<T, Context, Engine>(
|
||||
CblasTrans,
|
||||
CblasNoTrans,
|
||||
M_, N_, K_, 1,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
1,
|
||||
A_data,
|
||||
B_data + B_index,
|
||||
0,
|
||||
|
|
@ -126,7 +129,10 @@ class TTContractionGradientOp final : public Operator<Context> {
|
|||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans,
|
||||
CblasTrans,
|
||||
K_, M_, N_, 1,
|
||||
K_,
|
||||
M_,
|
||||
N_,
|
||||
1,
|
||||
B_data + B_index,
|
||||
G_ptr,
|
||||
B_index == 0 ? 0 : 1,
|
||||
|
|
@ -140,7 +146,10 @@ class TTContractionGradientOp final : public Operator<Context> {
|
|||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans,
|
||||
CblasNoTrans,
|
||||
K_, N_, M_, 1,
|
||||
K_,
|
||||
N_,
|
||||
M_,
|
||||
1,
|
||||
A_data,
|
||||
G_ptr,
|
||||
0,
|
||||
|
|
|
|||
|
|
@ -175,4 +175,4 @@ void MPISetupPeers(
|
|||
<< MPICommSize(GlobalMPIComm());
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#include "caffe2/core/init.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/mpi/mpi_common.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
C10_DEFINE_string(
|
||||
caffe_test_root,
|
||||
|
|
@ -47,8 +47,7 @@ const char kBcastNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIBroadcast) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kBcastNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kBcastNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -108,8 +107,7 @@ const char kReduceNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIReduce) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kReduceNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kReduceNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -174,8 +172,7 @@ const char kMPIAllgatherNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIAllgather) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kMPIAllgatherNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kMPIAllgatherNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -237,8 +234,7 @@ const char kMPIAllreduceNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIAllreduce) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kMPIAllreduceNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kMPIAllreduceNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -299,8 +295,7 @@ const char kInPlaceMPIAllreduceNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestInPlaceMPIAllreduce) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kInPlaceMPIAllreduceNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kInPlaceMPIAllreduceNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -323,10 +318,9 @@ TEST(MPITest, TestInPlaceMPIAllreduce) {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
||||
|
||||
GTEST_API_ int main(int argc, char **argv) {
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
int mpi_ret;
|
||||
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_ret);
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
|
|
|||
|
|
@ -2,23 +2,14 @@
|
|||
|
||||
namespace caffe2 {
|
||||
|
||||
OPERATOR_SCHEMA(MPICreateCommonWorld)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1);
|
||||
OPERATOR_SCHEMA(MPICreateCommonWorld).NumInputs(0).NumOutputs(1);
|
||||
OPERATOR_SCHEMA(MPIBroadcast)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.EnforceInplace({{1, 0}});
|
||||
OPERATOR_SCHEMA(MPIReduce)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1);
|
||||
OPERATOR_SCHEMA(MPIAllgather)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1);
|
||||
OPERATOR_SCHEMA(MPIAllreduce)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowInplace({{1, 0}});
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.EnforceInplace({{1, 0}});
|
||||
OPERATOR_SCHEMA(MPIReduce).NumInputs(2).NumOutputs(1);
|
||||
OPERATOR_SCHEMA(MPIAllgather).NumInputs(2).NumOutputs(1);
|
||||
OPERATOR_SCHEMA(MPIAllreduce).NumInputs(2).NumOutputs(1).AllowInplace({{1, 0}});
|
||||
OPERATOR_SCHEMA(MPISendTensor);
|
||||
OPERATOR_SCHEMA(MPIReceiveTensor);
|
||||
|
||||
|
|
@ -30,4 +21,4 @@ REGISTER_CPU_OPERATOR(MPIAllreduce, MPIAllreduceOp<float, CPUContext>);
|
|||
REGISTER_CPU_OPERATOR(MPISendTensor, MPISendTensorOp<CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(MPIReceiveTensor, MPIReceiveTensorOp<CPUContext>);
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
#include "caffe2/mpi/mpi_ops.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/mpi/mpi_ops.h"
|
||||
#include "caffe2/operators/operator_fallback_gpu.h"
|
||||
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// Here is a bunch of MPI macro definitions that allow us to see if the MPI
|
||||
|
|
@ -61,26 +60,16 @@ REGISTER_CUDA_OPERATOR(MPISendTensor, MPISendTensorOp<CUDAContext>);
|
|||
REGISTER_CUDA_OPERATOR(MPIReceiveTensor, MPIReceiveTensorOp<CUDAContext>);
|
||||
#else
|
||||
REGISTER_CUDA_OPERATOR(MPIBroadcast, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIReduce,
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIAllgather,
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPISendTensor,
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIReceiveTensor,
|
||||
GPUFallbackOpEx<SkipIndices<1, 2>>);
|
||||
REGISTER_CUDA_OPERATOR(MPIReduce, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(MPIAllgather, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(MPISendTensor, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(MPIReceiveTensor, GPUFallbackOpEx<SkipIndices<1, 2>>);
|
||||
#endif
|
||||
|
||||
#if CAFFE2_HAS_CUDA_MPI_ALLREDUCE
|
||||
REGISTER_CUDA_OPERATOR(MPIAllreduce, MPIAllreduceOp<float, CUDAContext>);
|
||||
#else
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
MPIAllreduce,
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(MPIAllreduce, GPUFallbackOp);
|
||||
#endif
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/mpi/mpi_common.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
C10_DEFINE_string(
|
||||
caffe_test_root,
|
||||
|
|
@ -43,8 +43,7 @@ const char kBcastNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIBroadcast) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kBcastNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kBcastNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -101,8 +100,7 @@ const char kReduceNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIReduce) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kReduceNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kReduceNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -163,8 +161,7 @@ const char kMPIAllgatherNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIAllgather) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kMPIAllgatherNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kMPIAllgatherNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -221,8 +218,7 @@ const char kMPIAllreduceNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestMPIAllreduce) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kMPIAllreduceNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kMPIAllreduceNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -278,8 +274,7 @@ const char kInPlaceMPIAllreduceNet[] = R"NET(
|
|||
|
||||
TEST(MPITest, TestInPlaceMPIAllreduce) {
|
||||
NetDef net_def;
|
||||
CHECK(TextFormat::ParseFromString(
|
||||
string(kInPlaceMPIAllreduceNet), &net_def));
|
||||
CHECK(TextFormat::ParseFromString(string(kInPlaceMPIAllreduceNet), &net_def));
|
||||
// Let's set the network's constant fill value to be the mpi rank.
|
||||
auto* arg = net_def.mutable_op(1)->mutable_arg(1);
|
||||
CAFFE_ENFORCE_EQ(arg->name(), "value");
|
||||
|
|
@ -301,10 +296,9 @@ TEST(MPITest, TestInPlaceMPIAllreduce) {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
||||
|
||||
GTEST_API_ int main(int argc, char **argv) {
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
int mpi_ret;
|
||||
MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &mpi_ret);
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
|
|
|||
|
|
@ -95,8 +95,8 @@ class TORCH_API ProfileOperatorObserver final
|
|||
};
|
||||
|
||||
class TORCH_API ProfileObserver final : public OperatorAttachingNetObserver<
|
||||
ProfileOperatorObserver,
|
||||
ProfileObserver> {
|
||||
ProfileOperatorObserver,
|
||||
ProfileObserver> {
|
||||
public:
|
||||
explicit ProfileObserver(NetBase* subject)
|
||||
: OperatorAttachingNetObserver<ProfileOperatorObserver, ProfileObserver>(
|
||||
|
|
|
|||
|
|
@ -27,10 +27,9 @@ class TORCH_API RunCountOperatorObserver final
|
|||
RunCountNetObserver* netObserver_;
|
||||
};
|
||||
|
||||
class TORCH_API RunCountNetObserver final
|
||||
: public OperatorAttachingNetObserver<
|
||||
RunCountOperatorObserver,
|
||||
RunCountNetObserver> {
|
||||
class TORCH_API RunCountNetObserver final : public OperatorAttachingNetObserver<
|
||||
RunCountOperatorObserver,
|
||||
RunCountNetObserver> {
|
||||
public:
|
||||
explicit RunCountNetObserver(NetBase* subject_)
|
||||
: OperatorAttachingNetObserver<
|
||||
|
|
|
|||
|
|
@ -28,9 +28,8 @@ class TORCH_API TimeCounter {
|
|||
int iterations_ = 0;
|
||||
};
|
||||
|
||||
class TORCH_API TimeOperatorObserver final
|
||||
: public TimeCounter,
|
||||
public ObserverBase<OperatorBase> {
|
||||
class TORCH_API TimeOperatorObserver final : public TimeCounter,
|
||||
public ObserverBase<OperatorBase> {
|
||||
public:
|
||||
explicit TimeOperatorObserver(OperatorBase* subject) = delete;
|
||||
explicit TimeOperatorObserver(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#include "caffe2/onnx/backend.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/onnx/backend.h"
|
||||
#include "caffe2/onnx/device.h"
|
||||
#include "caffe2/onnx/helper.h"
|
||||
#include "caffe2/utils/map_utils.h"
|
||||
|
|
@ -81,7 +81,9 @@ U LookUpWithDefault(
|
|||
}
|
||||
}
|
||||
|
||||
void UpdateNames(std::shared_ptr<DummyName> dummy, const caffe2::OperatorDef& op) {
|
||||
void UpdateNames(
|
||||
std::shared_ptr<DummyName> dummy,
|
||||
const caffe2::OperatorDef& op) {
|
||||
for (const auto& n : op.input()) {
|
||||
dummy->AddName(n);
|
||||
}
|
||||
|
|
@ -198,8 +200,8 @@ OnnxAttributes::get(const std::string& key) const {
|
|||
}
|
||||
|
||||
template <>
|
||||
::google::protobuf::RepeatedField<float>
|
||||
OnnxAttributes::get(const std::string& key) const {
|
||||
::google::protobuf::RepeatedField<float> OnnxAttributes::get(
|
||||
const std::string& key) const {
|
||||
::google::protobuf::RepeatedField<float> value;
|
||||
const auto it = onnx_attrs_.find(key);
|
||||
if (it != onnx_attrs_.end()) {
|
||||
|
|
@ -305,11 +307,12 @@ const std::
|
|||
Caffe2Backend::get_per_op_renamed_attrs() const {
|
||||
const static std::
|
||||
unordered_map<std::string, std::unordered_map<std::string, std::string>>
|
||||
kPerOpRenamedAttrs = {{"Squeeze", {{"axes", "dims"}}},
|
||||
{"Unsqueeze", {{"axes", "dims"}}},
|
||||
{"Transpose", {{"perm", "axes"}}},
|
||||
{"ConvTranspose", {{"output_padding", "adjs"}}},
|
||||
{"Selu", {{"gamma", "scale"}}}};
|
||||
kPerOpRenamedAttrs = {
|
||||
{"Squeeze", {{"axes", "dims"}}},
|
||||
{"Unsqueeze", {{"axes", "dims"}}},
|
||||
{"Transpose", {{"perm", "axes"}}},
|
||||
{"ConvTranspose", {{"output_padding", "adjs"}}},
|
||||
{"Selu", {{"gamma", "scale"}}}};
|
||||
|
||||
return kPerOpRenamedAttrs;
|
||||
}
|
||||
|
|
@ -462,7 +465,8 @@ Caffe2Ops Caffe2Backend::CreateConstantOfShape(
|
|||
auto* c2_op = ret.ops.Add();
|
||||
const auto* value = onnx_node->attributes.get<const TensorProto*>("value");
|
||||
if (value) {
|
||||
BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0));
|
||||
BuildTensorFillingOp(
|
||||
c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0));
|
||||
} else {
|
||||
c2_op->set_type("ConstantFill");
|
||||
c2_op->add_input(onnx_node->node.input(0));
|
||||
|
|
@ -604,7 +608,8 @@ Caffe2Ops Caffe2Backend::CreateNonZeroOp(
|
|||
auto ret = CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
|
||||
|
||||
auto* c2_transpose = ret.ops.Add();
|
||||
BuildOperator(c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)});
|
||||
BuildOperator(
|
||||
c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)});
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
@ -612,7 +617,8 @@ Caffe2Ops Caffe2Backend::CreateMultinomialOp(
|
|||
OnnxNode* onnx_node,
|
||||
const ConversionContext& ctx) {
|
||||
// Fallback to ATen.
|
||||
// ATen::Multinomial takes probabilities as input, ONNX Multinomial expects input to be log probabilities.
|
||||
// ATen::Multinomial takes probabilities as input, ONNX Multinomial expects
|
||||
// input to be log probabilities.
|
||||
Caffe2Ops ret;
|
||||
auto c2_exp_output = dummy_->NewDummyName();
|
||||
auto* c2_exp = ret.ops.Add();
|
||||
|
|
@ -631,10 +637,10 @@ Caffe2Ops Caffe2Backend::CreateMultinomialOp(
|
|||
c2_arg_num.set_name("num_samples");
|
||||
c2_arg_num.set_i(onnx_attributes.get<int64_t>("sample_size"));
|
||||
|
||||
// ONNX Multinomial has attribute dtype in {int64, int32}, which specifies output datatype.
|
||||
// ATen::Multinomial output dtype is always int64.
|
||||
// ONNX Multinomial has attribute dtype in {int64, int32}, which specifies
|
||||
// output datatype. ATen::Multinomial output dtype is always int64.
|
||||
auto onnx_dtype =
|
||||
onnx_attributes.get<int64_t>("dtype", TensorProto::UNDEFINED);
|
||||
onnx_attributes.get<int64_t>("dtype", TensorProto::UNDEFINED);
|
||||
if (onnx_dtype == ::ONNX_NAMESPACE::TensorProto::INT64) {
|
||||
BuildOperator(
|
||||
c2_multinomial,
|
||||
|
|
@ -655,9 +661,16 @@ Caffe2Ops Caffe2Backend::CreateMultinomialOp(
|
|||
caffe2::Argument to;
|
||||
to.set_name("to");
|
||||
to.set_i(caffe2::TensorProto::INT32);
|
||||
BuildOperator(c2_cast, "Cast", {c2_multinomial_output}, {onnx_node->node.output(0)}, {to});
|
||||
BuildOperator(
|
||||
c2_cast,
|
||||
"Cast",
|
||||
{c2_multinomial_output},
|
||||
{onnx_node->node.output(0)},
|
||||
{to});
|
||||
} else {
|
||||
CAFFE_THROW("ONNX does not support dtype other than int32/int64 in Multinomial, but get ", onnx_dtype);
|
||||
CAFFE_THROW(
|
||||
"ONNX does not support dtype other than int32/int64 in Multinomial, but get ",
|
||||
onnx_dtype);
|
||||
}
|
||||
|
||||
return ret;
|
||||
|
|
@ -750,9 +763,8 @@ Caffe2Ops Caffe2Backend::CreateGemm(
|
|||
auto trans_a = onnx_node->attributes.get<int64_t>("transA", 0L);
|
||||
auto trans_b = onnx_node->attributes.get<int64_t>("transB", 0L);
|
||||
// Support broadcast by default when opset_version > 6.
|
||||
auto broadcast =
|
||||
onnx_node->attributes.get<int64_t>("broadcast",
|
||||
(ctx.opset_version() > 6) ? 1L : 0L);
|
||||
auto broadcast = onnx_node->attributes.get<int64_t>(
|
||||
"broadcast", (ctx.opset_version() > 6) ? 1L : 0L);
|
||||
|
||||
// If the c's shape information is available and c is a 1d tensor(except
|
||||
// c is a scalar), use FC aggressively.
|
||||
|
|
@ -796,7 +808,8 @@ Caffe2Ops Caffe2Backend::CreateGemm(
|
|||
if (trans_b) {
|
||||
BuildOperator(c2_op, "FC", {input_a, input_b, input_c}, {output});
|
||||
} else {
|
||||
BuildOperator(c2_op, "FCTransposed", {input_a, input_b, input_c}, {output});
|
||||
BuildOperator(
|
||||
c2_op, "FCTransposed", {input_a, input_b, input_c}, {output});
|
||||
}
|
||||
} else {
|
||||
auto ab = dummy_->NewDummyName();
|
||||
|
|
@ -1057,14 +1070,15 @@ Caffe2Ops Caffe2Backend::CreateSlice(
|
|||
// the behavior of Caffe2's slice operator not matching that of ONNX's slice
|
||||
// 2) Fully expand the index tensor out to the rank of the data tensor.
|
||||
// pseudocode: indices_full = zeros(rank); indices_full[axes] = indices.int()
|
||||
std::string Caffe2Backend::PreprocessSliceIndexTensor(OnnxNode* onnx_node,
|
||||
Caffe2Ops& ret,
|
||||
std::string indices_tensor,
|
||||
std::string axes_tensor,
|
||||
std::string rank_tensor,
|
||||
std::string zero_tensor,
|
||||
std::string one_tensor,
|
||||
int default_value) {
|
||||
std::string Caffe2Backend::PreprocessSliceIndexTensor(
|
||||
OnnxNode* onnx_node,
|
||||
Caffe2Ops& ret,
|
||||
std::string indices_tensor,
|
||||
std::string axes_tensor,
|
||||
std::string rank_tensor,
|
||||
std::string zero_tensor,
|
||||
std::string one_tensor,
|
||||
int default_value) {
|
||||
auto indices_tensor_full = dummy_->NewDummyName();
|
||||
|
||||
{
|
||||
|
|
@ -1078,8 +1092,12 @@ std::string Caffe2Backend::PreprocessSliceIndexTensor(OnnxNode* onnx_node,
|
|||
input_as_shape.set_name("input_as_shape");
|
||||
input_as_shape.set_i(1);
|
||||
auto c2_op = ret.ops.Add();
|
||||
BuildOperator(c2_op, "ConstantFill", {rank_tensor}, {indices_tensor_full},
|
||||
{value, dtype, input_as_shape});
|
||||
BuildOperator(
|
||||
c2_op,
|
||||
"ConstantFill",
|
||||
{rank_tensor},
|
||||
{indices_tensor_full},
|
||||
{value, dtype, input_as_shape});
|
||||
}
|
||||
|
||||
// Subtract 1 from each element of the indices tensor that is negative
|
||||
|
|
@ -1089,7 +1107,8 @@ std::string Caffe2Backend::PreprocessSliceIndexTensor(OnnxNode* onnx_node,
|
|||
broadcast.set_name("broadcast");
|
||||
broadcast.set_i(1);
|
||||
auto c2_op = ret.ops.Add();
|
||||
BuildOperator(c2_op, "LT", {indices_tensor, zero_tensor}, {lt_tensor}, {broadcast});
|
||||
BuildOperator(
|
||||
c2_op, "LT", {indices_tensor, zero_tensor}, {lt_tensor}, {broadcast});
|
||||
}
|
||||
|
||||
auto sub_one_tensor = dummy_->NewDummyName();
|
||||
|
|
@ -1098,18 +1117,30 @@ std::string Caffe2Backend::PreprocessSliceIndexTensor(OnnxNode* onnx_node,
|
|||
broadcast.set_name("broadcast");
|
||||
broadcast.set_i(1);
|
||||
auto c2_op = ret.ops.Add();
|
||||
BuildOperator(c2_op, "Sub", {indices_tensor, one_tensor}, {sub_one_tensor}, {broadcast});
|
||||
BuildOperator(
|
||||
c2_op,
|
||||
"Sub",
|
||||
{indices_tensor, one_tensor},
|
||||
{sub_one_tensor},
|
||||
{broadcast});
|
||||
}
|
||||
|
||||
auto indices_tensor_adjusted = dummy_->NewDummyName();
|
||||
auto c2_op = ret.ops.Add();
|
||||
BuildOperator(c2_op, "Conditional", {lt_tensor, sub_one_tensor, indices_tensor}, {indices_tensor_adjusted}, {});
|
||||
BuildOperator(
|
||||
c2_op,
|
||||
"Conditional",
|
||||
{lt_tensor, sub_one_tensor, indices_tensor},
|
||||
{indices_tensor_adjusted},
|
||||
{});
|
||||
|
||||
// Fill in values specified from the partially-specified ONNX indices tensor
|
||||
c2_op = ret.ops.Add();
|
||||
BuildOperator(c2_op, "ScatterAssign",
|
||||
{indices_tensor_full, axes_tensor, indices_tensor_adjusted},
|
||||
{indices_tensor_full});
|
||||
BuildOperator(
|
||||
c2_op,
|
||||
"ScatterAssign",
|
||||
{indices_tensor_full, axes_tensor, indices_tensor_adjusted},
|
||||
{indices_tensor_full});
|
||||
|
||||
return indices_tensor_full;
|
||||
}
|
||||
|
|
@ -1164,31 +1195,32 @@ Caffe2Ops Caffe2Backend::CreateDynamicSlice(
|
|||
shape.add_ints(1);
|
||||
auto c2_op = ret.ops.Add();
|
||||
auto name = dummy_->NewDummyName();
|
||||
BuildOperator(c2_op, "ConstantFill", {}, {name},
|
||||
{value, dtype, shape});
|
||||
BuildOperator(c2_op, "ConstantFill", {}, {name}, {value, dtype, shape});
|
||||
return name;
|
||||
};
|
||||
|
||||
auto zero_tensor = define_integer_constant(0);
|
||||
auto one_tensor = define_integer_constant(1);
|
||||
|
||||
auto starts_tensor_full = PreprocessSliceIndexTensor(onnx_node,
|
||||
ret,
|
||||
onnx_node->node.input(1), // starts
|
||||
axes_tensor,
|
||||
rank_tensor,
|
||||
zero_tensor,
|
||||
one_tensor,
|
||||
0);
|
||||
auto starts_tensor_full = PreprocessSliceIndexTensor(
|
||||
onnx_node,
|
||||
ret,
|
||||
onnx_node->node.input(1), // starts
|
||||
axes_tensor,
|
||||
rank_tensor,
|
||||
zero_tensor,
|
||||
one_tensor,
|
||||
0);
|
||||
|
||||
auto ends_tensor_full = PreprocessSliceIndexTensor(onnx_node,
|
||||
ret,
|
||||
onnx_node->node.input(2), // ends
|
||||
axes_tensor,
|
||||
rank_tensor,
|
||||
zero_tensor,
|
||||
one_tensor,
|
||||
-1);
|
||||
auto ends_tensor_full = PreprocessSliceIndexTensor(
|
||||
onnx_node,
|
||||
ret,
|
||||
onnx_node->node.input(2), // ends
|
||||
axes_tensor,
|
||||
rank_tensor,
|
||||
zero_tensor,
|
||||
one_tensor,
|
||||
-1);
|
||||
|
||||
// attach the original op at the end
|
||||
c2_op = ret.ops.Add();
|
||||
|
|
@ -1219,7 +1251,8 @@ Caffe2Ops Caffe2Backend::CreateBatchNormalization(
|
|||
attr->set_i(1);
|
||||
}
|
||||
|
||||
if (attributes.HasAttribute("spatial") && attributes.get<int64_t>("spatial") == 1) {
|
||||
if (attributes.HasAttribute("spatial") &&
|
||||
attributes.get<int64_t>("spatial") == 1) {
|
||||
attributes.remove("spatial");
|
||||
}
|
||||
|
||||
|
|
@ -1263,10 +1296,12 @@ Caffe2Ops Caffe2Backend::CreateUpsample(
|
|||
attributes.remove("mode");
|
||||
|
||||
if (ctx.opset_version() >= 7 && ctx.opset_version() < 9) {
|
||||
const auto& scales = attributes.get<::google::protobuf::RepeatedField<float>>("scales");
|
||||
const auto& scales =
|
||||
attributes.get<::google::protobuf::RepeatedField<float>>("scales");
|
||||
if (scales.size() != 4) {
|
||||
CAFFE_THROW("The scales argument should have size 4");
|
||||
} else if (!AlmostEqual(scales.Get(0), 1) || !AlmostEqual(scales.Get(1), 1)) {
|
||||
} else if (
|
||||
!AlmostEqual(scales.Get(0), 1) || !AlmostEqual(scales.Get(1), 1)) {
|
||||
CAFFE_THROW("The first two elements in the scales argument must be 1");
|
||||
}
|
||||
attributes.remove("scales");
|
||||
|
|
@ -1332,14 +1367,14 @@ Caffe2Ops Caffe2Backend::CreateLRN(
|
|||
auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
|
||||
const auto& attributes = onnx_node->attributes;
|
||||
if (!attributes.HasAttribute("alpha")) {
|
||||
auto* arg = c2_op.ops.Mutable(0)->add_arg();
|
||||
arg->set_name("alpha");
|
||||
arg->set_f(1e-4);
|
||||
auto* arg = c2_op.ops.Mutable(0)->add_arg();
|
||||
arg->set_name("alpha");
|
||||
arg->set_f(1e-4);
|
||||
}
|
||||
if (!attributes.HasAttribute("beta")) {
|
||||
auto* arg = c2_op.ops.Mutable(0)->add_arg();
|
||||
arg->set_name("beta");
|
||||
arg->set_f(0.75);
|
||||
auto* arg = c2_op.ops.Mutable(0)->add_arg();
|
||||
arg->set_name("beta");
|
||||
arg->set_f(0.75);
|
||||
}
|
||||
return c2_op;
|
||||
}
|
||||
|
|
@ -1347,8 +1382,8 @@ Caffe2Ops Caffe2Backend::CreateLRN(
|
|||
//==============================================
|
||||
// Rest of the member functions for Caffe2Backend
|
||||
//==============================================
|
||||
std::unordered_set<std::string>
|
||||
Caffe2Backend::AllNamesInGraph(const GraphProto &graph) {
|
||||
std::unordered_set<std::string> Caffe2Backend::AllNamesInGraph(
|
||||
const GraphProto& graph) {
|
||||
std::unordered_set<std::string> names;
|
||||
|
||||
for (const auto& input : graph.input()) {
|
||||
|
|
@ -1406,8 +1441,7 @@ Caffe2Ops Caffe2Backend::CommonOnnxNodeToCaffe2Ops(
|
|||
c2_op->set_type(
|
||||
caffe2::get_default(get_renamed_operators(), onnx_op_type, onnx_op_type));
|
||||
if (!IsOperator(c2_op->type())) {
|
||||
CAFFE_THROW(
|
||||
"Don't know how to translate op ", onnx_op_type);
|
||||
CAFFE_THROW("Don't know how to translate op ", onnx_op_type);
|
||||
}
|
||||
|
||||
auto mapper = [&, this](const std::string& k) {
|
||||
|
|
@ -1446,7 +1480,7 @@ void Caffe2Backend::CheckOpSchemaArguments(
|
|||
const caffe2::OpSchema& schema,
|
||||
const caffe2::OperatorDef& op) {
|
||||
const auto& schema_args = schema.args();
|
||||
if (schema_args.size() > 0){
|
||||
if (schema_args.size() > 0) {
|
||||
std::vector<std::string> argnames;
|
||||
std::transform(
|
||||
schema_args.begin(),
|
||||
|
|
@ -1460,12 +1494,16 @@ void Caffe2Backend::CheckOpSchemaArguments(
|
|||
"Don't know how to map unexpected argument ",
|
||||
arg.name(),
|
||||
" (from operator ",
|
||||
op.type(), ")");
|
||||
op.type(),
|
||||
")");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// A number of C2 operators do not declare proper arguments. Let's log the error
|
||||
VLOG(2) << "Operator " << op.type() << " does not declare arguments in its schema. Please file a Caffe2 issue.";
|
||||
// A number of C2 operators do not declare proper arguments. Let's log the
|
||||
// error
|
||||
VLOG(2)
|
||||
<< "Operator " << op.type()
|
||||
<< " does not declare arguments in its schema. Please file a Caffe2 issue.";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1482,12 +1520,14 @@ Caffe2Ops Caffe2Backend::OnnxNodeToCaffe2Ops(
|
|||
res = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
|
||||
}
|
||||
|
||||
for (const auto& result_op: res.ops){
|
||||
for (const auto& result_op : res.ops) {
|
||||
const auto* schema = OpSchemaRegistry::Schema(result_op.type());
|
||||
if (schema) {
|
||||
CheckOpSchemaArguments(*schema, result_op);
|
||||
} else {
|
||||
CAFFE_THROW("Caffe2 has no such operator, could not find schema for ", result_op.type());
|
||||
CAFFE_THROW(
|
||||
"Caffe2 has no such operator, could not find schema for ",
|
||||
result_op.type());
|
||||
}
|
||||
}
|
||||
return res;
|
||||
|
|
@ -1660,23 +1700,23 @@ Caffe2BackendRep* Caffe2Backend::Prepare(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ConvertIntegralValueToCaffe2(caffe2::OperatorDef* c2_op,
|
||||
caffe2::Argument* c2_values,
|
||||
const TensorProto& onnx_tensor) {
|
||||
void ConvertIntegralValueToCaffe2(
|
||||
caffe2::OperatorDef* c2_op,
|
||||
caffe2::Argument* c2_values,
|
||||
const TensorProto& onnx_tensor) {
|
||||
c2_op->set_type(
|
||||
onnx_tensor.data_type() == TensorProto::BOOL ? "GivenTensorBoolFill"
|
||||
: "GivenTensorIntFill");
|
||||
::google::protobuf::RepeatedField<T> tmp;
|
||||
const ::google::protobuf::RepeatedField<T>* src =
|
||||
&tmp;
|
||||
const ::google::protobuf::RepeatedField<T>* src = &tmp;
|
||||
bool converted = TryConvertingTensorRawValues<T>(onnx_tensor, &tmp);
|
||||
if (converted) {
|
||||
for (const auto i : *src) {
|
||||
c2_values->add_ints(i);
|
||||
}
|
||||
} else {
|
||||
const ::google::protobuf::RepeatedField<::google::protobuf::int32> *int32_src = \
|
||||
&onnx_tensor.int32_data();
|
||||
const ::google::protobuf::RepeatedField<::google::protobuf::int32>*
|
||||
int32_src = &onnx_tensor.int32_data();
|
||||
for (const auto i : *int32_src) {
|
||||
c2_values->add_ints(i);
|
||||
}
|
||||
|
|
@ -1684,9 +1724,10 @@ void ConvertIntegralValueToCaffe2(caffe2::OperatorDef* c2_op,
|
|||
}
|
||||
|
||||
template <>
|
||||
void ConvertIntegralValueToCaffe2<::google::protobuf::int64>(caffe2::OperatorDef* c2_op,
|
||||
caffe2::Argument* c2_values,
|
||||
const TensorProto& onnx_tensor) {
|
||||
void ConvertIntegralValueToCaffe2<::google::protobuf::int64>(
|
||||
caffe2::OperatorDef* c2_op,
|
||||
caffe2::Argument* c2_values,
|
||||
const TensorProto& onnx_tensor) {
|
||||
c2_op->set_type("GivenTensorInt64Fill");
|
||||
auto* ints = c2_values->mutable_ints();
|
||||
if (!TryConvertingTensorRawValues<::google::protobuf::int64>(
|
||||
|
|
@ -1696,9 +1737,10 @@ void ConvertIntegralValueToCaffe2<::google::protobuf::int64>(caffe2::OperatorDef
|
|||
}
|
||||
|
||||
template <>
|
||||
void ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(caffe2::OperatorDef* c2_op,
|
||||
caffe2::Argument* c2_values,
|
||||
const TensorProto& onnx_tensor) {
|
||||
void ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(
|
||||
caffe2::OperatorDef* c2_op,
|
||||
caffe2::Argument* c2_values,
|
||||
const TensorProto& onnx_tensor) {
|
||||
c2_op->set_type("GivenTensorInt64Fill");
|
||||
::google::protobuf::RepeatedField<::google::protobuf::uint64> tmp;
|
||||
const ::google::protobuf::RepeatedField<::google::protobuf::uint64>* src =
|
||||
|
|
@ -1747,22 +1789,30 @@ void Caffe2Backend::BuildTensorFillingOp(
|
|||
c2_values->add_floats(i);
|
||||
}
|
||||
} else if (onnx_tensor.data_type() == TensorProto::INT64) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int64>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int64>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::UINT32) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(c2_op, c2_values, onnx_tensor);
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
} else if (onnx_tensor.data_type() == TensorProto::BOOL) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::UINT8) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::INT8) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::UINT16) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::INT16) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int16>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int16>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::INT32) {
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int32>(c2_op, c2_values, onnx_tensor);
|
||||
ConvertIntegralValueToCaffe2<::google::protobuf::int32>(
|
||||
c2_op, c2_values, onnx_tensor);
|
||||
} else if (onnx_tensor.data_type() == TensorProto::STRING) {
|
||||
c2_op->set_type("GivenTensorStringFill");
|
||||
auto* strings = c2_values->mutable_strings();
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/onnx/backend_rep.h"
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace caffe2 { namespace onnx {
|
||||
namespace caffe2 {
|
||||
namespace onnx {
|
||||
|
||||
void Caffe2BackendRep::CheckInit() {
|
||||
if (!predictor_) {
|
||||
|
|
@ -28,4 +29,5 @@ void Caffe2BackendRep::RunMap(
|
|||
(*predictor_)(inputs, outputs);
|
||||
}
|
||||
|
||||
}}
|
||||
} // namespace onnx
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -3,15 +3,16 @@
|
|||
#include <cstdlib>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace caffe2 { namespace onnx {
|
||||
namespace caffe2 {
|
||||
namespace onnx {
|
||||
static const std::unordered_map<std::string, DeviceType> kDeviceMap = {
|
||||
{"CPU", DeviceType::CPU},
|
||||
{"CUDA", DeviceType::CUDA}
|
||||
};
|
||||
{"CPU", DeviceType::CPU},
|
||||
{"CUDA", DeviceType::CUDA}};
|
||||
|
||||
Device::Device(const std::string &spec) {
|
||||
Device::Device(const std::string& spec) {
|
||||
auto pos = spec.find_first_of(':');
|
||||
type = kDeviceMap.at(spec.substr(0, pos - 1));
|
||||
device_id = atoi(spec.substr(pos + 1).c_str());
|
||||
}
|
||||
}}
|
||||
} // namespace onnx
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
namespace caffe2 { namespace onnx {
|
||||
namespace caffe2 {
|
||||
namespace onnx {
|
||||
|
||||
std::string DummyName::NewDummyName() {
|
||||
while (true) {
|
||||
|
|
@ -16,7 +17,7 @@ std::string DummyName::NewDummyName() {
|
|||
}
|
||||
}
|
||||
|
||||
void DummyName::Reset(const std::unordered_set<std::string> &used_names) {
|
||||
void DummyName::Reset(const std::unordered_set<std::string>& used_names) {
|
||||
used_names_ = used_names;
|
||||
counter_ = 0;
|
||||
}
|
||||
|
|
@ -44,15 +45,16 @@ NodeProto MakeNode(
|
|||
node.set_name(name);
|
||||
}
|
||||
node.set_op_type(type);
|
||||
for (const auto& input: inputs) {
|
||||
for (const auto& input : inputs) {
|
||||
node.add_input(input);
|
||||
}
|
||||
for (const auto& output: outputs) {
|
||||
for (const auto& output : outputs) {
|
||||
node.add_output(output);
|
||||
}
|
||||
for (const auto& attr: attributes) {
|
||||
for (const auto& attr : attributes) {
|
||||
node.add_attribute()->CopyFrom(attr);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
}}
|
||||
} // namespace onnx
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -469,11 +469,12 @@ const std::
|
|||
OnnxExporter::get_per_op_renamed_attrs() const {
|
||||
const static std::
|
||||
unordered_map<std::string, std::unordered_map<std::string, std::string>>
|
||||
kPerOpRenamedAttrs = {{"Squeeze", {{"dims", "axes"}}},
|
||||
{"Unsqueeze", {{"dims", "axes"}}},
|
||||
{"Transpose", {{"axes", "perm"}}},
|
||||
{"ConvTranspose", {{"adjs", "output_padding"}}},
|
||||
{"Selu", {{"scale", "gamma"}}}};
|
||||
kPerOpRenamedAttrs = {
|
||||
{"Squeeze", {{"dims", "axes"}}},
|
||||
{"Unsqueeze", {{"dims", "axes"}}},
|
||||
{"Transpose", {{"axes", "perm"}}},
|
||||
{"ConvTranspose", {{"adjs", "output_padding"}}},
|
||||
{"Selu", {{"scale", "gamma"}}}};
|
||||
|
||||
return kPerOpRenamedAttrs;
|
||||
}
|
||||
|
|
@ -556,11 +557,12 @@ bool OnnxExporter::IsBlockListed(const caffe2::Argument& arg) {
|
|||
const static std::unordered_map<std::string, std::unordered_set<std::string>>
|
||||
kBlockListString = {{"order", {"NCHW"}}};
|
||||
const static std::unordered_map<std::string, std::unordered_set<int64_t>>
|
||||
kBlockListInt = {{"cudnn_exhaustive_search", {0, 1}},
|
||||
{"use_cudnn", {0, 1}},
|
||||
{"exhaustive_search", {0, 1}},
|
||||
{"is_test", {0, 1}},
|
||||
{"broadcast", {0, 1}}};
|
||||
kBlockListInt = {
|
||||
{"cudnn_exhaustive_search", {0, 1}},
|
||||
{"use_cudnn", {0, 1}},
|
||||
{"exhaustive_search", {0, 1}},
|
||||
{"is_test", {0, 1}},
|
||||
{"broadcast", {0, 1}}};
|
||||
|
||||
if (arg.has_i()) {
|
||||
const auto it = kBlockListInt.find(arg.name());
|
||||
|
|
|
|||
|
|
@ -133,8 +133,8 @@ void adagrad_update_prefetch(
|
|||
|
||||
// Version with prefetching for embeddings and
|
||||
// momentum using fp16
|
||||
decltype(
|
||||
adagrad_fp16_update_prefetch__base) adagrad_fp16_update_prefetch__avx2_fma;
|
||||
decltype(adagrad_fp16_update_prefetch__base)
|
||||
adagrad_fp16_update_prefetch__avx2_fma;
|
||||
void adagrad_fp16_update_prefetch(
|
||||
int N,
|
||||
const at::Half* w,
|
||||
|
|
|
|||
|
|
@ -105,12 +105,12 @@ In foo.cc, do:
|
|||
#endif // CAFFE2_PERF_WITH_AVX2
|
||||
|
||||
#ifdef CAFFE2_PERF_WITH_AVX
|
||||
#define AVX_DO(funcname, ...) \
|
||||
{ \
|
||||
#define AVX_DO(funcname, ...) \
|
||||
{ \
|
||||
static const bool isDo = cpuinfo_initialize() && cpuinfo_has_x86_avx(); \
|
||||
if (isDo) { \
|
||||
return funcname##__avx(__VA_ARGS__); \
|
||||
} \
|
||||
if (isDo) { \
|
||||
return funcname##__avx(__VA_ARGS__); \
|
||||
} \
|
||||
}
|
||||
#define AVX_F16C_DO(funcname, ...) \
|
||||
{ \
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
// Apple clang was fixed in 8.1
|
||||
#if defined(__apple_build_version__) && ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1)))
|
||||
#if defined(__apple_build_version__) && \
|
||||
((__clang_major__ < 8) || \
|
||||
((__clang_major__ == 8) && (__clang_minor__ < 1)))
|
||||
#define CAFFE2_INTERNAL_APPLE_NEED_FIX 1
|
||||
#endif
|
||||
|
||||
|
|
@ -10,7 +12,8 @@
|
|||
#define CAFFE2_INTERNAL_CLANG_NEED_FIX 1
|
||||
#endif
|
||||
|
||||
#if defined(CAFFE2_INTERNAL_APPLE_NEED_FIX) || defined(CAFFE2_INTERNAL_CLANG_NEED_FIX)
|
||||
#if defined(CAFFE2_INTERNAL_APPLE_NEED_FIX) || \
|
||||
defined(CAFFE2_INTERNAL_CLANG_NEED_FIX)
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include <emmintrin.h>
|
||||
|
|
@ -19,8 +22,7 @@
|
|||
// https://reviews.llvm.org/D16177
|
||||
static __inline float
|
||||
__attribute__((__always_inline__, __nodebug__, __target__("f16c")))
|
||||
_cvtsh_ss(unsigned short a)
|
||||
{
|
||||
_cvtsh_ss(unsigned short a) {
|
||||
__v8hi v = {(short)a, 0, 0, 0, 0, 0, 0, 0};
|
||||
__v4sf r = __builtin_ia32_vcvtph2ps(v);
|
||||
return r[0];
|
||||
|
|
@ -28,7 +30,7 @@ _cvtsh_ss(unsigned short a)
|
|||
|
||||
static __inline unsigned short
|
||||
__attribute__((__always_inline__, __nodebug__, __target__("f16c")))
|
||||
_cvtss_sh(float a, int imm8) {
|
||||
_cvtss_sh(float a, int imm8) {
|
||||
unsigned short ret;
|
||||
*reinterpret_cast<at::Half*>(&ret) = a;
|
||||
return ret;
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ static bool EmbeddingLookupGenericSlow(
|
|||
return current == index_size;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Proxy back to generic implementation
|
||||
#define EMBEDDING_SPECIALIZATION( \
|
||||
IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \
|
||||
|
|
@ -204,6 +205,7 @@ static bool EmbeddingLookupGenericSlow(
|
|||
"Your input seems to be incorrect: the sum of lengths values should be " \
|
||||
"the size of the indices tensor, but it appears not."); \
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
EMBEDDING_SPECIALIZATION(int32_t, float, float, float, false);
|
||||
EMBEDDING_SPECIALIZATION(int64_t, float, float, float, false);
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
|||
return current == index_size;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Proxy back to generic implementation
|
||||
#define EMBEDDING_IDX_SPECIALIZATION( \
|
||||
IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \
|
||||
|
|
@ -207,6 +208,7 @@ static bool EmbeddingLookupGenericSlowIdx(
|
|||
"Your input seems to be incorrect: the sum of lengths values should be " \
|
||||
"the size of the indices tensor, but it appears not."); \
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlow(
|
|||
return current == index_size;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Proxy back to generic implementation
|
||||
#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \
|
||||
bool \
|
||||
|
|
@ -201,6 +202,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlow(
|
|||
"Your input seems to be incorrect: the sum of lengths values should be " \
|
||||
"the size of the indices tensor, but it appears not."); \
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, float);
|
||||
FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, float);
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
|
|||
return current == index_size;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Proxy back to generic implementation
|
||||
#define FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(IndexType, OutType) \
|
||||
bool \
|
||||
|
|
@ -203,6 +204,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
|
|||
"Your input seems to be incorrect: the sum of lengths values should be " \
|
||||
"the size of the indices tensor, but it appears not."); \
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int32_t, float);
|
||||
FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int64_t, float);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
#include <string.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string.h>
|
||||
#include "caffe2/utils/conversions.h"
|
||||
|
||||
#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG)
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class ThreadLocalPtrImpl {
|
|||
template <typename T>
|
||||
class ThreadLocalPtr {
|
||||
public:
|
||||
auto* operator-> () {
|
||||
auto* operator->() {
|
||||
return get();
|
||||
}
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ class ThreadLocalPtr {
|
|||
return impl_.get<T>();
|
||||
}
|
||||
|
||||
auto* operator-> () const {
|
||||
auto* operator->() const {
|
||||
return get();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ DataRandomFiller::DataRandomFiller(
|
|||
|
||||
// load op inputs and outputs
|
||||
std::unordered_set<std::string> output_names;
|
||||
for (auto i: c10::irange(run_net.op_size())) {
|
||||
for (auto i : c10::irange(run_net.op_size())) {
|
||||
const auto& op = run_net.op(i);
|
||||
const auto& op_dims = input_dims[i];
|
||||
const auto& op_types = input_types[i];
|
||||
|
|
@ -78,7 +78,7 @@ DataRandomFiller::DataRandomFiller(
|
|||
" inputs; while the input type size is " +
|
||||
c10::to_string(op_types.size()));
|
||||
|
||||
for (auto j: c10::irange(op.input_size())) {
|
||||
for (auto j : c10::irange(op.input_size())) {
|
||||
inputs_[op.input(j)] =
|
||||
std::make_pair(get_tensor_filler(op, j, op_dims), op_types[j]);
|
||||
}
|
||||
|
|
@ -99,19 +99,19 @@ DataRandomFiller::DataRandomFiller(
|
|||
}
|
||||
}
|
||||
|
||||
for (auto j: c10::irange(op.output_size())) {
|
||||
for (auto j : c10::irange(op.output_size())) {
|
||||
output_names.emplace(op.output(j));
|
||||
}
|
||||
}
|
||||
|
||||
// load parameters
|
||||
std::unordered_set<std::string> parameters;
|
||||
for (auto i: c10::irange(run_net.arg_size())) {
|
||||
for (auto i : c10::irange(run_net.arg_size())) {
|
||||
const auto& arg = run_net.arg(i);
|
||||
// TODO: replace "PredictorParameters" with the constant in OSS bbp
|
||||
if (arg.has_name() && arg.name() == "PredictorParameters") {
|
||||
parameters.reserve(arg.strings_size());
|
||||
for (auto j: c10::irange(arg.strings_size())) {
|
||||
for (auto j : c10::irange(arg.strings_size())) {
|
||||
parameters.emplace(arg.strings(j));
|
||||
}
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ Predictor::Predictor(
|
|||
|
||||
Predictor::Predictor(PredictorConfig config) : config_(std::move(config)) {
|
||||
const auto& initialized_vec = config_.ws->Blobs();
|
||||
const std::unordered_set<std::string> initialized{initialized_vec.begin(),
|
||||
initialized_vec.end()};
|
||||
const std::unordered_set<std::string> initialized{
|
||||
initialized_vec.begin(), initialized_vec.end()};
|
||||
for (const auto& name : config_.predict_net->external_input()) {
|
||||
if (!initialized.count(name)) {
|
||||
auto* blob = config_.ws->CreateBlob(name);
|
||||
|
|
@ -70,7 +70,7 @@ bool Predictor::operator()(const TensorList& inputs, TensorList* outputs) {
|
|||
return false;
|
||||
}
|
||||
outputs->clear();
|
||||
for (auto i: c10::irange(config_.predict_net->external_output_size())) {
|
||||
for (auto i : c10::irange(config_.predict_net->external_output_size())) {
|
||||
outputs->emplace_back(
|
||||
getTensor(config_.ws.get(), config_.predict_net->external_output(i))
|
||||
.UnsafeSharedInstance());
|
||||
|
|
@ -104,7 +104,7 @@ bool Predictor::operator()(const TensorMap& inputs, TensorList* outputs) {
|
|||
return false;
|
||||
}
|
||||
outputs->clear();
|
||||
for (auto i: c10::irange(config_.predict_net->external_output_size())) {
|
||||
for (auto i : c10::irange(config_.predict_net->external_output_size())) {
|
||||
outputs->push_back(
|
||||
getTensor(config_.ws.get(), config_.predict_net->external_output(i))
|
||||
.UnsafeSharedInstance());
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ struct TORCH_API PredictorConfig {
|
|||
std::shared_ptr<Workspace> ws;
|
||||
};
|
||||
|
||||
TORCH_API Workspace makeWorkspace(std::shared_ptr<PredictorParameters> parameters);
|
||||
TORCH_API Workspace
|
||||
makeWorkspace(std::shared_ptr<PredictorParameters> parameters);
|
||||
|
||||
TORCH_API PredictorConfig makePredictorConfig(
|
||||
const MetaNetDef& net,
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ MetaNetDef parseMetaNetDef(const std::string& value) {
|
|||
value);
|
||||
return def;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class PredictorTest : public testing::Test {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@
|
|||
namespace caffe2 {
|
||||
namespace predictor_utils {
|
||||
|
||||
TORCH_API const NetDef& getNet(
|
||||
const MetaNetDef& def,
|
||||
const std::string& name) {
|
||||
TORCH_API const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
|
||||
for (const auto& n : def.nets()) {
|
||||
if (n.key() == name) {
|
||||
return n.value();
|
||||
|
|
|
|||
|
|
@ -21,26 +21,26 @@ message TensorProto {
|
|||
UNDEFINED = 0;
|
||||
|
||||
// Basic types
|
||||
FLOAT = 1; // float
|
||||
INT32 = 2; // int
|
||||
BYTE = 3; // byte, when deserialized, is going to be restored as uint8
|
||||
STRING = 4; // string
|
||||
FLOAT = 1; // float
|
||||
INT32 = 2; // int
|
||||
BYTE = 3; // byte, when deserialized, is going to be restored as uint8
|
||||
STRING = 4; // string
|
||||
|
||||
// Less-commonly used data types
|
||||
BOOL = 5; // bool
|
||||
UINT8 = 6; // uint8_t
|
||||
INT8 = 7; // int8_t
|
||||
UINT16 = 8; // uint16_t
|
||||
INT16 = 9; // int16_t
|
||||
INT64 = 10; // int64_t
|
||||
FLOAT16 = 12; // at::Half
|
||||
DOUBLE = 13; // double
|
||||
BOOL = 5; // bool
|
||||
UINT8 = 6; // uint8_t
|
||||
INT8 = 7; // int8_t
|
||||
UINT16 = 8; // uint16_t
|
||||
INT16 = 9; // int16_t
|
||||
INT64 = 10; // int64_t
|
||||
FLOAT16 = 12; // at::Half
|
||||
DOUBLE = 13; // double
|
||||
|
||||
ZERO_COLLISION_HASH = 14; // zero-collision hash state
|
||||
REBATCHING_BUFFER= 15; // rebatching buffer
|
||||
ZERO_COLLISION_HASH = 14; // zero-collision hash state
|
||||
REBATCHING_BUFFER = 15; // rebatching buffer
|
||||
}
|
||||
// The type of the deserialized tensor data
|
||||
optional DataType data_type = 2 [default = FLOAT];
|
||||
optional DataType data_type = 2 [ default = FLOAT ];
|
||||
|
||||
// The format of the serialized data.
|
||||
enum SerializationFormat {
|
||||
|
|
@ -58,10 +58,10 @@ message TensorProto {
|
|||
// new messages that have a SerializationFormat value that we don't
|
||||
// understand. If we stored this as an enum then protobuf would deserialize
|
||||
// both of these cases the same way.
|
||||
optional uint32 data_format = 15 [default = 0];
|
||||
optional uint32 data_format = 15 [ default = 0 ];
|
||||
|
||||
// For float
|
||||
repeated float float_data = 3 [packed = true];
|
||||
repeated float float_data = 3 [ packed = true ];
|
||||
// For int32, uint8, int8, uint16, int16, bool, and float16
|
||||
// Note about float16: in storage we will basically convert float16 byte-wise
|
||||
// to unsigned short and then store them in the int32_data field.
|
||||
|
|
@ -69,15 +69,15 @@ message TensorProto {
|
|||
// larger serialized data than necessary, as protobuf's varint encoding
|
||||
// scheme requires 2 bytes to represent int8 and uint8 values that have the
|
||||
// MSB set.
|
||||
repeated int32 int32_data = 4 [packed = true];
|
||||
repeated int32 int32_data = 4 [ packed = true ];
|
||||
// For bytes
|
||||
optional bytes byte_data = 5;
|
||||
// For strings
|
||||
repeated bytes string_data = 6;
|
||||
// For double
|
||||
repeated double double_data = 9 [packed = true];
|
||||
repeated double double_data = 9 [ packed = true ];
|
||||
// For int64
|
||||
repeated int64 int64_data = 10 [packed = true];
|
||||
repeated int64 int64_data = 10 [ packed = true ];
|
||||
// store the raw data, contents are serialized as little-endian
|
||||
optional bytes raw_data = 13;
|
||||
|
||||
|
|
@ -107,9 +107,9 @@ message QTensorProto {
|
|||
required double scale = 3;
|
||||
required double bias = 4;
|
||||
required bool is_signed = 5;
|
||||
repeated int32 data = 6 [packed = true];
|
||||
repeated int32 data = 6 [ packed = true ];
|
||||
optional string name = 7;
|
||||
optional TensorProto.DataType data_type = 8 [default = INT32];
|
||||
optional TensorProto.DataType data_type = 8 [ default = INT32 ];
|
||||
|
||||
// Multi-group quantization params
|
||||
repeated double scales = 9;
|
||||
|
|
@ -120,7 +120,7 @@ message QTensorProto {
|
|||
optional int32 axis = 11;
|
||||
|
||||
// It should be true if it is a multi-group quantization proto
|
||||
optional bool is_multiparam = 12 [default = false];
|
||||
optional bool is_multiparam = 12 [ default = false ];
|
||||
}
|
||||
|
||||
// TensorProtos stores multiple TensorProto objects in one single proto. This
|
||||
|
|
@ -132,9 +132,9 @@ message TensorProtos {
|
|||
|
||||
message TensorShape {
|
||||
repeated int64 dims = 1;
|
||||
optional TensorProto.DataType data_type = 2 [default = FLOAT];
|
||||
optional TensorProto.DataType data_type = 2 [ default = FLOAT ];
|
||||
repeated int32 unknown_dims = 3;
|
||||
optional bool unknown_shape = 4 [default = false];
|
||||
optional bool unknown_shape = 4 [ default = false ];
|
||||
optional string name = 5;
|
||||
}
|
||||
|
||||
|
|
@ -150,8 +150,8 @@ message TensorShapes {
|
|||
message TensorBoundShape {
|
||||
optional TensorShape shape = 1;
|
||||
enum DimType {
|
||||
UNKNOWN = 0; // unknown
|
||||
CONSTANT = 1; // constant
|
||||
UNKNOWN = 0; // unknown
|
||||
CONSTANT = 1; // constant
|
||||
// batch, corresponding dimension is batch_size
|
||||
BATCH = 2;
|
||||
// batch_of_feature_max,
|
||||
|
|
@ -164,9 +164,8 @@ message TensorBoundShape {
|
|||
FEATURE_MAX = 5;
|
||||
// feature_max_default, corresponding shape is default_feature_length
|
||||
FEATURE_MAX_DEFAULT = 6;
|
||||
|
||||
}
|
||||
repeated DimType dim_type = 2; // dim_type.size() == shape.dims.size()
|
||||
repeated DimType dim_type = 2; // dim_type.size() == shape.dims.size()
|
||||
optional string name = 3;
|
||||
// a flag to indicate whether the shape is final and cannot be changed
|
||||
// eg: input/output of in-place ops
|
||||
|
|
@ -211,17 +210,17 @@ message Argument {
|
|||
// line in the DeviceTypeName() function in caffe2/utils/proto_utils.cc
|
||||
// and update c10/core/DeviceType.h
|
||||
enum DeviceTypeProto {
|
||||
PROTO_CPU = 0; // In default, we will use CPU.
|
||||
PROTO_CUDA = 1; // CUDA.
|
||||
PROTO_MKLDNN = 2; // Reserved for explicit MKLDNN
|
||||
PROTO_OPENGL = 3; // OpenGL
|
||||
PROTO_OPENCL = 4; // OpenCL
|
||||
PROTO_IDEEP = 5; // IDEEP.
|
||||
PROTO_HIP = 6; // AMD HIP
|
||||
PROTO_FPGA = 7; // FPGA
|
||||
PROTO_ORT = 8; // ONNX Runtime
|
||||
PROTO_XLA = 9; // XLA / TPU
|
||||
PROTO_MLC = 10; // ML Compute
|
||||
PROTO_CPU = 0; // In default, we will use CPU.
|
||||
PROTO_CUDA = 1; // CUDA.
|
||||
PROTO_MKLDNN = 2; // Reserved for explicit MKLDNN
|
||||
PROTO_OPENGL = 3; // OpenGL
|
||||
PROTO_OPENCL = 4; // OpenCL
|
||||
PROTO_IDEEP = 5; // IDEEP.
|
||||
PROTO_HIP = 6; // AMD HIP
|
||||
PROTO_FPGA = 7; // FPGA
|
||||
PROTO_ORT = 8; // ONNX Runtime
|
||||
PROTO_XLA = 9; // XLA / TPU
|
||||
PROTO_MLC = 10; // ML Compute
|
||||
// Change the following number if you add more devices in the code.
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11;
|
||||
}
|
||||
|
|
@ -270,7 +269,6 @@ message OperatorDef {
|
|||
// type.
|
||||
optional string engine = 7;
|
||||
|
||||
|
||||
// Additional 'fake' inputs used for expressing control dependencies
|
||||
// in the operator graph. This can be used to ensure that an
|
||||
// operator does not run until another operator is ready, for e.g.
|
||||
|
|
@ -281,7 +279,7 @@ message OperatorDef {
|
|||
|
||||
// is_gradient_op argument is only used as a hint in shape inference
|
||||
// and has no runtime significance
|
||||
optional bool is_gradient_op = 9 [default = false];
|
||||
optional bool is_gradient_op = 9 [ default = false ];
|
||||
|
||||
// debug information associated with the construction of the operator.
|
||||
// This is an optional string with no assumed characteristics as
|
||||
|
|
@ -387,7 +385,6 @@ message NetDef {
|
|||
repeated PartitionInfo partition_info = 9;
|
||||
}
|
||||
|
||||
|
||||
// ExecutionStep is actually a sort-of-hacky way we simulate iteration right
|
||||
// now.
|
||||
message ExecutionStep {
|
||||
|
|
@ -415,7 +412,7 @@ message ExecutionStep {
|
|||
// Criteria network specifies a single output (TensorCPU<bool>) of
|
||||
// size (1), is run on every iteration by the executor, and
|
||||
// execution terminates when the output[0] is `false`.
|
||||
optional string criteria_network = 5 [deprecated=true];
|
||||
optional string criteria_network = 5 [ deprecated = true ];
|
||||
|
||||
// DEPRECATED. Use `run_every_ms`.
|
||||
optional string report_net = 7;
|
||||
|
|
@ -440,7 +437,8 @@ message ExecutionStep {
|
|||
// 2) the first substep decide which of the rest of the steps should be run.
|
||||
// 3) external control
|
||||
//
|
||||
// ** It is the user's responsibility to not to put this blob in race conditions.
|
||||
// ** It is the user's responsibility to not to put this blob in race
|
||||
// conditions.
|
||||
// ** For example when setting this blob in concurrent substeps
|
||||
optional string should_stop_blob = 9;
|
||||
|
||||
|
|
|
|||
|
|
@ -14,11 +14,11 @@ message CaffeDatum {
|
|||
// Optionally, the datum could also hold float data.
|
||||
repeated float float_data = 6;
|
||||
// If true data contains an encoded image that need to be decoded
|
||||
optional bool encoded = 7 [default = false];
|
||||
optional bool encoded = 7 [ default = false ];
|
||||
}
|
||||
|
||||
enum LegacyPadding {
|
||||
NOTSET = 0; // Do not use old-stype padding strategies.
|
||||
NOTSET = 0; // Do not use old-stype padding strategies.
|
||||
|
||||
// VALID and SAME are two strategies adopted in Google DistBelief: it forces
|
||||
// the input shape as follows. For SAME, the output is:
|
||||
|
|
|
|||
|
|
@ -76,8 +76,7 @@ inline TORCH_API DeviceTypeProto TypeToProto(const DeviceType& t) {
|
|||
}
|
||||
}
|
||||
|
||||
inline TORCH_API caffe2::DeviceOption DeviceToOption(
|
||||
const at::Device& device) {
|
||||
inline TORCH_API caffe2::DeviceOption DeviceToOption(const at::Device& device) {
|
||||
caffe2::DeviceOption option;
|
||||
auto type = device.type();
|
||||
option.set_device_type(TypeToProto(type));
|
||||
|
|
|
|||
|
|
@ -31,5 +31,6 @@ message PredictorConsts {
|
|||
// Shape info blob name
|
||||
optional string SHAPE_INFO_BLOB = 13 [ default = "SHAPE_INFO_BLOB" ];
|
||||
// Sequential blob reader name
|
||||
optional string DEFERRED_BLOB_READER = 14 [ default = "__DEFERRED_BLOB_READER__" ];
|
||||
optional string DEFERRED_BLOB_READER = 14
|
||||
[ default = "__DEFERRED_BLOB_READER__" ];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ message TwoNumberStatsProto {
|
|||
// a node outputs to the blob.
|
||||
message BlobProfile {
|
||||
// Name of the blob (corresponds to OperatorDef.output).
|
||||
optional string name = 1; // required
|
||||
optional string name = 1; // required
|
||||
|
||||
// Profiling statistics.
|
||||
optional TwoNumberStatsProto bytes_used = 3;
|
||||
|
|
@ -45,7 +45,6 @@ message ProfDAGProto {
|
|||
|
||||
// The extra_info from the operator device option.
|
||||
repeated string extra_info = 7;
|
||||
|
||||
}
|
||||
|
||||
// Operator profiling information.
|
||||
|
|
|
|||
|
|
@ -87,9 +87,7 @@ message LibDef {
|
|||
optional RecordRef torchscript_arena = 1;
|
||||
}
|
||||
|
||||
enum ProtoVersion {
|
||||
PROTO_VERSION_NEWEST = 0x0000000000000006;
|
||||
}
|
||||
enum ProtoVersion { PROTO_VERSION_NEWEST = 0x0000000000000006; }
|
||||
|
||||
message ModelDef {
|
||||
// numbers of fields that have been removed. Do not reuse them!
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@
|
|||
#define DLPACK_DLL
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
@ -160,15 +160,15 @@ typedef struct DLManagedTensor {
|
|||
/*! \brief the context of the original host framework of DLManagedTensor in
|
||||
* which DLManagedTensor is used in the framework. It can also be NULL.
|
||||
*/
|
||||
void * manager_ctx;
|
||||
void* manager_ctx;
|
||||
/*! \brief Destructor signature void (*)(void*) - this should be called
|
||||
* to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
|
||||
* if there is no way for the caller to provide a reasonable destructor.
|
||||
* The destructors deletes the argument self as well.
|
||||
*/
|
||||
void (*deleter)(struct DLManagedTensor * self);
|
||||
void (*deleter)(struct DLManagedTensor* self);
|
||||
} DLManagedTensor;
|
||||
#ifdef __cplusplus
|
||||
} // DLPACK_EXTERN_C
|
||||
} // DLPACK_EXTERN_C
|
||||
#endif
|
||||
#endif // DLPACK_DLPACK_H_
|
||||
#endif // DLPACK_DLPACK_H_
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ void addObjectMethods(pybind11::module& m);
|
|||
Workspace* GetCurrentWorkspace();
|
||||
|
||||
// Get workspace by name. Returns nullptr if none exists by name.
|
||||
Workspace* GetWorkspaceByName(const std::string &name);
|
||||
Workspace* GetWorkspaceByName(const std::string& name);
|
||||
|
||||
class C10_EXPORT BlobFetcherBase {
|
||||
public:
|
||||
|
|
@ -350,10 +350,10 @@ class PythonOpBase : public Operator<Context> {
|
|||
auto kwargs = builder_call[2].cast<py::dict>();
|
||||
auto built_func = func(*args, **kwargs);
|
||||
CAFFE_ENFORCE(built_func);
|
||||
built_func_.reset(
|
||||
new Func{built_func,
|
||||
OperatorBase::template GetSingleArgument<bool>(
|
||||
"pass_workspace", false)});
|
||||
built_func_.reset(new Func{
|
||||
built_func,
|
||||
OperatorBase::template GetSingleArgument<bool>(
|
||||
"pass_workspace", false)});
|
||||
} catch (const py::error_already_set& e) {
|
||||
LOG(ERROR) << "Python exception encountered while creating PythonOp: "
|
||||
<< e.what();
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class DLPackWrapper {
|
|||
managed_tensor.dl_tensor = dlTensor;
|
||||
// C2 Tensor memory is managed by C2
|
||||
managed_tensor.manager_ctx = nullptr;
|
||||
managed_tensor.deleter= [](DLManagedTensor*) {};
|
||||
managed_tensor.deleter = [](DLManagedTensor*) {};
|
||||
|
||||
return py::reinterpret_steal<py::object>(
|
||||
PyCapsule_New(&managed_tensor, "dltensor", nullptr));
|
||||
|
|
|
|||
|
|
@ -13,11 +13,10 @@
|
|||
#ifdef CAFFE2_USE_CUDNN
|
||||
#include "caffe2/core/common_cudnn.h"
|
||||
#endif // CAFFE2_USE_CUDNN
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/operator_fallback_gpu.h"
|
||||
#include "caffe2/python/pybind_state_registry.h"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
|
||||
#ifdef CAFFE2_USE_TRT
|
||||
#include "caffe2/contrib/tensorrt/tensorrt_tranformer.h"
|
||||
|
|
@ -27,9 +26,7 @@ namespace caffe2 {
|
|||
namespace python {
|
||||
|
||||
REGISTER_CUDA_OPERATOR(Python, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
PythonGradient,
|
||||
GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(PythonGradient, GPUFallbackOp);
|
||||
|
||||
REGISTER_CUDA_OPERATOR(PythonDLPack, GPUFallbackOp);
|
||||
REGISTER_CUDA_OPERATOR(PythonDLPackGradient, GPUFallbackOp);
|
||||
|
|
@ -43,11 +40,14 @@ void addCUDAGlobalMethods(py::module& m) {
|
|||
m.def("get_cuda_version", &CudaVersion);
|
||||
#ifdef CAFFE2_USE_CUDNN
|
||||
m.def("get_cudnn_version", &cudnnCompiledVersion);
|
||||
m.attr("cudnn_convolution_fwd_algo_count") = py::int_((int) CUDNN_CONVOLUTION_FWD_ALGO_COUNT);
|
||||
m.attr("cudnn_convolution_bwd_data_algo_count") = py::int_((int) CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT);
|
||||
m.attr("cudnn_convolution_bwd_filter_algo_count") = py::int_((int) CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT);
|
||||
m.attr("cudnn_convolution_fwd_algo_count") =
|
||||
py::int_((int)CUDNN_CONVOLUTION_FWD_ALGO_COUNT);
|
||||
m.attr("cudnn_convolution_bwd_data_algo_count") =
|
||||
py::int_((int)CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT);
|
||||
m.attr("cudnn_convolution_bwd_filter_algo_count") =
|
||||
py::int_((int)CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT);
|
||||
#else
|
||||
m.def("get_cudnn_version", [](){ return static_cast<size_t>(0);});
|
||||
m.def("get_cudnn_version", []() { return static_cast<size_t>(0); });
|
||||
m.attr("cudnn_convolution_fwd_algo_count") = py::int_(0);
|
||||
m.attr("cudnn_convolution_bwd_data_algo_count") = py::int_(0);
|
||||
m.attr("cudnn_convolution_bwd_filter_algo_count") = py::int_(0);
|
||||
|
|
|
|||
|
|
@ -5,19 +5,17 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <c10/hip/HIPGuard.h>
|
||||
#include "caffe2/core/hip/common_miopen.h"
|
||||
#include "caffe2/core/hip/context_gpu.h"
|
||||
#include "caffe2/operators/hip/operator_fallback_gpu.h"
|
||||
#include "caffe2/python/pybind_state_registry.h"
|
||||
#include <c10/hip/HIPGuard.h>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace python {
|
||||
|
||||
REGISTER_HIP_OPERATOR(Python, GPUFallbackOp);
|
||||
REGISTER_HIP_OPERATOR(
|
||||
PythonGradient,
|
||||
GPUFallbackOp);
|
||||
REGISTER_HIP_OPERATOR(PythonGradient, GPUFallbackOp);
|
||||
|
||||
REGISTER_HIP_OPERATOR(PythonDLPack, GPUFallbackOp);
|
||||
REGISTER_HIP_OPERATOR(PythonDLPackGradient, GPUFallbackOp);
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
|
||||
#include <caffe2/ideep/ideep_utils.h>
|
||||
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace python {
|
||||
|
|
@ -26,39 +26,39 @@ REGISTER_BLOB_FETCHER((TypeMeta::Id<itensor>()), IDeepFetcher);
|
|||
REGISTER_BLOB_FEEDER(IDEEP, IDeepFeeder);
|
||||
|
||||
class IDeepFetcher : public BlobFetcherBase {
|
||||
TypeMeta type_transform(const itensor &atensor) {
|
||||
TypeMeta type_transform(const itensor& atensor) {
|
||||
switch (atensor.get_data_type()) {
|
||||
case itensor::data_type::f32:
|
||||
return TypeMeta::Make<float>();
|
||||
case itensor::data_type::s32:
|
||||
return TypeMeta::Make<int>();
|
||||
case itensor::data_type::s8:
|
||||
return TypeMeta::Make<int8_t>();
|
||||
case itensor::data_type::u8:
|
||||
return TypeMeta::Make<uint8_t>();
|
||||
default:
|
||||
// Should we throw exception?
|
||||
return TypeMeta();
|
||||
case itensor::data_type::f32:
|
||||
return TypeMeta::Make<float>();
|
||||
case itensor::data_type::s32:
|
||||
return TypeMeta::Make<int>();
|
||||
case itensor::data_type::s8:
|
||||
return TypeMeta::Make<int8_t>();
|
||||
case itensor::data_type::u8:
|
||||
return TypeMeta::Make<uint8_t>();
|
||||
default:
|
||||
// Should we throw exception?
|
||||
return TypeMeta();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
pybind11::object Fetch(const Blob &blob) override {
|
||||
public:
|
||||
pybind11::object Fetch(const Blob& blob) override {
|
||||
try {
|
||||
return FetchTensor(blob.Get<itensor>(), true).obj;
|
||||
} catch (ideep::error &e) {
|
||||
} catch (ideep::error& e) {
|
||||
LOG(ERROR) << "IDEEP error: " << e.message;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
FetchedBlob FetchTensor(const itensor &atensor, bool force_copy) {
|
||||
FetchedBlob FetchTensor(const itensor& atensor, bool force_copy) {
|
||||
#ifdef USE_NUMPY
|
||||
FetchedBlob result;
|
||||
CAFFE_ENFORCE((atensor.ndims() != 0) &&
|
||||
(atensor.get_nelems() == 0 ||
|
||||
atensor.get_data_handle() != nullptr),
|
||||
"Trying to fetch uninitialized tensor");
|
||||
CAFFE_ENFORCE(
|
||||
(atensor.ndims() != 0) &&
|
||||
(atensor.get_nelems() == 0 || atensor.get_data_handle() != nullptr),
|
||||
"Trying to fetch uninitialized tensor");
|
||||
// NOTE: Only support float so far.
|
||||
const int numpy_type = NPY_FLOAT;
|
||||
CAFFE_ENFORCE(
|
||||
|
|
@ -70,12 +70,12 @@ public:
|
|||
|
||||
result.copied = force_copy || atensor.need_reorder();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
void *outPtr;
|
||||
void* outPtr;
|
||||
if (result.copied) {
|
||||
result.obj = py::reinterpret_steal<py::object>(
|
||||
PyArray_SimpleNew(atensor.ndims(), npy_dims.data(), numpy_type));
|
||||
outPtr = static_cast<void *>(
|
||||
PyArray_DATA(reinterpret_cast<PyArrayObject *>(result.obj.ptr())));
|
||||
outPtr = static_cast<void*>(
|
||||
PyArray_DATA(reinterpret_cast<PyArrayObject*>(result.obj.ptr())));
|
||||
} else {
|
||||
outPtr = atensor.get_data_handle();
|
||||
result.obj = py::reinterpret_steal<py::object>(PyArray_SimpleNewFromData(
|
||||
|
|
@ -111,13 +111,13 @@ class IDeepFeeder : public BlobFeederBase {
|
|||
return itensor::data_type::undef;
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
void FeedTensor(
|
||||
const DeviceOption &option,
|
||||
PyArrayObject *original_array,
|
||||
itensor *tensor) {
|
||||
const DeviceOption& option,
|
||||
PyArrayObject* original_array,
|
||||
itensor* tensor) {
|
||||
#ifdef USE_NUMPY
|
||||
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS(original_array);
|
||||
auto g = MakeGuard([&]() { Py_XDECREF(array); });
|
||||
const auto npy_type = PyArray_TYPE(array);
|
||||
const TypeMeta meta = NumpyTypeToCaffe(npy_type);
|
||||
|
|
@ -125,10 +125,11 @@ public:
|
|||
meta,
|
||||
ScalarType::Undefined,
|
||||
"This numpy data type is not supported: ",
|
||||
PyArray_TYPE(array), ".");
|
||||
PyArray_TYPE(array),
|
||||
".");
|
||||
|
||||
int ndim = PyArray_NDIM(array);
|
||||
npy_intp *npy_dims = PyArray_DIMS(array);
|
||||
npy_intp* npy_dims = PyArray_DIMS(array);
|
||||
|
||||
itensor::dims adims;
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
|
|
@ -145,15 +146,14 @@ public:
|
|||
if (tensor->get_dims() != adims || type != tensor->get_data_type()) {
|
||||
tensor->resize(adims, type);
|
||||
}
|
||||
tensor->feed_from(adims, type,
|
||||
static_cast<void *>(PyArray_DATA(array)));
|
||||
tensor->feed_from(adims, type, static_cast<void*>(PyArray_DATA(array)));
|
||||
}
|
||||
#else
|
||||
CAFFE_THROW("Caffe2 was compiled without NumPy support.");
|
||||
#endif // USE_NUMPY
|
||||
}
|
||||
|
||||
bool ZeroDim(PyArrayObject *array) {
|
||||
bool ZeroDim(PyArrayObject* array) {
|
||||
#ifdef USE_NUMPY
|
||||
int ndim = PyArray_NDIM(array);
|
||||
return ndim == 0;
|
||||
|
|
@ -169,15 +169,15 @@ public:
|
|||
bool in_place) override {
|
||||
#ifdef USE_NUMPY
|
||||
try {
|
||||
PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array);
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS(original_array);
|
||||
auto g = MakeGuard([&]() { Py_XDECREF(array); });
|
||||
|
||||
const auto npy_type = PyArray_TYPE(array);
|
||||
const TypeMeta meta = NumpyTypeToCaffe(npy_type);
|
||||
|
||||
// TODO: if necessary, use dispatcher.
|
||||
if ((in_place && blob->IsType<itensor>())
|
||||
|| (meta.Match<float>() && !ZeroDim(original_array))) {
|
||||
if ((in_place && blob->IsType<itensor>()) ||
|
||||
(meta.Match<float>() && !ZeroDim(original_array))) {
|
||||
FeedTensor(option, original_array, blob->GetMutable<itensor>());
|
||||
} else {
|
||||
DeviceOption cpu_option(option);
|
||||
|
|
@ -191,10 +191,10 @@ public:
|
|||
true);
|
||||
} else {
|
||||
blob->Reset<Tensor>(new Tensor(
|
||||
cpu_tensor_feeder.FeedTensor(cpu_option, original_array)));
|
||||
cpu_tensor_feeder.FeedTensor(cpu_option, original_array)));
|
||||
}
|
||||
}
|
||||
} catch (ideep::error &e) {
|
||||
} catch (ideep::error& e) {
|
||||
LOG(ERROR) << "IDEEP error: " << e.message;
|
||||
throw;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -526,12 +526,12 @@ void addNomnigraphMethods(pybind11::module& m) {
|
|||
"operator_def",
|
||||
[](Caffe2Annotation& annot) {
|
||||
auto opDef = py::module::import("caffe2.proto.caffe2_pb2")
|
||||
.attr("OperatorDef");
|
||||
.attr("OperatorDef");
|
||||
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||
auto proto = annot.getOperatorDef();
|
||||
std::string serialized_proto;
|
||||
proto.SerializeToString(&serialized_proto);
|
||||
auto py_op_def= opDef();
|
||||
auto py_op_def = opDef();
|
||||
py_op_def.attr("ParseFromString")(py::bytes(serialized_proto));
|
||||
return py_op_def;
|
||||
},
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ const std::string& GetStringFromBlob(Blob* blob) {
|
|||
CAFFE_THROW("Unsupported Blob type");
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class BlobsQueueDBCursor : public Cursor {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -96,4 +96,4 @@ NO_GRADIENT(SafeEnqueueBlobs);
|
|||
NO_GRADIENT(SafeDequeueBlobs);
|
||||
NO_GRADIENT(WeightedSampleDequeueBlobs);
|
||||
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -49,7 +49,8 @@ class EnqueueBlobsOp final : public Operator<Context> {
|
|||
CAFFE_ENFORCE(InputSize() > 1);
|
||||
auto queue = Operator<Context>::Inputs()[0]
|
||||
->template Get<std::shared_ptr<BlobsQueue>>();
|
||||
CAFFE_ENFORCE(queue && static_cast<size_t>(OutputSize()) == queue->getNumBlobs());
|
||||
CAFFE_ENFORCE(
|
||||
queue && static_cast<size_t>(OutputSize()) == queue->getNumBlobs());
|
||||
return queue->blockingWrite(this->Outputs());
|
||||
}
|
||||
|
||||
|
|
@ -70,7 +71,8 @@ class DequeueBlobsOp final : public Operator<Context> {
|
|||
CAFFE_ENFORCE(InputSize() == 1);
|
||||
auto queue =
|
||||
OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
|
||||
CAFFE_ENFORCE(queue && static_cast<size_t>(OutputSize()) == queue->getNumBlobs());
|
||||
CAFFE_ENFORCE(
|
||||
queue && static_cast<size_t>(OutputSize()) == queue->getNumBlobs());
|
||||
return queue->blockingRead(this->Outputs(), timeout_secs_);
|
||||
}
|
||||
|
||||
|
|
@ -141,7 +143,7 @@ class SafeDequeueBlobsOp final : public Operator<Context> {
|
|||
if (blobs_.size() != size) {
|
||||
blobs_.resize(size);
|
||||
blobPtrs_.resize(size);
|
||||
for (auto col: c10::irange(size)) {
|
||||
for (auto col : c10::irange(size)) {
|
||||
blobPtrs_.at(col) = &blobs_.at(col);
|
||||
}
|
||||
}
|
||||
|
|
@ -152,7 +154,7 @@ class SafeDequeueBlobsOp final : public Operator<Context> {
|
|||
// if we read at least one record, status is still true
|
||||
return i > 0;
|
||||
}
|
||||
for (auto col: c10::irange(size)) {
|
||||
for (auto col : c10::irange(size)) {
|
||||
auto* out = this->Output(col);
|
||||
const auto& in = blobPtrs_.at(col)->template Get<Tensor>();
|
||||
if (i == 0) {
|
||||
|
|
@ -231,7 +233,7 @@ class WeightedSampleDequeueBlobsOp final : public Operator<Context> {
|
|||
float sum = accumulate(weights.begin(), weights.end(), 0.0f);
|
||||
CAFFE_ENFORCE(sum > 0.0f, "Sum of weights must be positive");
|
||||
cumProbs_.resize(weights.size());
|
||||
for (auto i: c10::irange(weights.size())) {
|
||||
for (auto i : c10::irange(weights.size())) {
|
||||
cumProbs_[i] = weights[i] / sum;
|
||||
CAFFE_ENFORCE_GE(
|
||||
cumProbs_[i], 0.0f, "Each probability must be non-negative");
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#include "caffe2/utils/math.h"
|
||||
#include "caffe2/queue/queue_ops.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
|
|
@ -13,4 +13,4 @@ REGISTER_CUDA_OPERATOR(CloseBlobsQueue, CloseBlobsQueueOp<CUDAContext>);
|
|||
REGISTER_CUDA_OPERATOR(SafeEnqueueBlobs, SafeEnqueueBlobsOp<CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(SafeDequeueBlobs, SafeDequeueBlobsOp<CUDAContext>);
|
||||
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ bool RebatchingQueue::enqueue(
|
|||
|
||||
do {
|
||||
queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]);
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
} while (canWrite() && idx < splittedInputs.size());
|
||||
}
|
||||
|
||||
|
|
@ -234,4 +234,4 @@ void RebatchingQueue::close() {
|
|||
cvEmpty_.notify_all();
|
||||
cvOverflow_.notify_all();
|
||||
}
|
||||
} // caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -65,4 +65,4 @@ class RebatchingQueue {
|
|||
|
||||
std::vector<std::vector<TensorCPU>> queue_;
|
||||
};
|
||||
} // caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -68,5 +68,5 @@ tensor per component.
|
|||
.Arg(
|
||||
"num_elements",
|
||||
"Number of elements to dequeue. By default we dequeue one element.");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -80,4 +80,4 @@ class CloseRebatchingQueueOp : public Operator<CPUContext> {
|
|||
return true;
|
||||
}
|
||||
};
|
||||
} // caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -167,9 +167,8 @@ void runDepthwise3x3Conv(
|
|||
// fast-path, all accesses in-bounds
|
||||
if (C10_LIKELY(
|
||||
ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
|
||||
iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
|
||||
2 * otw + 1 < args.out_cols
|
||||
)) {
|
||||
iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
|
||||
2 * otw + 1 < args.out_cols)) {
|
||||
float32x4x4_t input_tile;
|
||||
for (int row = 0; row < 4; ++row) {
|
||||
input_tile.val[row] =
|
||||
|
|
|
|||
|
|
@ -140,8 +140,7 @@ void compare(
|
|||
|
||||
// For small values / small difference, the relative error
|
||||
// can be huge but the absolute error will be small
|
||||
EXPECT_TRUE(
|
||||
relErr <= maxRelErr || absErr <= absErrForRelErrFailure)
|
||||
EXPECT_TRUE(relErr <= maxRelErr || absErr <= absErrForRelErrFailure)
|
||||
<< v1 << " " << v2 << " (rel err " << relErr << ") "
|
||||
<< "(" << n << " " << c << " " << h << " " << w << ") "
|
||||
<< "running N " << N << " inputC " << inputC << " H " << H
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
|
@ -131,10 +130,9 @@ NNPACKConvOp::getConvolutionTransformStrategy() const {
|
|||
return nnp_convolution_transform_strategy_compute;
|
||||
}
|
||||
|
||||
nnp_activation
|
||||
NNPACKConvOp::getActivationType() const {
|
||||
auto activation = OperatorBase::GetSingleArgument<std::string>(
|
||||
"activation", "identity");
|
||||
nnp_activation NNPACKConvOp::getActivationType() const {
|
||||
auto activation =
|
||||
OperatorBase::GetSingleArgument<std::string>("activation", "identity");
|
||||
if (activation == "identity") {
|
||||
return nnp_activation_identity;
|
||||
} else if (activation == "Relu") {
|
||||
|
|
@ -180,19 +178,23 @@ bool NNPACKConvOp::RunOnDeviceWithOrderNCHW() {
|
|||
biasData = dummyBias_.data();
|
||||
}
|
||||
|
||||
const nnp_size input_size = {.width = static_cast<size_t>(X.dim32(3)),
|
||||
.height = static_cast<size_t>(X.dim32(2))};
|
||||
const nnp_size input_size = {
|
||||
.width = static_cast<size_t>(X.dim32(3)),
|
||||
.height = static_cast<size_t>(X.dim32(2))};
|
||||
// filter is MCHW
|
||||
const nnp_size kernel_size = {.width = static_cast<size_t>(filter.dim32(3)),
|
||||
.height = static_cast<size_t>(filter.dim32(2))};
|
||||
const nnp_size kernel_size = {
|
||||
.width = static_cast<size_t>(filter.dim32(3)),
|
||||
.height = static_cast<size_t>(filter.dim32(2))};
|
||||
// pad is tblr
|
||||
const nnp_padding padding = {.top = static_cast<size_t>(pad_t()),
|
||||
.right = static_cast<size_t>(pad_r()),
|
||||
.bottom = static_cast<size_t>(pad_b()),
|
||||
.left = static_cast<size_t>(pad_l())};
|
||||
const nnp_padding padding = {
|
||||
.top = static_cast<size_t>(pad_t()),
|
||||
.right = static_cast<size_t>(pad_r()),
|
||||
.bottom = static_cast<size_t>(pad_b()),
|
||||
.left = static_cast<size_t>(pad_l())};
|
||||
|
||||
const nnp_size output_subsample = {.width = static_cast<size_t>(stride_w()),
|
||||
.height = static_cast<size_t>(stride_h())};
|
||||
const nnp_size output_subsample = {
|
||||
.width = static_cast<size_t>(stride_w()),
|
||||
.height = static_cast<size_t>(stride_h())};
|
||||
initNNPACK();
|
||||
|
||||
#if !defined(USE_INTERNAL_PTHREADPOOL_IMPL)
|
||||
|
|
|
|||
|
|
@ -21,10 +21,11 @@ namespace {
|
|||
uint8_t* GetMutableData(int type_index, TensorCPU* tensor) {
|
||||
// see COMP_DATA_TYPE_MAPPER in mutils.py for the mapping
|
||||
static const std::map<int, std::function<uint8_t*(TensorCPU * tensor)>>
|
||||
gTypeMapper = {REGISTER_TYPE(TensorProto::UINT8, uint8_t),
|
||||
REGISTER_TYPE(TensorProto::UINT16, uint16_t),
|
||||
REGISTER_TYPE(TensorProto::INT32, int32_t),
|
||||
REGISTER_TYPE(TensorProto::FLOAT, float)};
|
||||
gTypeMapper = {
|
||||
REGISTER_TYPE(TensorProto::UINT8, uint8_t),
|
||||
REGISTER_TYPE(TensorProto::UINT16, uint16_t),
|
||||
REGISTER_TYPE(TensorProto::INT32, int32_t),
|
||||
REGISTER_TYPE(TensorProto::FLOAT, float)};
|
||||
|
||||
CAFFE_ENFORCE_EQ(
|
||||
gTypeMapper.count(type_index),
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ C10_DEFINE_string(
|
|||
"gen/",
|
||||
"The root of the caffe test folder.");
|
||||
|
||||
GTEST_API_ int main(int argc, char **argv) {
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
// std::cout << "Running main() from gtest_main.cc\n";
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
caffe2::GlobalInit(&argc, &argv);
|
||||
|
|
|
|||
|
|
@ -115,4 +115,4 @@ TEST(CommonSubexpressionEliminationTest, TestFromExternal) {
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace Caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -46,4 +46,4 @@ TEST(ConvToNNPackTest, TestSimple) {
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace Caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -263,4 +263,4 @@ bool PatternNetTransform::ReplaceRule(
|
|||
return true;
|
||||
}
|
||||
|
||||
} // namespace Caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -530,4 +530,4 @@ TEST(PatternNetTransformTest, TestMultiInputOutputTransform) {
|
|||
|
||||
} // namespace
|
||||
|
||||
} // namespace Caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#include <caffe2/video/video_decoder.h>
|
||||
#include <assert.h>
|
||||
#include <caffe2/core/logging.h>
|
||||
#include <caffe2/video/video_decoder.h>
|
||||
#include <array>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <array>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
|
|
|
|||
|
|
@ -606,7 +606,9 @@ bool VideoInputOp<Context>::GetImageAndLabelsFromDBValue(
|
|||
img = scaled_img;
|
||||
} else {
|
||||
cv::cvtColor(
|
||||
scaled_img, img, (channels_rgb_ == 1) ? cv::COLOR_BGR2GRAY : cv::COLOR_GRAY2BGR);
|
||||
scaled_img,
|
||||
img,
|
||||
(channels_rgb_ == 1) ? cv::COLOR_BGR2GRAY : cv::COLOR_GRAY2BGR);
|
||||
}
|
||||
|
||||
cv::Mat rgb_img;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#include <caffe2/video/video_io.h>
|
||||
#include <caffe2/core/logging.h>
|
||||
#include <caffe2/video/video_io.h>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user