mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replace all CHECK_ and DCHECK_ with TORCH_* macros (#82032)
Avoid exposing defines that conflict with google logging, since this blocks external usage of libtorch in certain cases. All the 'interesting' changes should be in these two files, and the rest should just be mechanical changes via sed. c10/util/logging_is_not_google_glog.h c10/util/logging_is_google_glog.h Fixes https://github.com/pytorch/pytorch/issues/81415 cc @miladm @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/82032 Approved by: https://github.com/soumith, https://github.com/miladm
This commit is contained in:
parent
cab819222a
commit
4f34cd6d1e
|
|
@ -35,7 +35,7 @@ void LayerNormKernelImplInternal(
|
|||
Tensor* rstd) {
|
||||
using T_ACC = vec::vec_scalar_t<T>;
|
||||
using Vec = vec::Vectorized<T_ACC>;
|
||||
DCHECK_EQ(X.numel(), M * N);
|
||||
TORCH_DCHECK_EQ(X.numel(), M * N);
|
||||
DCHECK(!gamma.defined() || gamma.numel() == N);
|
||||
DCHECK(!beta.defined() || beta.numel() == N);
|
||||
const T* X_data = X.data_ptr<T>();
|
||||
|
|
@ -117,10 +117,10 @@ void LayerNormBackwardKernelImplInternal(
|
|||
Tensor* dbeta) {
|
||||
using T_ACC = vec::vec_scalar_t<T>;
|
||||
using Vec = vec::Vectorized<T_ACC>;
|
||||
DCHECK_EQ(dY.numel(), M * N);
|
||||
DCHECK_EQ(X.numel(), M * N);
|
||||
DCHECK_EQ(mean.numel(), M);
|
||||
DCHECK_EQ(rstd.numel(), M);
|
||||
TORCH_DCHECK_EQ(dY.numel(), M * N);
|
||||
TORCH_DCHECK_EQ(X.numel(), M * N);
|
||||
TORCH_DCHECK_EQ(mean.numel(), M);
|
||||
TORCH_DCHECK_EQ(rstd.numel(), M);
|
||||
DCHECK(!gamma.defined() || gamma.numel() == N);
|
||||
const T* dY_data = dY.template data_ptr<T>();
|
||||
const T* X_data = X.template data_ptr<T>();
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) {
|
|||
TORCH_CHECK(
|
||||
status == pytorch_qnnp_status_success,
|
||||
"failed to create QNNPACK Softmax operator");
|
||||
CHECK_NOTNULL(softargmax);
|
||||
TORCH_CHECK_NOTNULL(softargmax);
|
||||
|
||||
status = pytorch_qnnp_setup_softargmax_nc_q8(
|
||||
softargmax, batch_size, input, input_stride, output, output_stride);
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ static void BM_deep_wide_jit_graph_executor(benchmark::State& state) {
|
|||
|
||||
std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||
|
||||
CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
|
||||
TORCH_CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
|
||||
|
||||
mod.forward(inputs);
|
||||
for (auto _ : state) {
|
||||
|
|
@ -65,7 +65,7 @@ static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
|
|||
|
||||
std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||
|
||||
CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
|
||||
TORCH_CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
|
||||
|
||||
mod.forward(inputs);
|
||||
for (auto _ : state) {
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ int loadInput(
|
|||
LOG(INFO) << "Running on GPU.";
|
||||
#ifdef __CUDA_ARCH__
|
||||
caffe2::TensorCUDA* tensor = blob->GetMutable<caffe2::TensorCUDA>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
TORCH_CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
tensor->mutable_data<uint8_t>();
|
||||
|
|
@ -189,17 +189,17 @@ int loadInput(
|
|||
if (input_type_list[i] == "uint8_t") {
|
||||
caffe2::int8::Int8TensorCPU* tensor =
|
||||
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
TORCH_CHECK_NOTNULL(tensor);
|
||||
tensor->t.Resize(input_dims);
|
||||
tensor->t.mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
TORCH_CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<float>();
|
||||
} else if (input_type_list[i] == "int") {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
TORCH_CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<int>();
|
||||
} else {
|
||||
|
|
@ -495,7 +495,7 @@ int benchmark(
|
|||
net_def.set_name("benchmark");
|
||||
}
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
TORCH_CHECK_NOTNULL(net);
|
||||
runNetwork(
|
||||
workspace,
|
||||
net,
|
||||
|
|
|
|||
|
|
@ -591,7 +591,7 @@ void runNetwork(
|
|||
}
|
||||
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
TORCH_CHECK_NOTNULL(net);
|
||||
|
||||
LOG(INFO) << "Starting benchmark.";
|
||||
caffe2::ObserverConfig::initSampleRate(1, 1, 1, run_individual, warmup);
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ void ConvertImageDataset(
|
|||
// Synthesize key for this entry
|
||||
auto key_len = snprintf(
|
||||
key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str());
|
||||
DCHECK_LE(key_len, sizeof(key_cstr));
|
||||
TORCH_DCHECK_LE(key_len, sizeof(key_cstr));
|
||||
|
||||
// Put in db
|
||||
transaction->Put(string(key_cstr), std::move(value));
|
||||
|
|
|
|||
|
|
@ -136,12 +136,12 @@ int main(int argc, char** argv) {
|
|||
if (input_type_list[i] == "uint8_t") {
|
||||
caffe2::int8::Int8TensorCPU* tensor =
|
||||
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
TORCH_CHECK_NOTNULL(tensor);
|
||||
tensor->t.Resize(input_dims);
|
||||
tensor->t.mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
TORCH_CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
|
|
@ -184,7 +184,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
caffe2::NetBase* net = workspace->CreateNet(net_def);
|
||||
CHECK_NOTNULL(net);
|
||||
TORCH_CHECK_NOTNULL(net);
|
||||
CAFFE_ENFORCE(net->Run());
|
||||
net->TEST_Benchmark(FLAGS_warmup, FLAGS_iter, FLAGS_run_individual);
|
||||
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ TEST(LoggingTest, Join) {
|
|||
|
||||
TEST(LoggingTest, TestDanglingElse) {
|
||||
if (true)
|
||||
DCHECK_EQ(1, 1);
|
||||
TORCH_DCHECK_EQ(1, 1);
|
||||
else
|
||||
GTEST_FAIL();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ using EnforceNotMet = ::c10::Error;
|
|||
* With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))`
|
||||
*
|
||||
* Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided
|
||||
* too. Please use them instead of CHECK_EQ and friends for failures in
|
||||
* too. Please use them instead of TORCH_CHECK_EQ and friends for failures in
|
||||
* user-provided input.
|
||||
*/
|
||||
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ class Registry {
|
|||
const RegistryPriority priority = REGISTRY_DEFAULT) {
|
||||
std::lock_guard<std::mutex> lock(register_mutex_);
|
||||
// The if statement below is essentially the same as the following line:
|
||||
// CHECK_EQ(registry_.count(key), 0) << "Key " << key
|
||||
// TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
|
||||
// << " registered twice.";
|
||||
// However, CHECK_EQ depends on google logging, and since registration is
|
||||
// carried out at static initialization time, we do not want to have an
|
||||
// However, TORCH_CHECK_EQ depends on google logging, and since registration
|
||||
// is carried out at static initialization time, we do not want to have an
|
||||
// explicit dependency on glog's initialization function.
|
||||
if (registry_.count(key) != 0) {
|
||||
auto cur_priority = priority_[key];
|
||||
|
|
|
|||
|
|
@ -50,6 +50,71 @@ INSTANTIATE_FOR_CONTAINER(set)
|
|||
#include <glog/logging.h>
|
||||
|
||||
// Additional macros on top of glog
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
|
||||
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
|
||||
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
|
||||
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
|
||||
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
|
||||
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
|
||||
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_CHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_EQ(val1, val2)
|
||||
#define TORCH_CHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_NE(val1, val2)
|
||||
#define TORCH_CHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_LE(val1, val2)
|
||||
#define TORCH_CHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_LT(val1, val2)
|
||||
#define TORCH_CHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_GE(val1, val2)
|
||||
#define TORCH_CHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_GT(val1, val2)
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GT(val1, val2)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
DCHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Log with source location information override (to be used in generic
|
||||
// warning/error handlers implemented as functions, not macros)
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ void LogMessageFatal(const char* file, int line, const T& message) {
|
|||
MessageLogger(file, line, GLOG_FATAL).stream() << message;
|
||||
}
|
||||
|
||||
// Helpers for CHECK_NOTNULL(). Two are necessary to support both raw pointers
|
||||
// and smart pointers.
|
||||
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
|
||||
// pointers and smart pointers.
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
|
||||
if (t == nullptr) {
|
||||
|
|
@ -136,63 +136,63 @@ static_assert(
|
|||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
|
||||
#endif // NDEBUG
|
||||
|
||||
#define CHECK_OP(val1, val2, op) \
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << ") "
|
||||
|
||||
// Check_op macro definitions
|
||||
#define CHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==)
|
||||
#define CHECK_NE(val1, val2) CHECK_OP(val1, val2, !=)
|
||||
#define CHECK_LE(val1, val2) CHECK_OP(val1, val2, <=)
|
||||
#define CHECK_LT(val1, val2) CHECK_OP(val1, val2, <)
|
||||
#define CHECK_GE(val1, val2) CHECK_OP(val1, val2, >=)
|
||||
#define CHECK_GT(val1, val2) CHECK_OP(val1, val2, >)
|
||||
// TORCH_CHECK_OP macro definitions
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only versions of CHECK_OP macros.
|
||||
#define DCHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==)
|
||||
#define DCHECK_NE(val1, val2) CHECK_OP(val1, val2, !=)
|
||||
#define DCHECK_LE(val1, val2) CHECK_OP(val1, val2, <=)
|
||||
#define DCHECK_LT(val1, val2) CHECK_OP(val1, val2, <)
|
||||
#define DCHECK_GE(val1, val2) CHECK_OP(val1, val2, >=)
|
||||
#define DCHECK_GT(val1, val2) CHECK_OP(val1, val2, >)
|
||||
// Debug only versions of TORCH_CHECK_OP macros.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define DCHECK_EQ(val1, val2) \
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_OP(val1, val2, ==)
|
||||
#define DCHECK_NE(val1, val2) \
|
||||
TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_OP(val1, val2, !=)
|
||||
#define DCHECK_LE(val1, val2) \
|
||||
TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_OP(val1, val2, <=)
|
||||
#define DCHECK_LT(val1, val2) \
|
||||
TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_OP(val1, val2, <)
|
||||
#define DCHECK_GE(val1, val2) \
|
||||
TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_OP(val1, val2, >=)
|
||||
#define DCHECK_GT(val1, val2) \
|
||||
TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
CHECK_OP(val1, val2, >)
|
||||
TORCH_CHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define CHECK_NOTNULL(val) \
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of CHECK_NOTNULL
|
||||
#define DCHECK_NOTNULL(val) \
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define DCHECK_NOTNULL(val) \
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
CHECK_NOTNULL(val)
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ---------------------- Support for std objects --------------------------
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class Fp16FCAccOp final : public Operator<Context> {
|
|||
|
||||
Y_shape_cache_ = X.sizes().vec();
|
||||
// This is an invariant of canonical_axis, so we can DCHECK.
|
||||
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
|
||||
TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
|
||||
Y_shape_cache_.resize(canonical_axis + 1);
|
||||
Y_shape_cache_[canonical_axis] = N;
|
||||
Y->Resize(Y_shape_cache_);
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ NCCLContext* getNCCLContext(const NCCLExecution& ex) {
|
|||
LOG(INFO) << "Creating NCCLContext for key: " << key;
|
||||
contexts[key].reset(new NCCLContext(ex));
|
||||
}
|
||||
return CHECK_NOTNULL(contexts[key].get());
|
||||
return TORCH_CHECK_NOTNULL(contexts[key].get());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -153,7 +153,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
|
|||
auto& comm = comms[i];
|
||||
auto& stream = streams[i];
|
||||
|
||||
DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data()));
|
||||
TORCH_DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data()));
|
||||
CUDA_ENFORCE(cudaStreamWaitEvent(stream, context->master_event_, 0));
|
||||
f(ctx, comm, stream);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class OpenCLContext final {
|
|||
public:
|
||||
explicit OpenCLContext();
|
||||
explicit OpenCLContext(const DeviceOption& option) {
|
||||
DCHECK_EQ(option.device_type(), PROTO_OPENCL);
|
||||
TORCH_DCHECK_EQ(option.device_type(), PROTO_OPENCL);
|
||||
OpenCLContext();
|
||||
}
|
||||
~OpenCLContext() {}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ CuDNNWrapper::PerGPUCuDNNStates& CuDNNWrapper::cudnn_states() {
|
|||
// New it (never delete) to avoid calling the destructors on process
|
||||
// exit and racing against the CUDA shutdown sequence.
|
||||
static auto* p = new CuDNNWrapper::PerGPUCuDNNStates();
|
||||
CHECK_NOTNULL(p);
|
||||
TORCH_CHECK_NOTNULL(p);
|
||||
return *p;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -437,7 +437,7 @@ CUDAContext::CUDAContext(const DeviceOption& option)
|
|||
option.has_random_seed() ? option.random_seed()
|
||||
: RandomNumberSeed()) {
|
||||
static Caffe2CudaInitializerHelper g_cuda_initializer_;
|
||||
DCHECK_EQ(option.device_type(), PROTO_CUDA);
|
||||
TORCH_DCHECK_EQ(option.device_type(), PROTO_CUDA);
|
||||
}
|
||||
|
||||
CUDAContext::~CUDAContext() {
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
|
|||
curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
|
||||
CURAND_ENFORCE(
|
||||
curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_));
|
||||
CHECK_NOTNULL(curand_generator_);
|
||||
TORCH_CHECK_NOTNULL(curand_generator_);
|
||||
}
|
||||
CURAND_ENFORCE(curandSetStream(curand_generator_, cuda_stream()));
|
||||
return curand_generator_;
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ TEST(CUDAContextTest, MemoryPoolAllocateDealloc) {
|
|||
cudaStream_t getStreamForHandle(cublasHandle_t handle) {
|
||||
cudaStream_t stream = nullptr;
|
||||
CUBLAS_ENFORCE(cublasGetStream(handle, &stream));
|
||||
CHECK_NOTNULL(stream);
|
||||
TORCH_CHECK_NOTNULL(stream);
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class CuDNNWrapper {
|
|||
if (!sync_state.state.get()) {
|
||||
sync_state.state.reset(new CuDNNState(context_->device_id()));
|
||||
}
|
||||
CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
|
||||
TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ MIOPENWrapper::PerGPUMIOPENStates& MIOPENWrapper::miopen_states()
|
|||
// New it (never delete) to avoid calling the destructors on process
|
||||
// exit and racing against the CUDA shutdown sequence.
|
||||
static auto* p = new MIOPENWrapper::PerGPUMIOPENStates();
|
||||
CHECK_NOTNULL(p);
|
||||
TORCH_CHECK_NOTNULL(p);
|
||||
return *p;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class MIOPENWrapper
|
|||
{
|
||||
sync_state.state.reset(new MIOPENState(context_->device_id()));
|
||||
}
|
||||
CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f);
|
||||
TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ void checkChainingAndRun(
|
|||
net_def.set_num_workers(4);
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
auto* dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
|
||||
CHECK_NOTNULL(dag);
|
||||
TORCH_CHECK_NOTNULL(dag);
|
||||
const auto& chains = dag->TEST_execution_chains();
|
||||
EXPECT_EQ(chains, expected);
|
||||
testExecution(net, net_def.op().size());
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ void checkChainingAndRun(
|
|||
net_def.set_num_workers(4);
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
auto* dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
|
||||
CHECK_NOTNULL(dag);
|
||||
TORCH_CHECK_NOTNULL(dag);
|
||||
const auto& chains = dag->TEST_execution_chains();
|
||||
EXPECT_TRUE(chains == expected);
|
||||
testExecution(net, net_def.op().size());
|
||||
|
|
@ -175,7 +175,7 @@ void checkNumChainsAndRun(const char* spec, const int expected_num_chains) {
|
|||
{
|
||||
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
|
||||
auto* dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
|
||||
CHECK_NOTNULL(dag);
|
||||
TORCH_CHECK_NOTNULL(dag);
|
||||
const auto& chains = dag->TEST_execution_chains();
|
||||
EXPECT_EQ(expected_num_chains, chains.size());
|
||||
testExecution(net, net_def.op().size());
|
||||
|
|
@ -1108,7 +1108,7 @@ void testProfDAGNetErrorCase(bool test_error) {
|
|||
// with failing op - prof_dag handles invalid runs and returns empty stats,
|
||||
// without - returns stats for each op
|
||||
auto* prof_dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
|
||||
CHECK_NOTNULL(prof_dag);
|
||||
TORCH_CHECK_NOTNULL(prof_dag);
|
||||
auto stats_proto = prof_dag->GetPerOperatorCost();
|
||||
ASSERT_EQ(
|
||||
stats_proto.stats_size(), test_error ? 0 : net->GetOperators().size());
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
|
|||
|
||||
outputs_.reserve(operator_def.output_size());
|
||||
for (const string& output_str : operator_def.output()) {
|
||||
outputs_.push_back(CHECK_NOTNULL(ws->CreateBlob(output_str)));
|
||||
outputs_.push_back(TORCH_CHECK_NOTNULL(ws->CreateBlob(output_str)));
|
||||
}
|
||||
|
||||
type_ = operator_def.type();
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
|||
!std::is_same<T, Tensor>::value,
|
||||
"You should use Input<Tensor>(int, DeviceType) for "
|
||||
"Tensor.");
|
||||
DCHECK_LT((size_t)idx, inputs_.size());
|
||||
TORCH_DCHECK_LT((size_t)idx, inputs_.size());
|
||||
try {
|
||||
return inputs_.at(idx)->template Get<T>();
|
||||
} catch (::caffe2::EnforceNotMet& enf) {
|
||||
|
|
@ -178,7 +178,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
|||
static_assert(
|
||||
std::is_same<T, Tensor>::value,
|
||||
"Input(int, DeviceType) is only available for Tensor");
|
||||
DCHECK_LT((size_t)idx, inputs_.size());
|
||||
TORCH_DCHECK_LT((size_t)idx, inputs_.size());
|
||||
try {
|
||||
// TODO(jerryzh): We'll need to check device type in Get<T>() later
|
||||
// Get<T>() -> Get<T>(type)
|
||||
|
|
@ -193,7 +193,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
|||
}
|
||||
#if defined(EXPOSE_C2_OPS) || \
|
||||
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
|
||||
DCHECK_LT(0U, newstyle_inputs_.size());
|
||||
TORCH_DCHECK_LT(0U, newstyle_inputs_.size());
|
||||
IValue ival;
|
||||
if (newstyle_inputs_[0].isTensorList()) {
|
||||
// if the first input is a tensor list, we get input tensors by indexing
|
||||
|
|
@ -201,12 +201,12 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
|
|||
// are accessible as inputs. any hypothetical input tensors that come
|
||||
// after the list are not accessible.
|
||||
auto tensorList = newstyle_inputs_[0].toTensorVector();
|
||||
DCHECK_LT((size_t)idx, tensorList.size());
|
||||
TORCH_DCHECK_LT((size_t)idx, tensorList.size());
|
||||
ival = tensorList[idx];
|
||||
} else {
|
||||
// if the first input is not a tensor list, we get input tensors by
|
||||
// indexing into the inputs.
|
||||
DCHECK_LT((size_t)idx, newstyle_inputs_.size());
|
||||
TORCH_DCHECK_LT((size_t)idx, newstyle_inputs_.size());
|
||||
ival = newstyle_inputs_[idx];
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ class SleepOp final : public Operator<CPUContext> {
|
|||
SleepOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ms_(OperatorBase::GetSingleArgument<int>("ms", 1000)) {
|
||||
DCHECK_GT(ms_, 0);
|
||||
DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?";
|
||||
TORCH_DCHECK_GT(ms_, 0);
|
||||
TORCH_DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?";
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -187,8 +187,8 @@ class C10_EXPORT QTensor {
|
|||
* Returns the i-th dimension of the qtensor in int.
|
||||
*/
|
||||
inline int dim32(const int i) const {
|
||||
DCHECK_LT(i, static_cast<int>(dims_.size())) << "Exceeding ndim limit " << dims_.size();
|
||||
DCHECK_GE(i, 0) << "Cannot have negative index";
|
||||
TORCH_DCHECK_LT(i, static_cast<int>(dims_.size())) << "Exceeding ndim limit " << dims_.size();
|
||||
TORCH_DCHECK_GE(i, 0) << "Cannot have negative index";
|
||||
CAFFE_ENFORCE_LT(dims_[i], std::numeric_limits<int>::max());
|
||||
return static_cast<int>(dims_[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -163,8 +163,8 @@ string MaxPoolRTCFunction::GetSource(
|
|||
stride_w,
|
||||
pad_t,
|
||||
pad_l);
|
||||
DCHECK_GE(nbytes, 0);
|
||||
DCHECK_LT(nbytes, 65536);
|
||||
TORCH_DCHECK_GE(nbytes, 0);
|
||||
TORCH_DCHECK_LT(nbytes, 65536);
|
||||
return string(buffer);
|
||||
}
|
||||
|
||||
|
|
@ -202,8 +202,8 @@ string MaxPoolGradientRTCFunction::GetSource(
|
|||
stride_w,
|
||||
pad_t,
|
||||
pad_l);
|
||||
DCHECK_GE(nbytes, 0);
|
||||
DCHECK_LT(nbytes, 65536);
|
||||
TORCH_DCHECK_GE(nbytes, 0);
|
||||
TORCH_DCHECK_LT(nbytes, 65536);
|
||||
return string(buffer);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ FileStoreHandler::FileStoreHandler(
|
|||
auto ret = mkdir(basePath_.c_str(), 0777);
|
||||
#endif // defined(_MSC_VER)
|
||||
if (ret == -1) {
|
||||
CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno);
|
||||
TORCH_CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ std::string FileStoreHandler::realPath(const std::string& path) {
|
|||
std::array<char, PATH_MAX> buf;
|
||||
auto ret = realpath(path.c_str(), buf.data());
|
||||
#endif
|
||||
CHECK_EQ(buf.data(), ret) << "realpath: " << strerror(errno);
|
||||
TORCH_CHECK_EQ(buf.data(), ret) << "realpath: " << strerror(errno);
|
||||
return std::string(buf.data());
|
||||
}
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ bool FileStoreHandler::check(const std::vector<std::string>& names) {
|
|||
if (fd == -1) {
|
||||
// Only deal with files that don't exist.
|
||||
// Anything else is a problem.
|
||||
CHECK_EQ(errno, ENOENT);
|
||||
TORCH_CHECK_EQ(errno, ENOENT);
|
||||
|
||||
// One of the paths doesn't exist; return early
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -145,10 +145,10 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
const auto& U = Input(1);
|
||||
const auto& V = Input(2);
|
||||
const auto& dY = Input(3);
|
||||
DCHECK_GE(X.dim(), 1);
|
||||
DCHECK_GE(U.dim(), 2);
|
||||
DCHECK_GE(V.dim(), 2);
|
||||
DCHECK_LE(dY.dim(), 2);
|
||||
TORCH_DCHECK_GE(X.dim(), 1);
|
||||
TORCH_DCHECK_GE(U.dim(), 2);
|
||||
TORCH_DCHECK_GE(V.dim(), 2);
|
||||
TORCH_DCHECK_LE(dY.dim(), 2);
|
||||
// batch size
|
||||
int M = X.dim() > 1 ? X.dim32(0) : 1;
|
||||
// Feature dimension
|
||||
|
|
@ -156,13 +156,13 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
// number of outputs.
|
||||
int N = U.dim32(0);
|
||||
int middle = U.dim32(1);
|
||||
DCHECK_EQ(K, V.dim32(0));
|
||||
TORCH_DCHECK_EQ(K, V.dim32(0));
|
||||
if (dY.dim() > 1) {
|
||||
DCHECK_EQ(M, dY.dim32(0));
|
||||
DCHECK_EQ(N, dY.dim32(1));
|
||||
TORCH_DCHECK_EQ(M, dY.dim32(0));
|
||||
TORCH_DCHECK_EQ(N, dY.dim32(1));
|
||||
} else {
|
||||
DCHECK_EQ(X.dim(), 1);
|
||||
DCHECK_EQ(N, dY.numel());
|
||||
TORCH_DCHECK_EQ(X.dim(), 1);
|
||||
TORCH_DCHECK_EQ(N, dY.numel());
|
||||
}
|
||||
|
||||
auto* dU = Output(0, U.sizes(), at::dtype<T>());
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
|
||||
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
|
@ -249,9 +250,9 @@ class FullyConnectedPruneGradientOp : public Operator<Context> {
|
|||
auto& thres = Input(6);
|
||||
// TODO(wyiming): check comp_lb is a float
|
||||
auto& comp_lb = Input(7);
|
||||
DCHECK_GE(X.dim(), 1);
|
||||
DCHECK_GE(W.dim(), 2);
|
||||
DCHECK_LE(dY.dim(), 2);
|
||||
TORCH_DCHECK_GE(X.dim(), 1);
|
||||
TORCH_DCHECK_GE(W.dim(), 2);
|
||||
TORCH_DCHECK_LE(dY.dim(), 2);
|
||||
// batch size
|
||||
int M = X.dim() > 1 ? X.dim32(0) : 1;
|
||||
// Feature dimension
|
||||
|
|
@ -263,17 +264,17 @@ class FullyConnectedPruneGradientOp : public Operator<Context> {
|
|||
// TODO(wyiming): this threshold should be
|
||||
// based on distribution of the layer weight
|
||||
float thr = 0.01;
|
||||
DCHECK_EQ(Mask.dim32(0), W.dim32(0));
|
||||
DCHECK_EQ(Mask.dim32(1), W.dim32(1));
|
||||
DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
|
||||
DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
|
||||
DCHECK_EQ(K, W.numel() / W.dim32(0));
|
||||
TORCH_DCHECK_EQ(Mask.dim32(0), W.dim32(0));
|
||||
TORCH_DCHECK_EQ(Mask.dim32(1), W.dim32(1));
|
||||
TORCH_DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
|
||||
TORCH_DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
|
||||
TORCH_DCHECK_EQ(K, W.numel() / W.dim32(0));
|
||||
if (dY.dim() > 1) {
|
||||
DCHECK_EQ(M, dY.dim32(0));
|
||||
DCHECK_EQ(N, dY.dim32(1));
|
||||
TORCH_DCHECK_EQ(M, dY.dim32(0));
|
||||
TORCH_DCHECK_EQ(N, dY.dim32(1));
|
||||
} else {
|
||||
DCHECK_EQ(X.dim(), 1);
|
||||
DCHECK_EQ(N, dY.numel());
|
||||
TORCH_DCHECK_EQ(X.dim(), 1);
|
||||
TORCH_DCHECK_EQ(N, dY.numel());
|
||||
}
|
||||
|
||||
auto* dW = Output(0, W.sizes(), at::dtype<T>());
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ class IDEEPLRNOp final : public IDEEPOperator {
|
|||
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
|
||||
beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
|
||||
bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
|
||||
DCHECK_GT(size_, 0);
|
||||
DCHECK_EQ(size_ % 2, 1);
|
||||
DCHECK_GT(alpha_, 0);
|
||||
DCHECK_GT(beta_, 0);
|
||||
TORCH_DCHECK_GT(size_, 0);
|
||||
TORCH_DCHECK_EQ(size_ % 2, 1);
|
||||
TORCH_DCHECK_GT(alpha_, 0);
|
||||
TORCH_DCHECK_GT(beta_, 0);
|
||||
}
|
||||
~IDEEPLRNOp() override = default;
|
||||
|
||||
|
|
@ -52,10 +52,10 @@ class IDEEPLRNGradientOp final : public IDEEPOperator {
|
|||
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
|
||||
beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
|
||||
bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
|
||||
DCHECK_GT(size_, 0);
|
||||
DCHECK_EQ(size_ % 2, 1);
|
||||
DCHECK_GT(alpha_, 0);
|
||||
DCHECK_GT(beta_, 0);
|
||||
TORCH_DCHECK_GT(size_, 0);
|
||||
TORCH_DCHECK_EQ(size_ % 2, 1);
|
||||
TORCH_DCHECK_GT(alpha_, 0);
|
||||
TORCH_DCHECK_GT(beta_, 0);
|
||||
}
|
||||
~IDEEPLRNGradientOp() override = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
|||
parent_name += "_cpu_output_blob_" + base_def_.type();
|
||||
}
|
||||
local_output_blobs_.push_back(ws->CreateBlob(parent_name));
|
||||
CHECK_NOTNULL(local_output_blobs_.back());
|
||||
TORCH_CHECK_NOTNULL(local_output_blobs_.back());
|
||||
forwarded_output_blobs[base_def_.output(i)] = parent_name;
|
||||
output_inplace_.push_back(false);
|
||||
for (const string &input_name : base_def_.input()) {
|
||||
|
|
@ -74,7 +74,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
|||
// Set up the symbols for the local workspace.
|
||||
for (const string& name : base_def_.input()) {
|
||||
local_input_blobs_.push_back(local_ws_->CreateBlob(name));
|
||||
CHECK_NOTNULL(local_input_blobs_.back());
|
||||
TORCH_CHECK_NOTNULL(local_input_blobs_.back());
|
||||
}
|
||||
input_share_.resize(local_input_blobs_.size(), false);
|
||||
base_op_.reset(new CPUOp(base_def_, local_ws_.get()));
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class IDEEPInt8GivenTensorFillOp final : public IDEEPOperator {
|
|||
auto data_type = zero_point_ == 0 ? idtype::u8 : idtype::s8;
|
||||
|
||||
output->init({shape_, data_type});
|
||||
DCHECK_EQ(output->get_nelems(), values_.numel())
|
||||
TORCH_DCHECK_EQ(output->get_nelems(), values_.numel())
|
||||
<< "output size: " << output->get_nelems()
|
||||
<< " given size: " << values_.numel();
|
||||
|
||||
|
|
@ -121,7 +121,7 @@ class IDEEPInt8GivenIntTensorFillOp final : public IDEEPOperator {
|
|||
auto* output = Output(OUTPUT);
|
||||
output->init({shape_, idtype::s32});
|
||||
output->set_scale(ConvertScales(scales_));
|
||||
DCHECK_EQ(output->get_nelems(), values_.numel())
|
||||
TORCH_DCHECK_EQ(output->get_nelems(), values_.numel())
|
||||
<< "output size: " << output->get_nelems()
|
||||
<< " given size: " << values_.numel();
|
||||
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ class IDEEPSpatialBNOp final : public IDEEPOperator {
|
|||
const auto& bias = Input(BIAS);
|
||||
auto* Y = Output(OUTPUT);
|
||||
|
||||
DCHECK_EQ(scale.ndims(), 1);
|
||||
DCHECK_EQ(bias.ndims(), 1);
|
||||
DCHECK_EQ(scale.get_dim(0), X.get_dim(1));
|
||||
DCHECK_EQ(bias.get_dim(0), X.get_dim(1));
|
||||
TORCH_DCHECK_EQ(scale.ndims(), 1);
|
||||
TORCH_DCHECK_EQ(bias.ndims(), 1);
|
||||
TORCH_DCHECK_EQ(scale.get_dim(0), X.get_dim(1));
|
||||
TORCH_DCHECK_EQ(bias.get_dim(0), X.get_dim(1));
|
||||
|
||||
if (is_test_) {
|
||||
const auto& est_mean = Input(EST_MEAN);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ OPERATOR_SCHEMA(ImageInput)
|
|||
int batch_size = helper.GetSingleArgument<int>("batch_size", 0);
|
||||
int crop = helper.GetSingleArgument<int>("crop", -1);
|
||||
int color = helper.GetSingleArgument<int>("color", 1);
|
||||
CHECK_GT(crop, 0);
|
||||
TORCH_CHECK_GT(crop, 0);
|
||||
out[0] = CreateTensorShape(
|
||||
vector<int>{batch_size, crop, crop, color ? 3 : 1},
|
||||
TensorProto::FLOAT);
|
||||
|
|
|
|||
|
|
@ -530,8 +530,8 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
|
|||
if (protos.protos_size() == end + 1) {
|
||||
// We have bounding box information
|
||||
const TensorProto& bounding_proto = protos.protos(end);
|
||||
DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
|
||||
DCHECK_EQ(bounding_proto.int32_data_size(), 4);
|
||||
TORCH_DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
|
||||
TORCH_DCHECK_EQ(bounding_proto.int32_data_size(), 4);
|
||||
info.bounding_params.valid = true;
|
||||
info.bounding_params.ymin = bounding_proto.int32_data(0);
|
||||
info.bounding_params.xmin = bounding_proto.int32_data(1);
|
||||
|
|
@ -541,7 +541,7 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
|
|||
|
||||
if (image_proto.data_type() == TensorProto::STRING) {
|
||||
// encoded image string.
|
||||
DCHECK_EQ(image_proto.string_data_size(), 1);
|
||||
TORCH_DCHECK_EQ(image_proto.string_data_size(), 1);
|
||||
const string& encoded_image_str = image_proto.string_data(0);
|
||||
int encoded_size = encoded_image_str.size();
|
||||
// We use a cv::Mat to wrap the encoded str so we do not need a copy.
|
||||
|
|
@ -582,7 +582,7 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
|
|||
// TODO: if image decoding was unsuccessful, set label to 0
|
||||
if (label_proto.data_type() == TensorProto::FLOAT) {
|
||||
if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
|
||||
DCHECK_EQ(label_proto.float_data_size(), 1);
|
||||
TORCH_DCHECK_EQ(label_proto.float_data_size(), 1);
|
||||
prefetched_label_.mutable_data<float>()[item_id] =
|
||||
label_proto.float_data(0);
|
||||
} else if (label_type_ == MULTI_LABEL_SPARSE) {
|
||||
|
|
@ -614,7 +614,7 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
|
|||
}
|
||||
} else if (label_proto.data_type() == TensorProto::INT32) {
|
||||
if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
|
||||
DCHECK_EQ(label_proto.int32_data_size(), 1);
|
||||
TORCH_DCHECK_EQ(label_proto.int32_data_size(), 1);
|
||||
prefetched_label_.mutable_data<int>()[item_id] =
|
||||
label_proto.int32_data(0);
|
||||
} else if (label_type_ == MULTI_LABEL_SPARSE) {
|
||||
|
|
|
|||
|
|
@ -284,7 +284,7 @@ constexpr int computeMPSAlignOffset(int kernel, int pad) {
|
|||
size_t ComputeStartIndex(
|
||||
const TensorCPU& tensor,
|
||||
const std::vector<int>& index) {
|
||||
DCHECK_EQ(index.size(), tensor.dim());
|
||||
TORCH_DCHECK_EQ(index.size(), tensor.dim());
|
||||
|
||||
size_t ret = 0;
|
||||
for (int i = 0; i < index.size(); i++) {
|
||||
|
|
@ -299,7 +299,7 @@ template <class T>
|
|||
utils::ConstTensorView<T> GetSubTensorView(
|
||||
const TensorCPU& tensor,
|
||||
int dim0_start_index) {
|
||||
DCHECK_EQ(tensor.meta().itemsize(), sizeof(T));
|
||||
TORCH_DCHECK_EQ(tensor.meta().itemsize(), sizeof(T));
|
||||
|
||||
if (tensor.size() == 0) {
|
||||
return utils::ConstTensorView<T>(nullptr, {});
|
||||
|
|
@ -1490,7 +1490,7 @@ class MPSCNNConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
|
|||
caffe2::Timer consT;
|
||||
std::vector<float> refilter(kH * kW * output_channels * input_channels);
|
||||
refilter.assign(kH * kW * output_channels * input_channels, 0.0f);
|
||||
DCHECK_EQ(refilter.size(), filter.size());
|
||||
TORCH_DCHECK_EQ(refilter.size(), filter.size());
|
||||
auto* filter_ = filter.template data<float>();
|
||||
// For iOS11+ Reformat weights from WT[IC][OC][kH][kW] to
|
||||
// W[OC][kH][kW][IC]; For previous versions, reformat weights
|
||||
|
|
@ -1512,14 +1512,14 @@ class MPSCNNConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
|
|||
kw * output_channels * input_channels +
|
||||
oc * input_channels + ic;
|
||||
}
|
||||
DCHECK_LT(inputIdx, filter.size());
|
||||
DCHECK_LT(outputIdx, filter.size());
|
||||
TORCH_DCHECK_LT(inputIdx, filter.size());
|
||||
TORCH_DCHECK_LT(outputIdx, filter.size());
|
||||
refilter[outputIdx] = filter_[inputIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW);
|
||||
TORCH_DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW);
|
||||
// initialize data structures
|
||||
if (runtimeAtLeastIOS11) {
|
||||
MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
|
||||
|
|
@ -2225,7 +2225,7 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
|
|||
auto keep =
|
||||
utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
|
||||
|
||||
DCHECK_LE(keep.size(), scores.size());
|
||||
TORCH_DCHECK_LE(keep.size(), scores.size());
|
||||
|
||||
// 4. sort all (proposal, score) pairs by score from highest to lowest
|
||||
// 5. take top pre_nms_topN (e.g. 6000)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ void testMPSCNN() {
|
|||
CAFFE_ENFORCE_EQ(t1.sizes(), t2.sizes());
|
||||
for (auto i = 0; i < t1.size(); ++i) {
|
||||
// FP16 <-> FP32 round trip.
|
||||
CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
|
||||
TORCH_CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -197,7 +197,7 @@ void testMPSCNN() {
|
|||
CAFFE_ENFORCE_EQ(t1.size(), t2.size());
|
||||
for (auto i = 0; i < t1.size(); ++i) {
|
||||
// FP16 <-> FP32 round trip.
|
||||
CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
|
||||
TORCH_CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -274,7 +274,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -467,7 +467,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -560,7 +560,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -651,7 +651,7 @@ void testMPSCNN() {
|
|||
const float t2_i = t2.data<float>()[i];
|
||||
// LOG(INFO) << "i: " << i << ", cpu: " << t1_i << ", mtl: " <<
|
||||
// t2_i;
|
||||
CHECK_NEAR(t1_i, t2_i, 0.7);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.7);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -763,7 +763,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -851,7 +851,7 @@ void testMPSCNN() {
|
|||
const float t2_i = t2.data<float>()[i];
|
||||
// LOG(INFO) << "i: " << i << ", " << "CPU: " << t1_i << ", MTL: " <<
|
||||
// t2_i;
|
||||
CHECK_NEAR(t1_i, t2_i, 0.01);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.01);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -932,7 +932,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -991,7 +991,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<uint8_t>()[i];
|
||||
const float t2_i = t2.data<uint8_t>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1050,7 +1050,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<uint8_t>()[i];
|
||||
const float t2_i = t2.data<uint8_t>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1166,7 +1166,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.2);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1264,7 +1264,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.3);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1378,7 +1378,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1481,7 +1481,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1589,7 +1589,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1713,7 +1713,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1784,7 +1784,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.02);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.02);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1849,7 +1849,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.01);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1914,7 +1914,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.01);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2003,7 +2003,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.05);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.05);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2057,7 +2057,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.02);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.02);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2123,7 +2123,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2237,7 +2237,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2349,7 +2349,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2428,7 +2428,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2570,7 +2570,7 @@ void testMPSCNN() {
|
|||
const float t3_i = t3.data<float>()[i / 5];
|
||||
if (t3_i - HALF_MIN_VAL * 2 > 0) {
|
||||
LOG(INFO) << i << " " << t1_i << " " << t2_i << " " << t3_i;
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2579,7 +2579,7 @@ void testMPSCNN() {
|
|||
const float t3_i = t3.data<float>()[i];
|
||||
const float t4_i = t4.data<float>()[i];
|
||||
LOG(INFO) << i << " " << t3_i;
|
||||
CHECK_NEAR(t3_i, t4_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t3_i, t4_i, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2634,7 +2634,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2875,7 +2875,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2943,7 +2943,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3030,7 +3030,7 @@ void testMPSCNN() {
|
|||
// FP16 <-> FP32 round trip, accumulation, etc.
|
||||
const float t1_i = t1.data<float>()[i];
|
||||
const float t2_i = t2.data<float>()[i];
|
||||
CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3072,10 +3072,10 @@ void testMPSCNN() {
|
|||
}
|
||||
return arg->i();
|
||||
};
|
||||
CHECK_EQ(rc(0), 1);
|
||||
CHECK_EQ(rc(1), 2);
|
||||
CHECK_EQ(rc(2), 1);
|
||||
CHECK_EQ(rc(3), 1);
|
||||
TORCH_CHECK_EQ(rc(0), 1);
|
||||
TORCH_CHECK_EQ(rc(1), 2);
|
||||
TORCH_CHECK_EQ(rc(2), 1);
|
||||
TORCH_CHECK_EQ(rc(3), 1);
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -3117,18 +3117,18 @@ void testMPSCNN() {
|
|||
auto ty = [&](size_t i) { return netdef.op(i).type(); };
|
||||
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
|
||||
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
|
||||
CHECK_EQ(netdef.op_size(), 4);
|
||||
CHECK_EQ(ty(0), "CopyToMPSCNN");
|
||||
CHECK_EQ(ty(1), std::string("MPSCNN") + computeOp + std::string("Relu"));
|
||||
CHECK_EQ(ty(2), std::string("MPSCNN") + computeOp + std::string("Relu"));
|
||||
CHECK_EQ(ty(3), "CopyFromMPSCNN");
|
||||
CHECK_EQ(i0(0), "X");
|
||||
CHECK_EQ(i0(1), o0(0));
|
||||
CHECK_EQ(i0(2), "X2");
|
||||
CHECK_EQ(o0(2), i0(3));
|
||||
CHECK_EQ(o0(3), "Y");
|
||||
CHECK_EQ(netdef.external_input(0), "X");
|
||||
CHECK_EQ(netdef.external_output(0), "Y");
|
||||
TORCH_CHECK_EQ(netdef.op_size(), 4);
|
||||
TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN");
|
||||
TORCH_CHECK_EQ(ty(1), std::string("MPSCNN") + computeOp + std::string("Relu"));
|
||||
TORCH_CHECK_EQ(ty(2), std::string("MPSCNN") + computeOp + std::string("Relu"));
|
||||
TORCH_CHECK_EQ(ty(3), "CopyFromMPSCNN");
|
||||
TORCH_CHECK_EQ(i0(0), "X");
|
||||
TORCH_CHECK_EQ(i0(1), o0(0));
|
||||
TORCH_CHECK_EQ(i0(2), "X2");
|
||||
TORCH_CHECK_EQ(o0(2), i0(3));
|
||||
TORCH_CHECK_EQ(o0(3), "Y");
|
||||
TORCH_CHECK_EQ(netdef.external_input(0), "X");
|
||||
TORCH_CHECK_EQ(netdef.external_output(0), "Y");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3195,18 +3195,18 @@ void testMPSCNN() {
|
|||
op.add_output("Z");
|
||||
}
|
||||
netdef = rewriteForMetal(netdef);
|
||||
CHECK_EQ(netdef.op_size(), 4);
|
||||
TORCH_CHECK_EQ(netdef.op_size(), 4);
|
||||
auto ty = [&](size_t i) { return netdef.op(i).type(); };
|
||||
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
|
||||
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
|
||||
CHECK_EQ(ty(0), "CopyToMPSCNN");
|
||||
CHECK_EQ(ty(1), "MPSCNNConvRelu");
|
||||
CHECK_EQ(ty(2), "MPSCNNRelu");
|
||||
CHECK_EQ(ty(3), "CopyFromMPSCNN");
|
||||
CHECK_EQ(i0(1), o0(0));
|
||||
CHECK_EQ(o0(1), "Z");
|
||||
CHECK_EQ(i0(2), "Z");
|
||||
CHECK_EQ(o0(2), i0(3));
|
||||
TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN");
|
||||
TORCH_CHECK_EQ(ty(1), "MPSCNNConvRelu");
|
||||
TORCH_CHECK_EQ(ty(2), "MPSCNNRelu");
|
||||
TORCH_CHECK_EQ(ty(3), "CopyFromMPSCNN");
|
||||
TORCH_CHECK_EQ(i0(1), o0(0));
|
||||
TORCH_CHECK_EQ(o0(1), "Z");
|
||||
TORCH_CHECK_EQ(i0(2), "Z");
|
||||
TORCH_CHECK_EQ(o0(2), i0(3));
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -3235,21 +3235,21 @@ void testMPSCNN() {
|
|||
op.add_output("Z");
|
||||
}
|
||||
netdef = rewriteForMetal(netdef);
|
||||
CHECK_EQ(netdef.op_size(), 5);
|
||||
TORCH_CHECK_EQ(netdef.op_size(), 5);
|
||||
auto ty = [&](size_t i) { return netdef.op(i).type(); };
|
||||
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
|
||||
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
|
||||
CHECK_EQ(ty(0), "CopyToMPSCNN");
|
||||
CHECK_EQ(ty(1), "MPSCNNConv");
|
||||
CHECK_EQ(ty(2), "MPSCNNRelu");
|
||||
CHECK_EQ(ty(3), "MPSCNNRelu");
|
||||
CHECK_EQ(ty(4), "CopyFromMPSCNN");
|
||||
CHECK_EQ(i0(1), o0(0));
|
||||
CHECK_EQ(o0(1), "Y");
|
||||
CHECK_EQ(i0(2), o0(1));
|
||||
CHECK_EQ(o0(2), "Z");
|
||||
CHECK_EQ(i0(3), o0(1));
|
||||
CHECK_EQ(o0(3), i0(4));
|
||||
TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN");
|
||||
TORCH_CHECK_EQ(ty(1), "MPSCNNConv");
|
||||
TORCH_CHECK_EQ(ty(2), "MPSCNNRelu");
|
||||
TORCH_CHECK_EQ(ty(3), "MPSCNNRelu");
|
||||
TORCH_CHECK_EQ(ty(4), "CopyFromMPSCNN");
|
||||
TORCH_CHECK_EQ(i0(1), o0(0));
|
||||
TORCH_CHECK_EQ(o0(1), "Y");
|
||||
TORCH_CHECK_EQ(i0(2), o0(1));
|
||||
TORCH_CHECK_EQ(o0(2), "Z");
|
||||
TORCH_CHECK_EQ(i0(3), o0(1));
|
||||
TORCH_CHECK_EQ(o0(3), i0(4));
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -3277,14 +3277,14 @@ void testMPSCNN() {
|
|||
auto ty = [&](size_t i) { return netdef.op(i).type(); };
|
||||
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
|
||||
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
|
||||
CHECK_EQ(netdef.op_size(), 3);
|
||||
CHECK_EQ(ty(0), "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess");
|
||||
CHECK_EQ(ty(1), "MPSCNNRelu");
|
||||
CHECK_EQ(ty(2), "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess");
|
||||
CHECK_EQ(i0(0), "X");
|
||||
CHECK_EQ(i0(1), o0(0));
|
||||
CHECK_EQ(i0(2), o0(1));
|
||||
CHECK_EQ(o0(2), "Z");
|
||||
TORCH_CHECK_EQ(netdef.op_size(), 3);
|
||||
TORCH_CHECK_EQ(ty(0), "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess");
|
||||
TORCH_CHECK_EQ(ty(1), "MPSCNNRelu");
|
||||
TORCH_CHECK_EQ(ty(2), "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess");
|
||||
TORCH_CHECK_EQ(i0(0), "X");
|
||||
TORCH_CHECK_EQ(i0(1), o0(0));
|
||||
TORCH_CHECK_EQ(i0(2), o0(1));
|
||||
TORCH_CHECK_EQ(o0(2), "Z");
|
||||
}
|
||||
LOG(INFO) << "All MPSCNN tests passed.";
|
||||
}
|
||||
|
|
@ -3296,12 +3296,12 @@ NetDef truncateAfter(NetDef def, size_t idx) {
|
|||
for (auto i = 0; i < toRemove; ++i) {
|
||||
def.mutable_op()->RemoveLast();
|
||||
}
|
||||
CHECK_EQ(def.op_size(), idx + 1);
|
||||
TORCH_CHECK_EQ(def.op_size(), idx + 1);
|
||||
return def;
|
||||
}
|
||||
|
||||
NetDef addMPSCNNCopyFinalizer(NetDef def) {
|
||||
CHECK_GE(def.op_size(), 1);
|
||||
TORCH_CHECK_GE(def.op_size(), 1);
|
||||
const auto name = def.mutable_op(def.op_size() - 1)->output(0);
|
||||
def.mutable_op(def.op_size() - 1)->set_output(0, "METAL_COPIER");
|
||||
{
|
||||
|
|
@ -3315,7 +3315,7 @@ NetDef addMPSCNNCopyFinalizer(NetDef def) {
|
|||
|
||||
void compareModels(const NetDef& initNet, NetDef predictNet) {
|
||||
auto* arg = predictNet.mutable_op(0)->mutable_arg(0);
|
||||
CHECK_EQ(arg->name(), "noise_std");
|
||||
TORCH_CHECK_EQ(arg->name(), "noise_std");
|
||||
arg->set_f(0.000001);
|
||||
|
||||
NetDef metalPredictNet;
|
||||
|
|
@ -3365,7 +3365,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
|
|||
{
|
||||
const auto& mt = mws.GetBlob(name)->Get<TensorCPU>();
|
||||
const auto& ct = cws.GetBlob(name)->Get<TensorCPU>();
|
||||
CHECK_EQ(mt.sizes(), ct.sizes());
|
||||
TORCH_CHECK_EQ(mt.sizes(), ct.sizes());
|
||||
for (auto j = 0; j < mt.size(); ++j) {
|
||||
if (mt.IsType<float>()) {
|
||||
if (j < 10) {
|
||||
|
|
@ -3373,7 +3373,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
|
|||
<< ", CPU: " << ct.data<float>()[j]
|
||||
<< ", MTL: " << mt.data<float>()[j];
|
||||
}
|
||||
CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
|
||||
TORCH_CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
|
||||
} else {
|
||||
CHECK(mt.IsType<uint8_t>());
|
||||
if (j < 10) {
|
||||
|
|
@ -3381,7 +3381,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
|
|||
<< ", CPU: " << ct.data<uint8_t>()[j]
|
||||
<< ", MTL: " << mt.data<uint8_t>()[j];
|
||||
}
|
||||
CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
|
||||
TORCH_CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3428,7 +3428,7 @@ void verifyRewrite(
|
|||
LOG(INFO) << "One of the operator failed.";
|
||||
return;
|
||||
}
|
||||
// CHECK_EQ(mt.sizes(), ct.sizes());
|
||||
// TORCH_CHECK_EQ(mt.sizes(), ct.sizes());
|
||||
for (auto j = 0; j < fmin(mt.size(), ct.size()); ++j) {
|
||||
if (mt.IsType<float>()) {
|
||||
if (j < 10) {
|
||||
|
|
@ -3437,7 +3437,7 @@ void verifyRewrite(
|
|||
<< ", MTL: " << mt.data<float>()[j];
|
||||
}
|
||||
// Disabling check for now because of precision issues
|
||||
// CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
|
||||
// TORCH_CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
|
||||
} else {
|
||||
LOG(INFO) << "Type uint8_t";
|
||||
CHECK(mt.IsType<uint8_t>());
|
||||
|
|
@ -3447,7 +3447,7 @@ void verifyRewrite(
|
|||
<< ", MTL: " << mt.data<uint8_t>()[j];
|
||||
}
|
||||
// Disabling check for now.
|
||||
// CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
|
||||
// TORCH_CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ void filterNormalization11(const TensorCPU& WQ, TensorCPU* WQN) {
|
|||
for (auto j = 0; j < WQs; ++j) {
|
||||
bitSum += __builtin_popcount(WQdata[f * WQs + j]);
|
||||
}
|
||||
DCHECK_LE(bitSum, WQbits);
|
||||
TORCH_DCHECK_LE(bitSum, WQbits);
|
||||
WQNdata[f] = 2 * bitSum - WQbits;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ inline void quantize2bNeon(size_t QC,
|
|||
float offset,
|
||||
float inter_center_distance,
|
||||
std::array<uint8_t*, k2b1bXBits> XQdata) {
|
||||
DCHECK_EQ(QC % 8, 0);
|
||||
TORCH_DCHECK_EQ(QC % 8, 0);
|
||||
const auto offset_plus_2_inter_center_distance = vdupq_n_f32(offset + 2 * inter_center_distance);
|
||||
const auto offset_plus_inter_center_distance = vdupq_n_f32(offset + inter_center_distance);
|
||||
const auto offset_ = vdupq_n_f32(offset);
|
||||
|
|
@ -291,7 +291,7 @@ void qgess_packed(const uint8_t* __restrict__ Ablock,
|
|||
F&& f) {
|
||||
static_assert(kUnrollN % 8 == 0, "");
|
||||
static_assert(TileDepthBytes == 16, "");
|
||||
DCHECK_EQ(QK % 16, 0);
|
||||
TORCH_DCHECK_EQ(QK % 16, 0);
|
||||
uint16x8_t acc[kUnrollM][kUnrollN / 8];
|
||||
for (size_t mm = 0; mm < kUnrollM; ++mm) {
|
||||
for (size_t nn = 0; nn < kUnrollN / 8; ++nn) {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ void conv(const ConvArgs& args,
|
|||
(X.dim32(1) - KH + args.pad_t + args.pad_b) / args.stride_h + 1,
|
||||
(X.dim32(2) - KW + args.pad_l + args.pad_r) / args.stride_w + 1,
|
||||
W.dim32(0));
|
||||
CHECK_EQ(W.dim32(3), X.dim32(3));
|
||||
TORCH_CHECK_EQ(W.dim32(3), X.dim32(3));
|
||||
const auto OH = Y->dim32(1);
|
||||
const auto OW = Y->dim32(2);
|
||||
const auto OC = Y->dim32(3);
|
||||
|
|
@ -155,7 +155,7 @@ inline void gemmNT(int M, int N, int K, const float* A, const float* B, float* C
|
|||
}
|
||||
|
||||
inline void qgemmNT(int M, int N, int K, const uint8_t* A, const uint8_t* B, float* C) {
|
||||
CHECK_EQ(K % 8, 0);
|
||||
TORCH_CHECK_EQ(K % 8, 0);
|
||||
const int QK = K / 8;
|
||||
for (auto m = 0; m < M; ++m) {
|
||||
for (auto n = 0; n < N; ++n) {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ void APMeterOp<float, CPUContext>::BufferPredictions(
|
|||
// Initialize the buffer
|
||||
buffers_.resize(D, std::vector<BufferDataType>(buffer_size_));
|
||||
}
|
||||
DCHECK_EQ(buffers_.size(), D);
|
||||
TORCH_DCHECK_EQ(buffers_.size(), D);
|
||||
|
||||
// Fill atmose buffer_size_ data at a time, so truncate the input if needed
|
||||
if (N > buffer_size_) {
|
||||
|
|
@ -48,12 +48,12 @@ bool APMeterOp<float, CPUContext>::RunOnDevice() {
|
|||
auto& label = Input(LABEL);
|
||||
|
||||
// Check dimensions
|
||||
DCHECK_EQ(X.dim(), 2);
|
||||
TORCH_DCHECK_EQ(X.dim(), 2);
|
||||
int N = X.dim32(0);
|
||||
int D = X.dim32(1);
|
||||
DCHECK_EQ(label.dim(), 2);
|
||||
DCHECK_EQ(label.dim32(0), N);
|
||||
DCHECK_EQ(label.dim32(1), D);
|
||||
TORCH_DCHECK_EQ(label.dim(), 2);
|
||||
TORCH_DCHECK_EQ(label.dim32(0), N);
|
||||
TORCH_DCHECK_EQ(label.dim32(1), D);
|
||||
auto* Y = Output(0, {D}, at::dtype<float>());
|
||||
|
||||
const auto* Xdata = X.data<float>();
|
||||
|
|
|
|||
|
|
@ -124,8 +124,8 @@ bool BatchBoxCoxOp<CPUContext>::DoRunWithType() {
|
|||
if (K > 1) {
|
||||
TileArrayIntoVector(lambda1_ptr, D, K, &b.lambda1_);
|
||||
TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_);
|
||||
DCHECK_EQ(K * D, b.lambda1_.size());
|
||||
DCHECK_EQ(K * D, b.lambda2_.size());
|
||||
TORCH_DCHECK_EQ(K * D, b.lambda1_.size());
|
||||
TORCH_DCHECK_EQ(K * D, b.lambda2_.size());
|
||||
for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) {
|
||||
BoxCoxNonzeroLambda(
|
||||
K * D,
|
||||
|
|
@ -144,7 +144,7 @@ bool BatchBoxCoxOp<CPUContext>::DoRunWithType() {
|
|||
int64_t i = 0;
|
||||
if (K > 1) {
|
||||
TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_z_);
|
||||
DCHECK_EQ(K * D, b.lambda2_z_.size());
|
||||
TORCH_DCHECK_EQ(K * D, b.lambda2_z_.size());
|
||||
for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) {
|
||||
BoxCoxZeroLambda(
|
||||
K * D, data_ptr, b.lambda2_z_.data(), k_eps, output_ptr);
|
||||
|
|
@ -176,9 +176,9 @@ bool BatchBoxCoxOp<CPUContext>::DoRunWithType() {
|
|||
zeros_.resize(D - n);
|
||||
TileIndicesInPlace(&nonzeros_, D, K);
|
||||
TileIndicesInPlace(&zeros_, D, K);
|
||||
DCHECK_EQ(nonzeros_.size(), b.lambda1_.size());
|
||||
DCHECK_EQ(nonzeros_.size(), b.lambda2_.size());
|
||||
DCHECK_EQ(zeros_.size(), b.lambda2_z_.size());
|
||||
TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda1_.size());
|
||||
TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda2_.size());
|
||||
TORCH_DCHECK_EQ(zeros_.size(), b.lambda2_z_.size());
|
||||
for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) {
|
||||
BoxCoxMixedLambda(
|
||||
data_ptr,
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ bool BBoxTransformOp<float, CPUContext>::RunOnDevice() {
|
|||
CAFFE_ENFORCE_EQ(iminfo_in.dim32(1), 3);
|
||||
const int batch_size = iminfo_in.dim32(0);
|
||||
|
||||
DCHECK_EQ(weights_.size(), 4);
|
||||
TORCH_DCHECK_EQ(weights_.size(), 4);
|
||||
|
||||
Eigen::Map<const ERArrXXf> boxes0(
|
||||
roi_in.data<float>(), roi_in.dim32(0), roi_in.dim32(1));
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ TEST(BooleanUnmaskTest, Test) {
|
|||
auto& unmasked_data = unmasked_data_blob->Get<TensorCPU>();
|
||||
EXPECT_EQ(unmasked_data.numel(), 1);
|
||||
|
||||
CHECK_EQ(unmasked_data.data<float>()[0], 1.0f);
|
||||
TORCH_CHECK_EQ(unmasked_data.data<float>()[0], 1.0f);
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -161,14 +161,14 @@ const auto& tscores = Input(0);
|
|||
|
||||
// Pick the first `detections_per_im_` boxes with highest scores
|
||||
auto all_scores_sorted = get_all_scores_sorted();
|
||||
DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
|
||||
TORCH_DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
|
||||
|
||||
// Reconstruct keeps from `all_scores_sorted`
|
||||
for (auto& cur_keep : keeps) {
|
||||
cur_keep.clear();
|
||||
}
|
||||
for (int i = 0; i < detections_per_im_; i++) {
|
||||
DCHECK_GT(all_scores_sorted.size(), i);
|
||||
TORCH_DCHECK_GT(all_scores_sorted.size(), i);
|
||||
auto& cur = all_scores_sorted[i];
|
||||
keeps[cur.first].push_back(cur.second);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -272,8 +272,8 @@ bool MakeTwoClassOp<float, CPUContext>::RunOnDevice() {
|
|||
const auto* Xdata = X.data<float>();
|
||||
auto* Ydata = Y->template mutable_data<float>();
|
||||
for (int64_t i = 0; i < N; ++i) {
|
||||
DCHECK_GE(Xdata[i], 0.0);
|
||||
DCHECK_LE(Xdata[i], 1.0);
|
||||
TORCH_DCHECK_GE(Xdata[i], 0.0);
|
||||
TORCH_DCHECK_LE(Xdata[i], 1.0);
|
||||
Ydata[i * 2] = 1.0 - Xdata[i];
|
||||
Ydata[i * 2 + 1] = Xdata[i];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ void DeformConvOpBase<DType, Context>::DeformableIm2col(
|
|||
at::IntArrayRef im_shape,
|
||||
at::IntArrayRef col_shape,
|
||||
DType* data_col) {
|
||||
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
|
||||
TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
|
||||
CAFFE_ENFORCE_EQ(pad_t(), pad_b());
|
||||
CAFFE_ENFORCE_EQ(pad_l(), pad_r());
|
||||
const int pad_h = pad_t();
|
||||
|
|
@ -444,7 +444,7 @@ void DeformConvOpBase<DType, Context>::DeformableCol2im(
|
|||
index_t channel_per_deformable_group = im_shape[1] / deformable_group_;
|
||||
index_t num_kernels = size_from_dim_(0, col_shape);
|
||||
// num_axes should be smaller than block size
|
||||
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
|
||||
TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
|
||||
// To avoid involving atomic operations, we will launch one kernel per
|
||||
// bottom dimension, and then in the kernel add up the top dimensions.
|
||||
// NOLINT_NEXT_LINE(whitespace/operators)
|
||||
|
|
@ -592,7 +592,7 @@ void DeformConvOpBase<DType, Context>::DeformableCol2imCoord(
|
|||
kernel_w() * deformable_group_;
|
||||
index_t channel_per_deformable_group = col_shape[0] / deformable_group_;
|
||||
// num_axes should be smaller than block size
|
||||
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
|
||||
TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
|
||||
// To avoid involving atomic operations, we will launch one kernel per
|
||||
// bottom dimension, and then in the kernel add up the top dimensions.
|
||||
// NOLINT_NEXT_LINE(whitespace/operators)
|
||||
|
|
|
|||
|
|
@ -444,7 +444,7 @@ class GaussianFillOp final : public FillerOp<Context> {
|
|||
: FillerOp<Context>(std::forward<Args>(args)...),
|
||||
mean_(this->template GetSingleArgument<float>("mean", 0)),
|
||||
std_(this->template GetSingleArgument<float>("std", 1)) {
|
||||
DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
|
||||
TORCH_DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
|
||||
}
|
||||
|
||||
bool Fill(Tensor* output) override {
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class FullyConnectedOp final : public Operator<Context> {
|
|||
|
||||
Y_shape_cache_ = X.sizes().vec();
|
||||
// This is an invariant of canonical_axis, so we can DCHECK.
|
||||
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
|
||||
TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
|
||||
Y_shape_cache_.resize(canonical_axis + 1);
|
||||
Y_shape_cache_[canonical_axis] = N;
|
||||
auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ namespace {
|
|||
size_t ComputeStartIndex(
|
||||
const TensorCPU& tensor,
|
||||
const std::vector<int>& index) {
|
||||
DCHECK_EQ(index.size(), tensor.dim());
|
||||
TORCH_DCHECK_EQ(index.size(), tensor.dim());
|
||||
|
||||
size_t ret = 0;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
|
|
@ -27,7 +27,7 @@ template <class T>
|
|||
utils::ConstTensorView<T> GetSubTensorView(
|
||||
const TensorCPU& tensor,
|
||||
int dim0_start_index) {
|
||||
DCHECK_EQ(tensor.dtype().itemsize(), sizeof(T));
|
||||
TORCH_DCHECK_EQ(tensor.dtype().itemsize(), sizeof(T));
|
||||
|
||||
if (tensor.numel() == 0) {
|
||||
return utils::ConstTensorView<T>(nullptr, {});
|
||||
|
|
@ -244,7 +244,7 @@ void GenerateProposalsOp<CPUContext>::ProposalsForOneImage(
|
|||
// 3. remove predicted boxes with either height or width < min_size
|
||||
auto keep =
|
||||
utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
|
||||
DCHECK_LE(keep.size(), scores_sorted.size());
|
||||
TORCH_DCHECK_LE(keep.size(), scores_sorted.size());
|
||||
|
||||
// 6. apply loose nms (e.g. threshold = 0.7)
|
||||
// 7. take after_nms_topN (e.g. 300)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class ConstTensorView {
|
|||
return dims_;
|
||||
}
|
||||
int dim(int i) const {
|
||||
DCHECK_LE(i, dims_.size());
|
||||
TORCH_DCHECK_LE(i, dims_.size());
|
||||
return dims_[i];
|
||||
}
|
||||
const T* data() const {
|
||||
|
|
|
|||
|
|
@ -316,7 +316,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0GPU) {
|
|||
|
||||
// Add angle in bbox deltas
|
||||
int num_boxes = scores.size();
|
||||
CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
vector<float> bbx_with_angle(num_boxes * box_dim);
|
||||
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
|
||||
// at each spatial location for each anchor.
|
||||
|
|
@ -516,7 +516,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedGPU) {
|
|||
|
||||
// Add angle in bbox deltas
|
||||
int num_boxes = scores.size();
|
||||
CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
vector<float> bbx_with_angle(num_boxes * box_dim);
|
||||
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
|
||||
// at each spatial location for each anchor.
|
||||
|
|
|
|||
|
|
@ -494,7 +494,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
|
|||
|
||||
// Add angle in bbox deltas
|
||||
auto num_boxes = scores.size();
|
||||
CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
vector<float> bbx_with_angle(num_boxes * box_dim);
|
||||
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
|
||||
// at each spatial location for each anchor.
|
||||
|
|
@ -667,7 +667,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
|
|||
|
||||
// Add angle in bbox deltas
|
||||
auto num_boxes = scores.size();
|
||||
CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
|
||||
vector<float> bbx_with_angle(num_boxes * box_dim);
|
||||
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
|
||||
// at each spatial location for each anchor.
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp<Context> {
|
|||
}
|
||||
|
||||
bool Fill(Tensor* output) override {
|
||||
DCHECK_EQ(output->numel(), values_.numel())
|
||||
TORCH_DCHECK_EQ(output->numel(), values_.numel())
|
||||
<< "output size: " << output->numel()
|
||||
<< " given size: " << values_.numel();
|
||||
auto* data = output->template mutable_data<uint8_t>();
|
||||
|
|
@ -51,7 +51,7 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp<Context> {
|
|||
private:
|
||||
void Extract() {
|
||||
auto source_values = this->template GetRepeatedArgument<string>("values");
|
||||
DCHECK_EQ(source_values.size(), 1)
|
||||
TORCH_DCHECK_EQ(source_values.size(), 1)
|
||||
<< "expected size: 1 "
|
||||
<< " given size: " << source_values.size();
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ bool LRNOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
|
|||
// Note(Yangqing): this one is copied from my Caffe implementation.
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
|
|
@ -81,7 +81,7 @@ bool LRNOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
|
|||
// variants have I written...?
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int H = X.dim32(1);
|
||||
const int W = X.dim32(2);
|
||||
|
|
@ -135,7 +135,7 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
|
|||
auto& Y = Input(1);
|
||||
auto& dY = Input(2);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
|
|
@ -143,8 +143,8 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
|
|||
const int image_size = C * H * W;
|
||||
// Loosely checking the size, assuming that the shapes will be the same as
|
||||
// long as the sizes check out.
|
||||
DCHECK_EQ(X.numel(), Y.numel());
|
||||
DCHECK_EQ(X.numel(), dY.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), dY.numel());
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<float>());
|
||||
|
||||
const float* Xdata = X.data<float>();
|
||||
|
|
@ -248,7 +248,7 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
|
|||
auto& Y = Input(1);
|
||||
auto& dY = Input(2);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int H = X.dim32(1);
|
||||
const int W = X.dim32(2);
|
||||
|
|
@ -257,8 +257,8 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
|
|||
const float* Xdata = X.data<float>();
|
||||
// Loosely checking the size, assuming that the shapes will be the same as
|
||||
// long as the sizes check out.
|
||||
DCHECK_EQ(X.numel(), Y.numel());
|
||||
DCHECK_EQ(X.numel(), dY.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), dY.numel());
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<float>());
|
||||
if (!scale_) {
|
||||
scale_ = &local_scale_tensor_;
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ template<>
|
|||
bool LRNOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
|
|
@ -214,7 +214,7 @@ template<>
|
|||
bool LRNOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int H = X.dim32(1);
|
||||
const int W = X.dim32(2);
|
||||
|
|
@ -252,15 +252,15 @@ bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
|
|||
auto& Y = Input(1);
|
||||
auto& dY = Input(2);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
const int W = X.dim32(3);
|
||||
// Loosely checking the size, assuming that the shapes will be the same as
|
||||
// long as the sizes check out.
|
||||
DCHECK_EQ(X.numel(), Y.numel());
|
||||
DCHECK_EQ(X.numel(), dY.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), dY.numel());
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<float>());
|
||||
|
||||
const float* Xdata = X.data<float>();
|
||||
|
|
@ -295,7 +295,7 @@ bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
|||
auto& Y = Input(1);
|
||||
auto& dY = Input(2);
|
||||
|
||||
DCHECK_EQ(X.dim(), 4);
|
||||
TORCH_DCHECK_EQ(X.dim(), 4);
|
||||
const int N = X.dim32(0);
|
||||
const int H = X.dim32(1);
|
||||
const int W = X.dim32(2);
|
||||
|
|
@ -303,8 +303,8 @@ bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
|||
const float* Xdata = X.data<float>();
|
||||
// Loosely checking the size, assuming that the shapes will be the same as
|
||||
// long as the sizes check out.
|
||||
DCHECK_EQ(X.numel(), Y.numel());
|
||||
DCHECK_EQ(X.numel(), dY.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(X.numel(), dY.numel());
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<float>());
|
||||
if (!scale_) {
|
||||
scale_ = &local_scale_tensor_;
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ class LRNOpBase : public Operator<Context> {
|
|||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
pre_pad_((size_ - 1) / 2) {
|
||||
DCHECK_GT(size_, 0);
|
||||
DCHECK_EQ(size_ % 2, 1);
|
||||
DCHECK_GT(alpha_, 0);
|
||||
DCHECK_GT(beta_, 0);
|
||||
TORCH_DCHECK_GT(size_, 0);
|
||||
TORCH_DCHECK_EQ(size_ % 2, 1);
|
||||
TORCH_DCHECK_GT(alpha_, 0);
|
||||
TORCH_DCHECK_GT(beta_, 0);
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ class GetGPUMemoryUsageOp final : public Operator<CUDAContext> {
|
|||
~GetGPUMemoryUsageOp() override {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
CHECK_EQ(InputSize(), 0);
|
||||
CHECK_EQ(OutputSize(), 1);
|
||||
TORCH_CHECK_EQ(InputSize(), 0);
|
||||
TORCH_CHECK_EQ(OutputSize(), 1);
|
||||
std::vector<long> total_by_gpu = CUDAContext::TotalMemoryByGpu();
|
||||
std::vector<long> max_by_gpu = CUDAContext::MaxMemoryByGpu();
|
||||
CHECK_EQ(total_by_gpu.size(), max_by_gpu.size());
|
||||
TORCH_CHECK_EQ(total_by_gpu.size(), max_by_gpu.size());
|
||||
|
||||
|
||||
auto* stats = Output(0, {2, static_cast<int64_t>(total_by_gpu.size())}, at::dtype<long>());
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
|
|||
auto& X = Input(PREDICTION);
|
||||
auto& label = Input(LABEL);
|
||||
|
||||
DCHECK_EQ(X.dim(), 2);
|
||||
TORCH_DCHECK_EQ(X.dim(), 2);
|
||||
// amount, number of instances
|
||||
int N = X.dim32(0);
|
||||
// dimension, number of classes
|
||||
int D = X.dim32(1);
|
||||
DCHECK_EQ(label.dim(), 1);
|
||||
DCHECK_EQ(label.dim32(0), N);
|
||||
TORCH_DCHECK_EQ(label.dim(), 1);
|
||||
TORCH_DCHECK_EQ(label.dim32(0), N);
|
||||
auto* Y0 = Output(0, {D}, at::dtype<float>());
|
||||
auto* Y1 = Output(1, {D}, at::dtype<int>());
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
|
|||
}
|
||||
}
|
||||
int labelid = labeldata[i];
|
||||
DCHECK_LT(labelid, D);
|
||||
TORCH_DCHECK_LT(labelid, D);
|
||||
if (maxid == labelid) {
|
||||
accuracies[labelid]++;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,13 +40,13 @@ bool MultiClassAccuracyOp<float, CUDAContext>::RunOnDevice() {
|
|||
auto& label = Input(LABEL);
|
||||
|
||||
|
||||
DCHECK_EQ(X.dim(), 2);
|
||||
TORCH_DCHECK_EQ(X.dim(), 2);
|
||||
// amount, number of instances
|
||||
int N = X.dim32(0);
|
||||
// dimension, number of classes
|
||||
int D = X.dim32(1);
|
||||
DCHECK_EQ(label.dim(), 1);
|
||||
DCHECK_EQ(label.dim32(0), N);
|
||||
TORCH_DCHECK_EQ(label.dim(), 1);
|
||||
TORCH_DCHECK_EQ(label.dim32(0), N);
|
||||
auto* Y0 = Output(0, {D}, at::dtype<float>());
|
||||
auto* Y1 = Output(1, {D}, at::dtype<int>());
|
||||
|
||||
|
|
|
|||
|
|
@ -52,12 +52,12 @@ class GPUFallbackOpEx final : public Operator<CUDAContext> {
|
|||
// Set up the symbols for the local workspace.
|
||||
for (const string& name : def.input()) {
|
||||
local_input_blobs_.push_back(local_ws_.CreateBlob(name));
|
||||
CHECK_NOTNULL(local_input_blobs_.back());
|
||||
TORCH_CHECK_NOTNULL(local_input_blobs_.back());
|
||||
}
|
||||
base_op_ = CreateOperator(base_def_, &local_ws_);
|
||||
for (const string& name : def.output()) {
|
||||
local_output_blobs_.push_back(local_ws_.GetBlob(name));
|
||||
CHECK_NOTNULL(local_output_blobs_.back());
|
||||
TORCH_CHECK_NOTNULL(local_output_blobs_.back());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ template <>
|
|||
bool PerplexityOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_EQ(X.dim(), 1);
|
||||
TORCH_DCHECK_EQ(X.dim(), 1);
|
||||
int N = X.dim32(0);
|
||||
|
||||
auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ template <>
|
|||
bool PerplexityOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_EQ(X.dim(), 1);
|
||||
TORCH_DCHECK_EQ(X.dim(), 1);
|
||||
int N = X.dim32(0);
|
||||
|
||||
auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ bool PReluGradientOp<float, CPUContext>::RunOnDevice() {
|
|||
|
||||
CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU");
|
||||
|
||||
DCHECK_EQ(dY.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
|
||||
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
|
||||
auto* dW = Output(1, W.sizes(), at::dtype<float>());
|
||||
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ bool PReluGradientOp<float, CUDAContext>::RunOnDevice() {
|
|||
|
||||
CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU");
|
||||
|
||||
DCHECK_EQ(dY.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
|
||||
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
|
||||
auto* dW = Output(1, W.sizes(), at::dtype<float>());
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ void Decode(
|
|||
|
||||
int sz = output->numel();
|
||||
for (C10_UNUSED const auto i : c10::irange(sz)) {
|
||||
DCHECK_LE(*code_ptr, cb_size);
|
||||
TORCH_DCHECK_LE(*code_ptr, cb_size);
|
||||
*out_ptr++ = cb_ptr[*code_ptr++];
|
||||
}
|
||||
} else {
|
||||
|
|
@ -49,7 +49,7 @@ void Decode(
|
|||
CAFFE_ENFORCE_EQ(cb_size, output->numel());
|
||||
auto* out_ptr = output->template mutable_data<CodebookT>();
|
||||
while (gradient_ptr < gradient_end) {
|
||||
DCHECK_LE(*code_ptr, cb_size);
|
||||
TORCH_DCHECK_LE(*code_ptr, cb_size);
|
||||
out_ptr[*code_ptr++] += *gradient_ptr++;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class Int8AveragePoolOp final : public ConvPoolOpBase<CPUContext> {
|
|||
Y->scale = Y_scale;
|
||||
Y->zero_point = Y_zero_point;
|
||||
|
||||
CHECK_EQ(X.t.dim(), 4);
|
||||
TORCH_CHECK_EQ(X.t.dim(), 4);
|
||||
const int channels = X.t.dim32(3);
|
||||
ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);
|
||||
|
||||
|
|
|
|||
|
|
@ -42,10 +42,10 @@ class Int8ChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
|
|||
this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
const float Y_scale =
|
||||
this->template GetSingleArgument<float>("Y_scale", 1.0f);
|
||||
CHECK_EQ(Y_offset, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
|
||||
CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
|
||||
TORCH_CHECK_EQ(Y_offset, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
|
||||
TORCH_CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
|
||||
|
||||
const auto C = X.t.dim32(3);
|
||||
const auto G = this->group_;
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ class Int8ConcatOp final : public Operator<CPUContext> {
|
|||
if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
|
||||
// Default to C axis
|
||||
axis_ = this->template GetSingleArgument<int>("axis", 3);
|
||||
CHECK_GE(axis_, 0);
|
||||
CHECK_LT(axis_, 4);
|
||||
TORCH_CHECK_GE(axis_, 0);
|
||||
TORCH_CHECK_LT(axis_, 4);
|
||||
} else if (
|
||||
this->template GetSingleArgument<string>("order", "") == "NCHW") {
|
||||
axis_ = this->template GetSingleArgument<int>("axis", 1);
|
||||
CHECK_GE(axis_, 0);
|
||||
CHECK_LT(axis_, 4);
|
||||
TORCH_CHECK_GE(axis_, 0);
|
||||
TORCH_CHECK_LT(axis_, 4);
|
||||
} else {
|
||||
axis_ = this->template GetSingleArgument<int>("axis", 0);
|
||||
}
|
||||
|
|
@ -39,20 +39,20 @@ class Int8ConcatOp final : public Operator<CPUContext> {
|
|||
Y->zero_point = X0.zero_point;
|
||||
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_offset, X0.zero_point);
|
||||
CHECK_EQ(Y_scale, X0.scale);
|
||||
CHECK_GE(X0.zero_point, std::numeric_limits<uint8_t>::min());
|
||||
CHECK_LE(X0.zero_point, std::numeric_limits<uint8_t>::max());
|
||||
TORCH_CHECK_EQ(Y_offset, X0.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X0.scale);
|
||||
TORCH_CHECK_GE(X0.zero_point, std::numeric_limits<uint8_t>::min());
|
||||
TORCH_CHECK_LE(X0.zero_point, std::numeric_limits<uint8_t>::max());
|
||||
auto Y_dims = X0.t.sizes().vec();
|
||||
if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
|
||||
CHECK_EQ(Y_dims.size(), 4);
|
||||
TORCH_CHECK_EQ(Y_dims.size(), 4);
|
||||
}
|
||||
for (const auto i : c10::irange(1, InputSize())) {
|
||||
const auto& Xi = Inputs()[i]->template Get<Int8TensorCPU>();
|
||||
CHECK_EQ(Xi.t.dim(), Y_dims.size());
|
||||
TORCH_CHECK_EQ(Xi.t.dim(), Y_dims.size());
|
||||
for (const auto j : c10::irange(Y_dims.size())) {
|
||||
if (j != axis_) {
|
||||
CHECK_EQ(Xi.t.size(j), Y_dims[j]);
|
||||
TORCH_CHECK_EQ(Xi.t.size(j), Y_dims[j]);
|
||||
}
|
||||
}
|
||||
Y_dims[axis_] += Xi.t.size(axis_);
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class Int8ConvOp final : public ConvPoolOpBase<CPUContext> {
|
|||
const bool isDepthwise = this->group_ > 1 && this->group_ == M &&
|
||||
this->group_ == C && KC == 1 && KH * KW == 9 && dilation_w() == 1;
|
||||
|
||||
CHECK_EQ(Y->t.dim32(3), M);
|
||||
TORCH_CHECK_EQ(Y->t.dim32(3), M);
|
||||
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
|
||||
initQNNPACK();
|
||||
|
||||
|
|
|
|||
|
|
@ -47,14 +47,14 @@ class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
|
|||
|
||||
const auto IC = X.t.size(3);
|
||||
|
||||
CHECK_EQ(IC, W.t.size(0));
|
||||
TORCH_CHECK_EQ(IC, W.t.size(0));
|
||||
const auto KH = W.t.size(1);
|
||||
const auto KW = W.t.size(2);
|
||||
const auto OC = W.t.size(3);
|
||||
|
||||
auto sizes = ConvTransposeUnpoolBase<CPUContext>::GetOutputSize(X.t, OC);
|
||||
ReinitializeTensor(&(Y->t), sizes, at::dtype<uint8_t>().device(CPU));
|
||||
CHECK_EQ(OC, Y->t.size(3));
|
||||
TORCH_CHECK_EQ(OC, Y->t.size(3));
|
||||
|
||||
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
|
||||
initQNNPACK();
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ class Int8FCOp final : public Operator<CPUContext> {
|
|||
// (NxHxW)xC == MxK x (NxK) -> MxN
|
||||
const auto K = X.t.size_from_dim(1);
|
||||
const auto N = W.t.size(0);
|
||||
CHECK_EQ(K, W.t.size(1));
|
||||
CHECK_EQ(N, B.t.numel());
|
||||
TORCH_CHECK_EQ(K, W.t.size(1));
|
||||
TORCH_CHECK_EQ(N, B.t.numel());
|
||||
const auto M = X.t.numel() / K;
|
||||
ReinitializeTensor(&Y->t, {M, N}, at::dtype<uint8_t>().device(CPU));
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ class Int8FlattenOp : public Operator<CPUContext> {
|
|||
auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
|
||||
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_offset, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_EQ(Y_offset, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
Y->scale = Y_scale;
|
||||
Y->zero_point = Y_offset;
|
||||
CAFFE_ENFORCE_GE(
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class Int8GivenTensorFillOp final : public Operator<CPUContext> {
|
|||
}
|
||||
|
||||
bool Fill(Int8TensorCPU* output) {
|
||||
DCHECK_EQ(output->t.numel(), values_.numel())
|
||||
TORCH_DCHECK_EQ(output->t.numel(), values_.numel())
|
||||
<< "output size: " << output->t.numel()
|
||||
<< " given size: " << values_.numel();
|
||||
auto* data = output->t.template mutable_data<uint8_t>();
|
||||
|
|
@ -98,7 +98,7 @@ class Int8GivenIntTensorFillOp final : public Operator<CPUContext> {
|
|||
}
|
||||
|
||||
bool Fill(Int8TensorCPU* output) {
|
||||
DCHECK_EQ(output->t.numel(), values_.numel())
|
||||
TORCH_DCHECK_EQ(output->t.numel(), values_.numel())
|
||||
<< "output size: " << output->t.numel()
|
||||
<< " given size: " << values_.numel();
|
||||
auto* data = output->t.template mutable_data<int32_t>();
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ class Int8LeakyReluOp final : public Operator<CPUContext> {
|
|||
const int32_t Y_zero_point =
|
||||
this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_GE(Y_zero_point, std::numeric_limits<uint8_t>::min());
|
||||
CHECK_LE(Y_zero_point, std::numeric_limits<uint8_t>::max());
|
||||
TORCH_CHECK_GE(Y_zero_point, std::numeric_limits<uint8_t>::min());
|
||||
TORCH_CHECK_LE(Y_zero_point, std::numeric_limits<uint8_t>::max());
|
||||
|
||||
/*
|
||||
* Record quantization parameters for the input, because if the op is
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ class Int8MaxPoolOp final : public ConvPoolOpBase<CPUContext> {
|
|||
const int32_t Y_zero_point =
|
||||
this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_zero_point, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_EQ(Y_zero_point, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
|
||||
CHECK_EQ(X.t.dim(), 4);
|
||||
TORCH_CHECK_EQ(X.t.dim(), 4);
|
||||
const int channels = X.t.dim32(3);
|
||||
ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);
|
||||
|
||||
|
|
|
|||
|
|
@ -34,14 +34,14 @@ class Int8ReluOp final : public Operator<CPUContext> {
|
|||
Y->t.ResizeLike(X.t);
|
||||
Y->scale = X.scale;
|
||||
Y->zero_point = X.zero_point;
|
||||
CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
|
||||
CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
|
||||
TORCH_CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
|
||||
TORCH_CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
|
||||
const int32_t Y_offset =
|
||||
this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
const float Y_scale =
|
||||
this->template GetSingleArgument<float>("Y_scale", 1.0f);
|
||||
CHECK_EQ(Y_offset, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_EQ(Y_offset, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
|
||||
initQNNPACK();
|
||||
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ class Int8ReshapeOp final : public ReshapeOp<uint8_t, CPUContext> {
|
|||
auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
|
||||
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_offset, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_EQ(Y_offset, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
Y->scale = Y_scale;
|
||||
Y->zero_point = Y_offset;
|
||||
DoRunWithTypeImpl<T>(X.t, &Y->t);
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ class Int8ResizeNearestOp final : public Operator<CPUContext> {
|
|||
|
||||
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_offset, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_EQ(Y_offset, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
|
||||
const uint8_t* Xdata = X.t.data<uint8_t>();
|
||||
uint8_t* Ydata = Y->t.mutable_data<uint8_t>();
|
||||
|
|
|
|||
|
|
@ -281,10 +281,10 @@ class Int8RoIAlignOp final : public Operator<CPUContext> {
|
|||
sampling_ratio_(
|
||||
this->template GetSingleArgument<int>("sampling_ratio", -1)),
|
||||
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
|
||||
DCHECK_GT(spatial_scale_, 0);
|
||||
DCHECK_GT(pooled_height_, 0);
|
||||
DCHECK_GT(pooled_width_, 0);
|
||||
DCHECK_GE(sampling_ratio_, 0);
|
||||
TORCH_DCHECK_GT(spatial_scale_, 0);
|
||||
TORCH_DCHECK_GT(pooled_height_, 0);
|
||||
TORCH_DCHECK_GT(pooled_width_, 0);
|
||||
TORCH_DCHECK_GE(sampling_ratio_, 0);
|
||||
// only supports NHWC
|
||||
CAFFE_ENFORCE(order_ == StorageOrder::NHWC);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ class Int8SigmoidOp final : public Operator<CPUContext> {
|
|||
const int32_t Y_zero_point =
|
||||
this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_zero_point, 0);
|
||||
CHECK_EQ(Y_scale, 1.0f / 256.0f);
|
||||
TORCH_CHECK_EQ(Y_zero_point, 0);
|
||||
TORCH_CHECK_EQ(Y_scale, 1.0f / 256.0f);
|
||||
|
||||
/*
|
||||
* Record quantization parameters for the input, because if the op is
|
||||
|
|
|
|||
|
|
@ -76,8 +76,8 @@ class Int8SliceOp final : public SliceOp<CPUContext> {
|
|||
auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
|
||||
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_offset, X.zero_point);
|
||||
CHECK_EQ(Y_scale, X.scale);
|
||||
TORCH_CHECK_EQ(Y_offset, X.zero_point);
|
||||
TORCH_CHECK_EQ(Y_scale, X.scale);
|
||||
Y->scale = Y_scale;
|
||||
Y->zero_point = Y_offset;
|
||||
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ class Int8SoftmaxOp final : public Operator<CPUContext> {
|
|||
const int32_t Y_zero_point =
|
||||
this->template GetSingleArgument<int>("Y_zero_point", 0);
|
||||
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
|
||||
CHECK_EQ(Y_zero_point, 0);
|
||||
CHECK_EQ(Y_scale, 1.0f / 256.0f);
|
||||
TORCH_CHECK_EQ(Y_zero_point, 0);
|
||||
TORCH_CHECK_EQ(Y_scale, 1.0f / 256.0f);
|
||||
|
||||
/*
|
||||
* Record quantization parameters for the input, because if the op is
|
||||
|
|
|
|||
|
|
@ -341,7 +341,7 @@ TEST(Int8, SumRelu) {
|
|||
}
|
||||
|
||||
void setq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
|
||||
CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
|
||||
TORCH_CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
|
||||
for (auto i = 0U; i < vs.size(); ++i) {
|
||||
uint8_t vq = std::max(
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
|
|
@ -354,7 +354,7 @@ void setq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
|
|||
}
|
||||
|
||||
void biassetq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
|
||||
CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
|
||||
TORCH_CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
|
||||
for (auto i = 0U; i < vs.size(); ++i) {
|
||||
int32_t vq = std::max(
|
||||
std::numeric_limits<int32_t>::min(),
|
||||
|
|
|
|||
|
|
@ -91,8 +91,8 @@ inline void QuantizeMultiplierSmallerThanOne(
|
|||
q_fixed /= 2;
|
||||
--*right_shift;
|
||||
}
|
||||
CHECK_GE(*right_shift, 0);
|
||||
CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
|
||||
TORCH_CHECK_GE(*right_shift, 0);
|
||||
TORCH_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
|
||||
*quantized_multiplier = static_cast<int32_t>(q_fixed);
|
||||
}
|
||||
|
||||
|
|
@ -108,8 +108,8 @@ inline void QuantizeMultiplierGreaterThanOne(
|
|||
q_fixed /= 2;
|
||||
++*left_shift;
|
||||
}
|
||||
CHECK_GE(*left_shift, 0);
|
||||
CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
|
||||
TORCH_CHECK_GE(*left_shift, 0);
|
||||
TORCH_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
|
||||
*quantized_multiplier = static_cast<int32_t>(q_fixed);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ class BaseReducer {
|
|||
}
|
||||
|
||||
void observeInput(int input, const Tensor& value, int skip_dims) {
|
||||
DCHECK_EQ(0, input);
|
||||
TORCH_DCHECK_EQ(0, input);
|
||||
auto dims = value.sizes();
|
||||
computeMeta(dims, skip_dims);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ bool SumElementsGradientOp<T, Context>::RunOnDevice()
|
|||
Tensor sum_grad(Input(1), CPU);
|
||||
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<T>());
|
||||
DCHECK_EQ(sum_grad.numel(), 1);
|
||||
TORCH_DCHECK_EQ(sum_grad.numel(), 1);
|
||||
math::Set<T, Context>(
|
||||
dX->numel(),
|
||||
static_cast<T>(
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ template <>
|
|||
bool SumElementsGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
DCHECK_EQ(dY.numel(), 1);
|
||||
TORCH_DCHECK_EQ(dY.numel(), 1);
|
||||
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<float>());
|
||||
SumElementsGradientKernel<float>
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ class RoIAlignGradientOp final : public Operator<Context> {
|
|||
sampling_ratio_(
|
||||
this->template GetSingleArgument<int>("sampling_ratio", -1)),
|
||||
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
|
||||
DCHECK_GT(spatial_scale_, 0);
|
||||
DCHECK_GT(pooled_height_, 0);
|
||||
DCHECK_GT(pooled_width_, 0);
|
||||
DCHECK_GE(sampling_ratio_, 0);
|
||||
TORCH_DCHECK_GT(spatial_scale_, 0);
|
||||
TORCH_DCHECK_GT(pooled_height_, 0);
|
||||
TORCH_DCHECK_GT(pooled_width_, 0);
|
||||
TORCH_DCHECK_GE(sampling_ratio_, 0);
|
||||
}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ class RoIAlignOp final : public Operator<Context> {
|
|||
OP_SINGLE_ARG(int, "pooled_w", pooled_w_, 1),
|
||||
OP_SINGLE_ARG(int, "sampling_ratio", sampling_ratio_, -1),
|
||||
OP_SINGLE_ARG(bool, "aligned", aligned_, false) {
|
||||
DCHECK_GT(spatial_scale_, 0.0f);
|
||||
DCHECK_GT(pooled_h_, 0);
|
||||
DCHECK_GT(pooled_w_, 0);
|
||||
TORCH_DCHECK_GT(spatial_scale_, 0.0f);
|
||||
TORCH_DCHECK_GT(pooled_h_, 0);
|
||||
TORCH_DCHECK_GT(pooled_w_, 0);
|
||||
DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ class RoIAlignRotatedGradientOp final : public Operator<Context> {
|
|||
sampling_ratio_(
|
||||
this->template GetSingleArgument<int>("sampling_ratio", -1)),
|
||||
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
|
||||
DCHECK_GT(spatial_scale_, 0);
|
||||
DCHECK_GT(pooled_height_, 0);
|
||||
DCHECK_GT(pooled_width_, 0);
|
||||
DCHECK_GE(sampling_ratio_, 0);
|
||||
TORCH_DCHECK_GT(spatial_scale_, 0);
|
||||
TORCH_DCHECK_GT(pooled_height_, 0);
|
||||
TORCH_DCHECK_GT(pooled_width_, 0);
|
||||
TORCH_DCHECK_GE(sampling_ratio_, 0);
|
||||
}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
|
|
|
|||
|
|
@ -27,10 +27,10 @@ class RoIAlignRotatedOp final : public Operator<Context> {
|
|||
sampling_ratio_(
|
||||
this->template GetSingleArgument<int>("sampling_ratio", -1)),
|
||||
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
|
||||
DCHECK_GT(spatial_scale_, 0);
|
||||
DCHECK_GT(pooled_height_, 0);
|
||||
DCHECK_GT(pooled_width_, 0);
|
||||
DCHECK_GE(sampling_ratio_, 0);
|
||||
TORCH_DCHECK_GT(spatial_scale_, 0);
|
||||
TORCH_DCHECK_GT(pooled_height_, 0);
|
||||
TORCH_DCHECK_GT(pooled_width_, 0);
|
||||
TORCH_DCHECK_GE(sampling_ratio_, 0);
|
||||
DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
|
||||
}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
|
|
|||
|
|
@ -137,10 +137,10 @@ bool SliceImpl(
|
|||
src_offset_bytes + i * src_block_size_bytes;
|
||||
char* local_dst_offset_bytes =
|
||||
dst_offset_bytes + i * dst_block_size_bytes;
|
||||
DCHECK_LE(
|
||||
TORCH_DCHECK_LE(
|
||||
static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
|
||||
static_cast<void*>(src_bytes + src_nbytes));
|
||||
DCHECK_LE(
|
||||
TORCH_DCHECK_LE(
|
||||
static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
|
||||
static_cast<void*>(dst_bytes + dst_nbytes));
|
||||
context->CopyItemsSameDevice(
|
||||
|
|
@ -183,10 +183,10 @@ bool SliceImpl(
|
|||
src_offset_bytes + i * src_block_size_bytes;
|
||||
char* local_dst_offset_bytes =
|
||||
dst_offset_bytes + i * dst_block_size_bytes;
|
||||
DCHECK_LE(
|
||||
TORCH_DCHECK_LE(
|
||||
local_src_offset_bytes + src_block_size_bytes,
|
||||
src_bytes + src_nbytes);
|
||||
DCHECK_LE(
|
||||
TORCH_DCHECK_LE(
|
||||
local_dst_offset_bytes + src_block_size_bytes,
|
||||
dst_bytes + dst_nbytes);
|
||||
context->CopyItemsSameDevice(
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
|
|||
const int N = Y.size_to_dim(canonical_axis);
|
||||
const int D = Y.size_from_dim(canonical_axis);
|
||||
|
||||
CHECK_EQ(Y.sizes(), dY.sizes());
|
||||
TORCH_CHECK_EQ(Y.sizes(), dY.sizes());
|
||||
auto* dX = Output(0, Y.sizes(), at::dtype<T>());
|
||||
auto* dX_data = dX->template mutable_data<T>();
|
||||
if (N == 0 || D == 0) {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ bool SoftplusGradientOp<float, CPUContext>::RunOnDevice() {
|
|||
auto& Y = Input(0);
|
||||
auto& dY = Input(1);
|
||||
|
||||
DCHECK_EQ(dY.numel(), Y.numel());
|
||||
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
|
||||
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
|
||||
|
||||
const float* Ydata = Y.data<float>();
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ template <>
|
|||
bool SoftplusOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
|
||||
DCHECK_GT(X.numel(), 0);
|
||||
TORCH_DCHECK_GT(X.numel(), 0);
|
||||
auto* Y = Output(0, X.sizes(), at::dtype<float>());
|
||||
SoftplusKernel<float>
|
||||
<<<CAFFE_GET_BLOCKS(X.numel()),
|
||||
|
|
@ -42,8 +42,8 @@ bool SoftplusGradientOp<float, CUDAContext>::RunOnDevice() {
|
|||
auto& Y = Input(0);
|
||||
auto& dY = Input(1);
|
||||
|
||||
DCHECK_GT(Y.numel(), 0);
|
||||
DCHECK_EQ(dY.numel(), Y.numel());
|
||||
TORCH_DCHECK_GT(Y.numel(), 0);
|
||||
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
|
||||
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
|
||||
SoftplusGradientKernel<float>
|
||||
<<<CAFFE_GET_BLOCKS(Y.numel()),
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ template<>
|
|||
bool SummarizeOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
const int N = X.numel();
|
||||
DCHECK_GT(N, 0);
|
||||
TORCH_DCHECK_GT(N, 0);
|
||||
|
||||
// TODO(Yangqing): Any better way to avoid having to const cast?
|
||||
thrust::device_ptr<float> Xdata(const_cast<float*>(X.data<float>()));
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user