Replace all CHECK_ and DCHECK_ with TORCH_* macros (#82032)

Avoid exposing defines that conflict with google logging, since this blocks external usage of libtorch in certain cases.

All the 'interesting' changes should be in these two files, and the rest should just be mechanical changes via sed.
c10/util/logging_is_not_google_glog.h
c10/util/logging_is_google_glog.h

Fixes https://github.com/pytorch/pytorch/issues/81415

cc @miladm @malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82032
Approved by: https://github.com/soumith, https://github.com/miladm
This commit is contained in:
Will Constable 2022-07-26 01:20:44 +00:00 committed by PyTorch MergeBot
parent cab819222a
commit 4f34cd6d1e
150 changed files with 680 additions and 612 deletions

View File

@ -35,7 +35,7 @@ void LayerNormKernelImplInternal(
Tensor* rstd) {
using T_ACC = vec::vec_scalar_t<T>;
using Vec = vec::Vectorized<T_ACC>;
DCHECK_EQ(X.numel(), M * N);
TORCH_DCHECK_EQ(X.numel(), M * N);
DCHECK(!gamma.defined() || gamma.numel() == N);
DCHECK(!beta.defined() || beta.numel() == N);
const T* X_data = X.data_ptr<T>();
@ -117,10 +117,10 @@ void LayerNormBackwardKernelImplInternal(
Tensor* dbeta) {
using T_ACC = vec::vec_scalar_t<T>;
using Vec = vec::Vectorized<T_ACC>;
DCHECK_EQ(dY.numel(), M * N);
DCHECK_EQ(X.numel(), M * N);
DCHECK_EQ(mean.numel(), M);
DCHECK_EQ(rstd.numel(), M);
TORCH_DCHECK_EQ(dY.numel(), M * N);
TORCH_DCHECK_EQ(X.numel(), M * N);
TORCH_DCHECK_EQ(mean.numel(), M);
TORCH_DCHECK_EQ(rstd.numel(), M);
DCHECK(!gamma.defined() || gamma.numel() == N);
const T* dY_data = dY.template data_ptr<T>();
const T* X_data = X.template data_ptr<T>();

View File

@ -94,7 +94,7 @@ Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) {
TORCH_CHECK(
status == pytorch_qnnp_status_success,
"failed to create QNNPACK Softmax operator");
CHECK_NOTNULL(softargmax);
TORCH_CHECK_NOTNULL(softargmax);
status = pytorch_qnnp_setup_softargmax_nc_q8(
softargmax, batch_size, input, input_stride, output, output_stride);

View File

@ -47,7 +47,7 @@ static void BM_deep_wide_jit_graph_executor(benchmark::State& state) {
std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
TORCH_CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
mod.forward(inputs);
for (auto _ : state) {
@ -65,7 +65,7 @@ static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
TORCH_CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
mod.forward(inputs);
for (auto _ : state) {

View File

@ -173,7 +173,7 @@ int loadInput(
LOG(INFO) << "Running on GPU.";
#ifdef __CUDA_ARCH__
caffe2::TensorCUDA* tensor = blob->GetMutable<caffe2::TensorCUDA>();
CHECK_NOTNULL(tensor);
TORCH_CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
tensor->mutable_data<uint8_t>();
@ -189,17 +189,17 @@ int loadInput(
if (input_type_list[i] == "uint8_t") {
caffe2::int8::Int8TensorCPU* tensor =
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
CHECK_NOTNULL(tensor);
TORCH_CHECK_NOTNULL(tensor);
tensor->t.Resize(input_dims);
tensor->t.mutable_data<uint8_t>();
} else if (input_type_list[i] == "float") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
TORCH_CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else if (input_type_list[i] == "int") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
TORCH_CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<int>();
} else {
@ -495,7 +495,7 @@ int benchmark(
net_def.set_name("benchmark");
}
caffe2::NetBase* net = workspace->CreateNet(net_def);
CHECK_NOTNULL(net);
TORCH_CHECK_NOTNULL(net);
runNetwork(
workspace,
net,

View File

@ -591,7 +591,7 @@ void runNetwork(
}
caffe2::NetBase* net = workspace->CreateNet(net_def);
CHECK_NOTNULL(net);
TORCH_CHECK_NOTNULL(net);
LOG(INFO) << "Starting benchmark.";
caffe2::ObserverConfig::initSampleRate(1, 1, 1, run_individual, warmup);

View File

@ -251,7 +251,7 @@ void ConvertImageDataset(
// Synthesize key for this entry
auto key_len = snprintf(
key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str());
DCHECK_LE(key_len, sizeof(key_cstr));
TORCH_DCHECK_LE(key_len, sizeof(key_cstr));
// Put in db
transaction->Put(string(key_cstr), std::move(value));

View File

@ -136,12 +136,12 @@ int main(int argc, char** argv) {
if (input_type_list[i] == "uint8_t") {
caffe2::int8::Int8TensorCPU* tensor =
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
CHECK_NOTNULL(tensor);
TORCH_CHECK_NOTNULL(tensor);
tensor->t.Resize(input_dims);
tensor->t.mutable_data<uint8_t>();
} else if (input_type_list[i] == "float") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
TORCH_CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else {
@ -184,7 +184,7 @@ int main(int argc, char** argv) {
}
caffe2::NetBase* net = workspace->CreateNet(net_def);
CHECK_NOTNULL(net);
TORCH_CHECK_NOTNULL(net);
CAFFE_ENFORCE(net->Run());
net->TEST_Benchmark(FLAGS_warmup, FLAGS_iter, FLAGS_run_individual);

View File

@ -141,7 +141,7 @@ TEST(LoggingTest, Join) {
TEST(LoggingTest, TestDanglingElse) {
if (true)
DCHECK_EQ(1, 1);
TORCH_DCHECK_EQ(1, 1);
else
GTEST_FAIL();
}

View File

@ -180,7 +180,7 @@ using EnforceNotMet = ::c10::Error;
* With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))`
*
* Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided
* too. Please use them instead of CHECK_EQ and friends for failures in
* too. Please use them instead of TORCH_CHECK_EQ and friends for failures in
* user-provided input.
*/

View File

@ -64,10 +64,10 @@ class Registry {
const RegistryPriority priority = REGISTRY_DEFAULT) {
std::lock_guard<std::mutex> lock(register_mutex_);
// The if statement below is essentially the same as the following line:
// CHECK_EQ(registry_.count(key), 0) << "Key " << key
// TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
// << " registered twice.";
// However, CHECK_EQ depends on google logging, and since registration is
// carried out at static initialization time, we do not want to have an
// However, TORCH_CHECK_EQ depends on google logging, and since registration
// is carried out at static initialization time, we do not want to have an
// explicit dependency on glog's initialization function.
if (registry_.count(key) != 0) {
auto cur_priority = priority_[key];

View File

@ -50,6 +50,71 @@ INSTANTIATE_FOR_CONTAINER(set)
#include <glog/logging.h>
// Additional macros on top of glog
#ifndef NDEBUG
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_CHECK_EQ(val1, val2) \
while (false) \
CHECK_EQ(val1, val2)
#define TORCH_CHECK_NE(val1, val2) \
while (false) \
CHECK_NE(val1, val2)
#define TORCH_CHECK_LE(val1, val2) \
while (false) \
CHECK_LE(val1, val2)
#define TORCH_CHECK_LT(val1, val2) \
while (false) \
CHECK_LT(val1, val2)
#define TORCH_CHECK_GE(val1, val2) \
while (false) \
CHECK_GE(val1, val2)
#define TORCH_CHECK_GT(val1, val2) \
while (false) \
CHECK_GT(val1, val2)
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
DCHECK_GT(val1, val2)
#endif // NDEBUG
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
DCHECK_NOTNULL(val)
#endif // NDEBUG
// Log with source location information override (to be used in generic
// warning/error handlers implemented as functions, not macros)

View File

@ -61,8 +61,8 @@ void LogMessageFatal(const char* file, int line, const T& message) {
MessageLogger(file, line, GLOG_FATAL).stream() << message;
}
// Helpers for CHECK_NOTNULL(). Two are necessary to support both raw pointers
// and smart pointers.
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
// pointers and smart pointers.
template <typename T>
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
if (t == nullptr) {
@ -136,63 +136,63 @@ static_assert(
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
#endif // NDEBUG
#define CHECK_OP(val1, val2, op) \
#define TORCH_CHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << ") "
// Check_op macro definitions
#define CHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==)
#define CHECK_NE(val1, val2) CHECK_OP(val1, val2, !=)
#define CHECK_LE(val1, val2) CHECK_OP(val1, val2, <=)
#define CHECK_LT(val1, val2) CHECK_OP(val1, val2, <)
#define CHECK_GE(val1, val2) CHECK_OP(val1, val2, >=)
#define CHECK_GT(val1, val2) CHECK_OP(val1, val2, >)
// TORCH_CHECK_OP macro definitions
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#ifndef NDEBUG
// Debug only versions of CHECK_OP macros.
#define DCHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==)
#define DCHECK_NE(val1, val2) CHECK_OP(val1, val2, !=)
#define DCHECK_LE(val1, val2) CHECK_OP(val1, val2, <=)
#define DCHECK_LT(val1, val2) CHECK_OP(val1, val2, <)
#define DCHECK_GE(val1, val2) CHECK_OP(val1, val2, >=)
#define DCHECK_GT(val1, val2) CHECK_OP(val1, val2, >)
// Debug only versions of TORCH_CHECK_OP macros.
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define DCHECK_EQ(val1, val2) \
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
CHECK_OP(val1, val2, ==)
#define DCHECK_NE(val1, val2) \
TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
CHECK_OP(val1, val2, !=)
#define DCHECK_LE(val1, val2) \
TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
CHECK_OP(val1, val2, <=)
#define DCHECK_LT(val1, val2) \
TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
CHECK_OP(val1, val2, <)
#define DCHECK_GE(val1, val2) \
TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
CHECK_OP(val1, val2, >=)
#define DCHECK_GT(val1, val2) \
TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
CHECK_OP(val1, val2, >)
TORCH_CHECK_OP(val1, val2, >)
#endif // NDEBUG
// Check that a pointer is not null.
#define CHECK_NOTNULL(val) \
#define TORCH_CHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#ifndef NDEBUG
// Debug only version of CHECK_NOTNULL
#define DCHECK_NOTNULL(val) \
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#else // !NDEBUG
// Optimized version - generates no code.
#define DCHECK_NOTNULL(val) \
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
CHECK_NOTNULL(val)
TORCH_CHECK_NOTNULL(val)
#endif // NDEBUG
// ---------------------- Support for std objects --------------------------

View File

@ -78,7 +78,7 @@ class Fp16FCAccOp final : public Operator<Context> {
Y_shape_cache_ = X.sizes().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
Y_shape_cache_.resize(canonical_axis + 1);
Y_shape_cache_[canonical_axis] = N;
Y->Resize(Y_shape_cache_);

View File

@ -91,7 +91,7 @@ NCCLContext* getNCCLContext(const NCCLExecution& ex) {
LOG(INFO) << "Creating NCCLContext for key: " << key;
contexts[key].reset(new NCCLContext(ex));
}
return CHECK_NOTNULL(contexts[key].get());
return TORCH_CHECK_NOTNULL(contexts[key].get());
}
template <typename T>
@ -153,7 +153,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
auto& comm = comms[i];
auto& stream = streams[i];
DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data()));
TORCH_DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data()));
CUDA_ENFORCE(cudaStreamWaitEvent(stream, context->master_event_, 0));
f(ctx, comm, stream);
}

View File

@ -36,7 +36,7 @@ class OpenCLContext final {
public:
explicit OpenCLContext();
explicit OpenCLContext(const DeviceOption& option) {
DCHECK_EQ(option.device_type(), PROTO_OPENCL);
TORCH_DCHECK_EQ(option.device_type(), PROTO_OPENCL);
OpenCLContext();
}
~OpenCLContext() {}

View File

@ -9,7 +9,7 @@ CuDNNWrapper::PerGPUCuDNNStates& CuDNNWrapper::cudnn_states() {
// New it (never delete) to avoid calling the destructors on process
// exit and racing against the CUDA shutdown sequence.
static auto* p = new CuDNNWrapper::PerGPUCuDNNStates();
CHECK_NOTNULL(p);
TORCH_CHECK_NOTNULL(p);
return *p;
}

View File

@ -437,7 +437,7 @@ CUDAContext::CUDAContext(const DeviceOption& option)
option.has_random_seed() ? option.random_seed()
: RandomNumberSeed()) {
static Caffe2CudaInitializerHelper g_cuda_initializer_;
DCHECK_EQ(option.device_type(), PROTO_CUDA);
TORCH_DCHECK_EQ(option.device_type(), PROTO_CUDA);
}
CUDAContext::~CUDAContext() {

View File

@ -230,7 +230,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
CURAND_ENFORCE(
curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_));
CHECK_NOTNULL(curand_generator_);
TORCH_CHECK_NOTNULL(curand_generator_);
}
CURAND_ENFORCE(curandSetStream(curand_generator_, cuda_stream()));
return curand_generator_;

View File

@ -65,7 +65,7 @@ TEST(CUDAContextTest, MemoryPoolAllocateDealloc) {
cudaStream_t getStreamForHandle(cublasHandle_t handle) {
cudaStream_t stream = nullptr;
CUBLAS_ENFORCE(cublasGetStream(handle, &stream));
CHECK_NOTNULL(stream);
TORCH_CHECK_NOTNULL(stream);
return stream;
}

View File

@ -172,7 +172,7 @@ class CuDNNWrapper {
if (!sync_state.state.get()) {
sync_state.state.reset(new CuDNNState(context_->device_id()));
}
CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
}
protected:

View File

@ -25,7 +25,7 @@ MIOPENWrapper::PerGPUMIOPENStates& MIOPENWrapper::miopen_states()
// New it (never delete) to avoid calling the destructors on process
// exit and racing against the CUDA shutdown sequence.
static auto* p = new MIOPENWrapper::PerGPUMIOPENStates();
CHECK_NOTNULL(p);
TORCH_CHECK_NOTNULL(p);
return *p;
}

View File

@ -138,7 +138,7 @@ class MIOPENWrapper
{
sync_state.state.reset(new MIOPENState(context_->device_id()));
}
CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f);
TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f);
}
protected:

View File

@ -79,7 +79,7 @@ void checkChainingAndRun(
net_def.set_num_workers(4);
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
auto* dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
CHECK_NOTNULL(dag);
TORCH_CHECK_NOTNULL(dag);
const auto& chains = dag->TEST_execution_chains();
EXPECT_EQ(chains, expected);
testExecution(net, net_def.op().size());

View File

@ -152,7 +152,7 @@ void checkChainingAndRun(
net_def.set_num_workers(4);
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
auto* dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
CHECK_NOTNULL(dag);
TORCH_CHECK_NOTNULL(dag);
const auto& chains = dag->TEST_execution_chains();
EXPECT_TRUE(chains == expected);
testExecution(net, net_def.op().size());
@ -175,7 +175,7 @@ void checkNumChainsAndRun(const char* spec, const int expected_num_chains) {
{
std::unique_ptr<NetBase> net(CreateNet(net_def, &ws));
auto* dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
CHECK_NOTNULL(dag);
TORCH_CHECK_NOTNULL(dag);
const auto& chains = dag->TEST_execution_chains();
EXPECT_EQ(expected_num_chains, chains.size());
testExecution(net, net_def.op().size());
@ -1108,7 +1108,7 @@ void testProfDAGNetErrorCase(bool test_error) {
// with failing op - prof_dag handles invalid runs and returns empty stats,
// without - returns stats for each op
auto* prof_dag = dynamic_cast_if_rtti<AsyncNetBase*>(net.get());
CHECK_NOTNULL(prof_dag);
TORCH_CHECK_NOTNULL(prof_dag);
auto stats_proto = prof_dag->GetPerOperatorCost();
ASSERT_EQ(
stats_proto.stats_size(), test_error ? 0 : net->GetOperators().size());

View File

@ -82,7 +82,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
outputs_.reserve(operator_def.output_size());
for (const string& output_str : operator_def.output()) {
outputs_.push_back(CHECK_NOTNULL(ws->CreateBlob(output_str)));
outputs_.push_back(TORCH_CHECK_NOTNULL(ws->CreateBlob(output_str)));
}
type_ = operator_def.type();

View File

@ -157,7 +157,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
!std::is_same<T, Tensor>::value,
"You should use Input<Tensor>(int, DeviceType) for "
"Tensor.");
DCHECK_LT((size_t)idx, inputs_.size());
TORCH_DCHECK_LT((size_t)idx, inputs_.size());
try {
return inputs_.at(idx)->template Get<T>();
} catch (::caffe2::EnforceNotMet& enf) {
@ -178,7 +178,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
static_assert(
std::is_same<T, Tensor>::value,
"Input(int, DeviceType) is only available for Tensor");
DCHECK_LT((size_t)idx, inputs_.size());
TORCH_DCHECK_LT((size_t)idx, inputs_.size());
try {
// TODO(jerryzh): We'll need to check device type in Get<T>() later
// Get<T>() -> Get<T>(type)
@ -193,7 +193,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
}
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
DCHECK_LT(0U, newstyle_inputs_.size());
TORCH_DCHECK_LT(0U, newstyle_inputs_.size());
IValue ival;
if (newstyle_inputs_[0].isTensorList()) {
// if the first input is a tensor list, we get input tensors by indexing
@ -201,12 +201,12 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
// are accessible as inputs. any hypothetical input tensors that come
// after the list are not accessible.
auto tensorList = newstyle_inputs_[0].toTensorVector();
DCHECK_LT((size_t)idx, tensorList.size());
TORCH_DCHECK_LT((size_t)idx, tensorList.size());
ival = tensorList[idx];
} else {
// if the first input is not a tensor list, we get input tensors by
// indexing into the inputs.
DCHECK_LT((size_t)idx, newstyle_inputs_.size());
TORCH_DCHECK_LT((size_t)idx, newstyle_inputs_.size());
ival = newstyle_inputs_[idx];
}
CAFFE_ENFORCE(

View File

@ -24,8 +24,8 @@ class SleepOp final : public Operator<CPUContext> {
SleepOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
ms_(OperatorBase::GetSingleArgument<int>("ms", 1000)) {
DCHECK_GT(ms_, 0);
DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?";
TORCH_DCHECK_GT(ms_, 0);
TORCH_DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?";
}
bool RunOnDevice() override {

View File

@ -187,8 +187,8 @@ class C10_EXPORT QTensor {
* Returns the i-th dimension of the qtensor in int.
*/
inline int dim32(const int i) const {
DCHECK_LT(i, static_cast<int>(dims_.size())) << "Exceeding ndim limit " << dims_.size();
DCHECK_GE(i, 0) << "Cannot have negative index";
TORCH_DCHECK_LT(i, static_cast<int>(dims_.size())) << "Exceeding ndim limit " << dims_.size();
TORCH_DCHECK_GE(i, 0) << "Cannot have negative index";
CAFFE_ENFORCE_LT(dims_[i], std::numeric_limits<int>::max());
return static_cast<int>(dims_[i]);
}

View File

@ -163,8 +163,8 @@ string MaxPoolRTCFunction::GetSource(
stride_w,
pad_t,
pad_l);
DCHECK_GE(nbytes, 0);
DCHECK_LT(nbytes, 65536);
TORCH_DCHECK_GE(nbytes, 0);
TORCH_DCHECK_LT(nbytes, 65536);
return string(buffer);
}
@ -202,8 +202,8 @@ string MaxPoolGradientRTCFunction::GetSource(
stride_w,
pad_t,
pad_l);
DCHECK_GE(nbytes, 0);
DCHECK_LT(nbytes, 65536);
TORCH_DCHECK_GE(nbytes, 0);
TORCH_DCHECK_LT(nbytes, 65536);
return string(buffer);
}

View File

@ -55,7 +55,7 @@ FileStoreHandler::FileStoreHandler(
auto ret = mkdir(basePath_.c_str(), 0777);
#endif // defined(_MSC_VER)
if (ret == -1) {
CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno);
TORCH_CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno);
}
}
@ -71,7 +71,7 @@ std::string FileStoreHandler::realPath(const std::string& path) {
std::array<char, PATH_MAX> buf;
auto ret = realpath(path.c_str(), buf.data());
#endif
CHECK_EQ(buf.data(), ret) << "realpath: " << strerror(errno);
TORCH_CHECK_EQ(buf.data(), ret) << "realpath: " << strerror(errno);
return std::string(buf.data());
}
@ -152,7 +152,7 @@ bool FileStoreHandler::check(const std::vector<std::string>& names) {
if (fd == -1) {
// Only deal with files that don't exist.
// Anything else is a problem.
CHECK_EQ(errno, ENOENT);
TORCH_CHECK_EQ(errno, ENOENT);
// One of the paths doesn't exist; return early
return false;

View File

@ -145,10 +145,10 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
const auto& U = Input(1);
const auto& V = Input(2);
const auto& dY = Input(3);
DCHECK_GE(X.dim(), 1);
DCHECK_GE(U.dim(), 2);
DCHECK_GE(V.dim(), 2);
DCHECK_LE(dY.dim(), 2);
TORCH_DCHECK_GE(X.dim(), 1);
TORCH_DCHECK_GE(U.dim(), 2);
TORCH_DCHECK_GE(V.dim(), 2);
TORCH_DCHECK_LE(dY.dim(), 2);
// batch size
int M = X.dim() > 1 ? X.dim32(0) : 1;
// Feature dimension
@ -156,13 +156,13 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
// number of outputs.
int N = U.dim32(0);
int middle = U.dim32(1);
DCHECK_EQ(K, V.dim32(0));
TORCH_DCHECK_EQ(K, V.dim32(0));
if (dY.dim() > 1) {
DCHECK_EQ(M, dY.dim32(0));
DCHECK_EQ(N, dY.dim32(1));
TORCH_DCHECK_EQ(M, dY.dim32(0));
TORCH_DCHECK_EQ(N, dY.dim32(1));
} else {
DCHECK_EQ(X.dim(), 1);
DCHECK_EQ(N, dY.numel());
TORCH_DCHECK_EQ(X.dim(), 1);
TORCH_DCHECK_EQ(N, dY.numel());
}
auto* dU = Output(0, U.sizes(), at::dtype<T>());

View File

@ -17,6 +17,7 @@
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
#include <c10/util/Logging.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
@ -249,9 +250,9 @@ class FullyConnectedPruneGradientOp : public Operator<Context> {
auto& thres = Input(6);
// TODO(wyiming): check comp_lb is a float
auto& comp_lb = Input(7);
DCHECK_GE(X.dim(), 1);
DCHECK_GE(W.dim(), 2);
DCHECK_LE(dY.dim(), 2);
TORCH_DCHECK_GE(X.dim(), 1);
TORCH_DCHECK_GE(W.dim(), 2);
TORCH_DCHECK_LE(dY.dim(), 2);
// batch size
int M = X.dim() > 1 ? X.dim32(0) : 1;
// Feature dimension
@ -263,17 +264,17 @@ class FullyConnectedPruneGradientOp : public Operator<Context> {
// TODO(wyiming): this threshold should be
// based on distribution of the layer weight
float thr = 0.01;
DCHECK_EQ(Mask.dim32(0), W.dim32(0));
DCHECK_EQ(Mask.dim32(1), W.dim32(1));
DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
DCHECK_EQ(K, W.numel() / W.dim32(0));
TORCH_DCHECK_EQ(Mask.dim32(0), W.dim32(0));
TORCH_DCHECK_EQ(Mask.dim32(1), W.dim32(1));
TORCH_DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
TORCH_DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
TORCH_DCHECK_EQ(K, W.numel() / W.dim32(0));
if (dY.dim() > 1) {
DCHECK_EQ(M, dY.dim32(0));
DCHECK_EQ(N, dY.dim32(1));
TORCH_DCHECK_EQ(M, dY.dim32(0));
TORCH_DCHECK_EQ(N, dY.dim32(1));
} else {
DCHECK_EQ(X.dim(), 1);
DCHECK_EQ(N, dY.numel());
TORCH_DCHECK_EQ(X.dim(), 1);
TORCH_DCHECK_EQ(N, dY.numel());
}
auto* dW = Output(0, W.sizes(), at::dtype<T>());

View File

@ -15,10 +15,10 @@ class IDEEPLRNOp final : public IDEEPOperator {
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
DCHECK_GT(size_, 0);
DCHECK_EQ(size_ % 2, 1);
DCHECK_GT(alpha_, 0);
DCHECK_GT(beta_, 0);
TORCH_DCHECK_GT(size_, 0);
TORCH_DCHECK_EQ(size_ % 2, 1);
TORCH_DCHECK_GT(alpha_, 0);
TORCH_DCHECK_GT(beta_, 0);
}
~IDEEPLRNOp() override = default;
@ -52,10 +52,10 @@ class IDEEPLRNGradientOp final : public IDEEPOperator {
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
DCHECK_GT(size_, 0);
DCHECK_EQ(size_ % 2, 1);
DCHECK_GT(alpha_, 0);
DCHECK_GT(beta_, 0);
TORCH_DCHECK_GT(size_, 0);
TORCH_DCHECK_EQ(size_ % 2, 1);
TORCH_DCHECK_GT(alpha_, 0);
TORCH_DCHECK_GT(beta_, 0);
}
~IDEEPLRNGradientOp() override = default;

View File

@ -60,7 +60,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
parent_name += "_cpu_output_blob_" + base_def_.type();
}
local_output_blobs_.push_back(ws->CreateBlob(parent_name));
CHECK_NOTNULL(local_output_blobs_.back());
TORCH_CHECK_NOTNULL(local_output_blobs_.back());
forwarded_output_blobs[base_def_.output(i)] = parent_name;
output_inplace_.push_back(false);
for (const string &input_name : base_def_.input()) {
@ -74,7 +74,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
// Set up the symbols for the local workspace.
for (const string& name : base_def_.input()) {
local_input_blobs_.push_back(local_ws_->CreateBlob(name));
CHECK_NOTNULL(local_input_blobs_.back());
TORCH_CHECK_NOTNULL(local_input_blobs_.back());
}
input_share_.resize(local_input_blobs_.size(), false);
base_op_.reset(new CPUOp(base_def_, local_ws_.get()));

View File

@ -51,7 +51,7 @@ class IDEEPInt8GivenTensorFillOp final : public IDEEPOperator {
auto data_type = zero_point_ == 0 ? idtype::u8 : idtype::s8;
output->init({shape_, data_type});
DCHECK_EQ(output->get_nelems(), values_.numel())
TORCH_DCHECK_EQ(output->get_nelems(), values_.numel())
<< "output size: " << output->get_nelems()
<< " given size: " << values_.numel();
@ -121,7 +121,7 @@ class IDEEPInt8GivenIntTensorFillOp final : public IDEEPOperator {
auto* output = Output(OUTPUT);
output->init({shape_, idtype::s32});
output->set_scale(ConvertScales(scales_));
DCHECK_EQ(output->get_nelems(), values_.numel())
TORCH_DCHECK_EQ(output->get_nelems(), values_.numel())
<< "output size: " << output->get_nelems()
<< " given size: " << values_.numel();

View File

@ -30,10 +30,10 @@ class IDEEPSpatialBNOp final : public IDEEPOperator {
const auto& bias = Input(BIAS);
auto* Y = Output(OUTPUT);
DCHECK_EQ(scale.ndims(), 1);
DCHECK_EQ(bias.ndims(), 1);
DCHECK_EQ(scale.get_dim(0), X.get_dim(1));
DCHECK_EQ(bias.get_dim(0), X.get_dim(1));
TORCH_DCHECK_EQ(scale.ndims(), 1);
TORCH_DCHECK_EQ(bias.ndims(), 1);
TORCH_DCHECK_EQ(scale.get_dim(0), X.get_dim(1));
TORCH_DCHECK_EQ(bias.get_dim(0), X.get_dim(1));
if (is_test_) {
const auto& est_mean = Input(EST_MEAN);

View File

@ -26,7 +26,7 @@ OPERATOR_SCHEMA(ImageInput)
int batch_size = helper.GetSingleArgument<int>("batch_size", 0);
int crop = helper.GetSingleArgument<int>("crop", -1);
int color = helper.GetSingleArgument<int>("color", 1);
CHECK_GT(crop, 0);
TORCH_CHECK_GT(crop, 0);
out[0] = CreateTensorShape(
vector<int>{batch_size, crop, crop, color ? 3 : 1},
TensorProto::FLOAT);

View File

@ -530,8 +530,8 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
if (protos.protos_size() == end + 1) {
// We have bounding box information
const TensorProto& bounding_proto = protos.protos(end);
DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
DCHECK_EQ(bounding_proto.int32_data_size(), 4);
TORCH_DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
TORCH_DCHECK_EQ(bounding_proto.int32_data_size(), 4);
info.bounding_params.valid = true;
info.bounding_params.ymin = bounding_proto.int32_data(0);
info.bounding_params.xmin = bounding_proto.int32_data(1);
@ -541,7 +541,7 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
if (image_proto.data_type() == TensorProto::STRING) {
// encoded image string.
DCHECK_EQ(image_proto.string_data_size(), 1);
TORCH_DCHECK_EQ(image_proto.string_data_size(), 1);
const string& encoded_image_str = image_proto.string_data(0);
int encoded_size = encoded_image_str.size();
// We use a cv::Mat to wrap the encoded str so we do not need a copy.
@ -582,7 +582,7 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
// TODO: if image decoding was unsuccessful, set label to 0
if (label_proto.data_type() == TensorProto::FLOAT) {
if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
DCHECK_EQ(label_proto.float_data_size(), 1);
TORCH_DCHECK_EQ(label_proto.float_data_size(), 1);
prefetched_label_.mutable_data<float>()[item_id] =
label_proto.float_data(0);
} else if (label_type_ == MULTI_LABEL_SPARSE) {
@ -614,7 +614,7 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
}
} else if (label_proto.data_type() == TensorProto::INT32) {
if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
DCHECK_EQ(label_proto.int32_data_size(), 1);
TORCH_DCHECK_EQ(label_proto.int32_data_size(), 1);
prefetched_label_.mutable_data<int>()[item_id] =
label_proto.int32_data(0);
} else if (label_type_ == MULTI_LABEL_SPARSE) {

View File

@ -284,7 +284,7 @@ constexpr int computeMPSAlignOffset(int kernel, int pad) {
size_t ComputeStartIndex(
const TensorCPU& tensor,
const std::vector<int>& index) {
DCHECK_EQ(index.size(), tensor.dim());
TORCH_DCHECK_EQ(index.size(), tensor.dim());
size_t ret = 0;
for (int i = 0; i < index.size(); i++) {
@ -299,7 +299,7 @@ template <class T>
utils::ConstTensorView<T> GetSubTensorView(
const TensorCPU& tensor,
int dim0_start_index) {
DCHECK_EQ(tensor.meta().itemsize(), sizeof(T));
TORCH_DCHECK_EQ(tensor.meta().itemsize(), sizeof(T));
if (tensor.size() == 0) {
return utils::ConstTensorView<T>(nullptr, {});
@ -1490,7 +1490,7 @@ class MPSCNNConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
caffe2::Timer consT;
std::vector<float> refilter(kH * kW * output_channels * input_channels);
refilter.assign(kH * kW * output_channels * input_channels, 0.0f);
DCHECK_EQ(refilter.size(), filter.size());
TORCH_DCHECK_EQ(refilter.size(), filter.size());
auto* filter_ = filter.template data<float>();
// For iOS11+ Reformat weights from WT[IC][OC][kH][kW] to
// W[OC][kH][kW][IC]; For previous versions, reformat weights
@ -1512,14 +1512,14 @@ class MPSCNNConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
kw * output_channels * input_channels +
oc * input_channels + ic;
}
DCHECK_LT(inputIdx, filter.size());
DCHECK_LT(outputIdx, filter.size());
TORCH_DCHECK_LT(inputIdx, filter.size());
TORCH_DCHECK_LT(outputIdx, filter.size());
refilter[outputIdx] = filter_[inputIdx];
}
}
}
}
DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW);
TORCH_DCHECK_EQ(filter.size(), input_channels * output_channels * kH * kW);
// initialize data structures
if (runtimeAtLeastIOS11) {
MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor
@ -2225,7 +2225,7 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
auto keep =
utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
DCHECK_LE(keep.size(), scores.size());
TORCH_DCHECK_LE(keep.size(), scores.size());
// 4. sort all (proposal, score) pairs by score from highest to lowest
// 5. take top pre_nms_topN (e.g. 6000)

View File

@ -126,7 +126,7 @@ void testMPSCNN() {
CAFFE_ENFORCE_EQ(t1.sizes(), t2.sizes());
for (auto i = 0; i < t1.size(); ++i) {
// FP16 <-> FP32 round trip.
CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
TORCH_CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
}
}
}
@ -197,7 +197,7 @@ void testMPSCNN() {
CAFFE_ENFORCE_EQ(t1.size(), t2.size());
for (auto i = 0; i < t1.size(); ++i) {
// FP16 <-> FP32 round trip.
CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
TORCH_CHECK_NEAR(t1.data<float>()[i], t2.data<float>()[i], 1e-2);
}
}
}
@ -274,7 +274,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -467,7 +467,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -560,7 +560,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -651,7 +651,7 @@ void testMPSCNN() {
const float t2_i = t2.data<float>()[i];
// LOG(INFO) << "i: " << i << ", cpu: " << t1_i << ", mtl: " <<
// t2_i;
CHECK_NEAR(t1_i, t2_i, 0.7);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.7);
}
}
}
@ -763,7 +763,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -851,7 +851,7 @@ void testMPSCNN() {
const float t2_i = t2.data<float>()[i];
// LOG(INFO) << "i: " << i << ", " << "CPU: " << t1_i << ", MTL: " <<
// t2_i;
CHECK_NEAR(t1_i, t2_i, 0.01);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.01);
}
}
}
@ -932,7 +932,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -991,7 +991,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<uint8_t>()[i];
const float t2_i = t2.data<uint8_t>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -1050,7 +1050,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<uint8_t>()[i];
const float t2_i = t2.data<uint8_t>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -1166,7 +1166,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.2);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.2);
}
}
}
@ -1264,7 +1264,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.3);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.3);
}
}
}
@ -1378,7 +1378,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -1481,7 +1481,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -1589,7 +1589,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -1713,7 +1713,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -1784,7 +1784,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.02);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.02);
}
}
@ -1849,7 +1849,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.01);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.01);
}
}
@ -1914,7 +1914,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.01);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.01);
}
}
@ -2003,7 +2003,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.05);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.05);
}
}
@ -2057,7 +2057,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.02);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.02);
}
}
}
@ -2123,7 +2123,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -2237,7 +2237,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -2349,7 +2349,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -2428,7 +2428,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -2570,7 +2570,7 @@ void testMPSCNN() {
const float t3_i = t3.data<float>()[i / 5];
if (t3_i - HALF_MIN_VAL * 2 > 0) {
LOG(INFO) << i << " " << t1_i << " " << t2_i << " " << t3_i;
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
@ -2579,7 +2579,7 @@ void testMPSCNN() {
const float t3_i = t3.data<float>()[i];
const float t4_i = t4.data<float>()[i];
LOG(INFO) << i << " " << t3_i;
CHECK_NEAR(t3_i, t4_i, 0.1);
TORCH_CHECK_NEAR(t3_i, t4_i, 0.1);
}
}
@ -2634,7 +2634,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -2875,7 +2875,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -2943,7 +2943,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -3030,7 +3030,7 @@ void testMPSCNN() {
// FP16 <-> FP32 round trip, accumulation, etc.
const float t1_i = t1.data<float>()[i];
const float t2_i = t2.data<float>()[i];
CHECK_NEAR(t1_i, t2_i, 0.1);
TORCH_CHECK_NEAR(t1_i, t2_i, 0.1);
}
}
}
@ -3072,10 +3072,10 @@ void testMPSCNN() {
}
return arg->i();
};
CHECK_EQ(rc(0), 1);
CHECK_EQ(rc(1), 2);
CHECK_EQ(rc(2), 1);
CHECK_EQ(rc(3), 1);
TORCH_CHECK_EQ(rc(0), 1);
TORCH_CHECK_EQ(rc(1), 2);
TORCH_CHECK_EQ(rc(2), 1);
TORCH_CHECK_EQ(rc(3), 1);
}
{
@ -3117,18 +3117,18 @@ void testMPSCNN() {
auto ty = [&](size_t i) { return netdef.op(i).type(); };
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
CHECK_EQ(netdef.op_size(), 4);
CHECK_EQ(ty(0), "CopyToMPSCNN");
CHECK_EQ(ty(1), std::string("MPSCNN") + computeOp + std::string("Relu"));
CHECK_EQ(ty(2), std::string("MPSCNN") + computeOp + std::string("Relu"));
CHECK_EQ(ty(3), "CopyFromMPSCNN");
CHECK_EQ(i0(0), "X");
CHECK_EQ(i0(1), o0(0));
CHECK_EQ(i0(2), "X2");
CHECK_EQ(o0(2), i0(3));
CHECK_EQ(o0(3), "Y");
CHECK_EQ(netdef.external_input(0), "X");
CHECK_EQ(netdef.external_output(0), "Y");
TORCH_CHECK_EQ(netdef.op_size(), 4);
TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN");
TORCH_CHECK_EQ(ty(1), std::string("MPSCNN") + computeOp + std::string("Relu"));
TORCH_CHECK_EQ(ty(2), std::string("MPSCNN") + computeOp + std::string("Relu"));
TORCH_CHECK_EQ(ty(3), "CopyFromMPSCNN");
TORCH_CHECK_EQ(i0(0), "X");
TORCH_CHECK_EQ(i0(1), o0(0));
TORCH_CHECK_EQ(i0(2), "X2");
TORCH_CHECK_EQ(o0(2), i0(3));
TORCH_CHECK_EQ(o0(3), "Y");
TORCH_CHECK_EQ(netdef.external_input(0), "X");
TORCH_CHECK_EQ(netdef.external_output(0), "Y");
}
}
@ -3195,18 +3195,18 @@ void testMPSCNN() {
op.add_output("Z");
}
netdef = rewriteForMetal(netdef);
CHECK_EQ(netdef.op_size(), 4);
TORCH_CHECK_EQ(netdef.op_size(), 4);
auto ty = [&](size_t i) { return netdef.op(i).type(); };
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
CHECK_EQ(ty(0), "CopyToMPSCNN");
CHECK_EQ(ty(1), "MPSCNNConvRelu");
CHECK_EQ(ty(2), "MPSCNNRelu");
CHECK_EQ(ty(3), "CopyFromMPSCNN");
CHECK_EQ(i0(1), o0(0));
CHECK_EQ(o0(1), "Z");
CHECK_EQ(i0(2), "Z");
CHECK_EQ(o0(2), i0(3));
TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN");
TORCH_CHECK_EQ(ty(1), "MPSCNNConvRelu");
TORCH_CHECK_EQ(ty(2), "MPSCNNRelu");
TORCH_CHECK_EQ(ty(3), "CopyFromMPSCNN");
TORCH_CHECK_EQ(i0(1), o0(0));
TORCH_CHECK_EQ(o0(1), "Z");
TORCH_CHECK_EQ(i0(2), "Z");
TORCH_CHECK_EQ(o0(2), i0(3));
}
{
@ -3235,21 +3235,21 @@ void testMPSCNN() {
op.add_output("Z");
}
netdef = rewriteForMetal(netdef);
CHECK_EQ(netdef.op_size(), 5);
TORCH_CHECK_EQ(netdef.op_size(), 5);
auto ty = [&](size_t i) { return netdef.op(i).type(); };
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
CHECK_EQ(ty(0), "CopyToMPSCNN");
CHECK_EQ(ty(1), "MPSCNNConv");
CHECK_EQ(ty(2), "MPSCNNRelu");
CHECK_EQ(ty(3), "MPSCNNRelu");
CHECK_EQ(ty(4), "CopyFromMPSCNN");
CHECK_EQ(i0(1), o0(0));
CHECK_EQ(o0(1), "Y");
CHECK_EQ(i0(2), o0(1));
CHECK_EQ(o0(2), "Z");
CHECK_EQ(i0(3), o0(1));
CHECK_EQ(o0(3), i0(4));
TORCH_CHECK_EQ(ty(0), "CopyToMPSCNN");
TORCH_CHECK_EQ(ty(1), "MPSCNNConv");
TORCH_CHECK_EQ(ty(2), "MPSCNNRelu");
TORCH_CHECK_EQ(ty(3), "MPSCNNRelu");
TORCH_CHECK_EQ(ty(4), "CopyFromMPSCNN");
TORCH_CHECK_EQ(i0(1), o0(0));
TORCH_CHECK_EQ(o0(1), "Y");
TORCH_CHECK_EQ(i0(2), o0(1));
TORCH_CHECK_EQ(o0(2), "Z");
TORCH_CHECK_EQ(i0(3), o0(1));
TORCH_CHECK_EQ(o0(3), i0(4));
}
{
@ -3277,14 +3277,14 @@ void testMPSCNN() {
auto ty = [&](size_t i) { return netdef.op(i).type(); };
auto i0 = [&](size_t i) { return netdef.op(i).input(0); };
auto o0 = [&](size_t i) { return netdef.op(i).output(0); };
CHECK_EQ(netdef.op_size(), 3);
CHECK_EQ(ty(0), "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess");
CHECK_EQ(ty(1), "MPSCNNRelu");
CHECK_EQ(ty(2), "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess");
CHECK_EQ(i0(0), "X");
CHECK_EQ(i0(1), o0(0));
CHECK_EQ(i0(2), o0(1));
CHECK_EQ(o0(2), "Z");
TORCH_CHECK_EQ(netdef.op_size(), 3);
TORCH_CHECK_EQ(ty(0), "MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocess");
TORCH_CHECK_EQ(ty(1), "MPSCNNRelu");
TORCH_CHECK_EQ(ty(2), "MPSCNNBRGNCHWCToPackedInt8BGRAStylizerDeprocess");
TORCH_CHECK_EQ(i0(0), "X");
TORCH_CHECK_EQ(i0(1), o0(0));
TORCH_CHECK_EQ(i0(2), o0(1));
TORCH_CHECK_EQ(o0(2), "Z");
}
LOG(INFO) << "All MPSCNN tests passed.";
}
@ -3296,12 +3296,12 @@ NetDef truncateAfter(NetDef def, size_t idx) {
for (auto i = 0; i < toRemove; ++i) {
def.mutable_op()->RemoveLast();
}
CHECK_EQ(def.op_size(), idx + 1);
TORCH_CHECK_EQ(def.op_size(), idx + 1);
return def;
}
NetDef addMPSCNNCopyFinalizer(NetDef def) {
CHECK_GE(def.op_size(), 1);
TORCH_CHECK_GE(def.op_size(), 1);
const auto name = def.mutable_op(def.op_size() - 1)->output(0);
def.mutable_op(def.op_size() - 1)->set_output(0, "METAL_COPIER");
{
@ -3315,7 +3315,7 @@ NetDef addMPSCNNCopyFinalizer(NetDef def) {
void compareModels(const NetDef& initNet, NetDef predictNet) {
auto* arg = predictNet.mutable_op(0)->mutable_arg(0);
CHECK_EQ(arg->name(), "noise_std");
TORCH_CHECK_EQ(arg->name(), "noise_std");
arg->set_f(0.000001);
NetDef metalPredictNet;
@ -3365,7 +3365,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
{
const auto& mt = mws.GetBlob(name)->Get<TensorCPU>();
const auto& ct = cws.GetBlob(name)->Get<TensorCPU>();
CHECK_EQ(mt.sizes(), ct.sizes());
TORCH_CHECK_EQ(mt.sizes(), ct.sizes());
for (auto j = 0; j < mt.size(); ++j) {
if (mt.IsType<float>()) {
if (j < 10) {
@ -3373,7 +3373,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
<< ", CPU: " << ct.data<float>()[j]
<< ", MTL: " << mt.data<float>()[j];
}
CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
TORCH_CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
} else {
CHECK(mt.IsType<uint8_t>());
if (j < 10) {
@ -3381,7 +3381,7 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
<< ", CPU: " << ct.data<uint8_t>()[j]
<< ", MTL: " << mt.data<uint8_t>()[j];
}
CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
TORCH_CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
}
}
}
@ -3428,7 +3428,7 @@ void verifyRewrite(
LOG(INFO) << "One of the operator failed.";
return;
}
// CHECK_EQ(mt.sizes(), ct.sizes());
// TORCH_CHECK_EQ(mt.sizes(), ct.sizes());
for (auto j = 0; j < fmin(mt.size(), ct.size()); ++j) {
if (mt.IsType<float>()) {
if (j < 10) {
@ -3437,7 +3437,7 @@ void verifyRewrite(
<< ", MTL: " << mt.data<float>()[j];
}
// Disabling check for now because of precision issues
// CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
// TORCH_CHECK_NEAR(mt.data<float>()[j], ct.data<float>()[j], 5);
} else {
LOG(INFO) << "Type uint8_t";
CHECK(mt.IsType<uint8_t>());
@ -3447,7 +3447,7 @@ void verifyRewrite(
<< ", MTL: " << mt.data<uint8_t>()[j];
}
// Disabling check for now.
// CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
// TORCH_CHECK_NEAR(mt.data<uint8_t>()[j], ct.data<uint8_t>()[j], 5);
}
}
}

View File

@ -170,7 +170,7 @@ void filterNormalization11(const TensorCPU& WQ, TensorCPU* WQN) {
for (auto j = 0; j < WQs; ++j) {
bitSum += __builtin_popcount(WQdata[f * WQs + j]);
}
DCHECK_LE(bitSum, WQbits);
TORCH_DCHECK_LE(bitSum, WQbits);
WQNdata[f] = 2 * bitSum - WQbits;
}
}

View File

@ -19,7 +19,7 @@ inline void quantize2bNeon(size_t QC,
float offset,
float inter_center_distance,
std::array<uint8_t*, k2b1bXBits> XQdata) {
DCHECK_EQ(QC % 8, 0);
TORCH_DCHECK_EQ(QC % 8, 0);
const auto offset_plus_2_inter_center_distance = vdupq_n_f32(offset + 2 * inter_center_distance);
const auto offset_plus_inter_center_distance = vdupq_n_f32(offset + inter_center_distance);
const auto offset_ = vdupq_n_f32(offset);
@ -291,7 +291,7 @@ void qgess_packed(const uint8_t* __restrict__ Ablock,
F&& f) {
static_assert(kUnrollN % 8 == 0, "");
static_assert(TileDepthBytes == 16, "");
DCHECK_EQ(QK % 16, 0);
TORCH_DCHECK_EQ(QK % 16, 0);
uint16x8_t acc[kUnrollM][kUnrollN / 8];
for (size_t mm = 0; mm < kUnrollM; ++mm) {
for (size_t nn = 0; nn < kUnrollN / 8; ++nn) {

View File

@ -19,7 +19,7 @@ void conv(const ConvArgs& args,
(X.dim32(1) - KH + args.pad_t + args.pad_b) / args.stride_h + 1,
(X.dim32(2) - KW + args.pad_l + args.pad_r) / args.stride_w + 1,
W.dim32(0));
CHECK_EQ(W.dim32(3), X.dim32(3));
TORCH_CHECK_EQ(W.dim32(3), X.dim32(3));
const auto OH = Y->dim32(1);
const auto OW = Y->dim32(2);
const auto OC = Y->dim32(3);
@ -155,7 +155,7 @@ inline void gemmNT(int M, int N, int K, const float* A, const float* B, float* C
}
inline void qgemmNT(int M, int N, int K, const uint8_t* A, const uint8_t* B, float* C) {
CHECK_EQ(K % 8, 0);
TORCH_CHECK_EQ(K % 8, 0);
const int QK = K / 8;
for (auto m = 0; m < M; ++m) {
for (auto n = 0; n < N; ++n) {

View File

@ -12,7 +12,7 @@ void APMeterOp<float, CPUContext>::BufferPredictions(
// Initialize the buffer
buffers_.resize(D, std::vector<BufferDataType>(buffer_size_));
}
DCHECK_EQ(buffers_.size(), D);
TORCH_DCHECK_EQ(buffers_.size(), D);
// Fill atmose buffer_size_ data at a time, so truncate the input if needed
if (N > buffer_size_) {
@ -48,12 +48,12 @@ bool APMeterOp<float, CPUContext>::RunOnDevice() {
auto& label = Input(LABEL);
// Check dimensions
DCHECK_EQ(X.dim(), 2);
TORCH_DCHECK_EQ(X.dim(), 2);
int N = X.dim32(0);
int D = X.dim32(1);
DCHECK_EQ(label.dim(), 2);
DCHECK_EQ(label.dim32(0), N);
DCHECK_EQ(label.dim32(1), D);
TORCH_DCHECK_EQ(label.dim(), 2);
TORCH_DCHECK_EQ(label.dim32(0), N);
TORCH_DCHECK_EQ(label.dim32(1), D);
auto* Y = Output(0, {D}, at::dtype<float>());
const auto* Xdata = X.data<float>();

View File

@ -124,8 +124,8 @@ bool BatchBoxCoxOp<CPUContext>::DoRunWithType() {
if (K > 1) {
TileArrayIntoVector(lambda1_ptr, D, K, &b.lambda1_);
TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_);
DCHECK_EQ(K * D, b.lambda1_.size());
DCHECK_EQ(K * D, b.lambda2_.size());
TORCH_DCHECK_EQ(K * D, b.lambda1_.size());
TORCH_DCHECK_EQ(K * D, b.lambda2_.size());
for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) {
BoxCoxNonzeroLambda(
K * D,
@ -144,7 +144,7 @@ bool BatchBoxCoxOp<CPUContext>::DoRunWithType() {
int64_t i = 0;
if (K > 1) {
TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_z_);
DCHECK_EQ(K * D, b.lambda2_z_.size());
TORCH_DCHECK_EQ(K * D, b.lambda2_z_.size());
for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) {
BoxCoxZeroLambda(
K * D, data_ptr, b.lambda2_z_.data(), k_eps, output_ptr);
@ -176,9 +176,9 @@ bool BatchBoxCoxOp<CPUContext>::DoRunWithType() {
zeros_.resize(D - n);
TileIndicesInPlace(&nonzeros_, D, K);
TileIndicesInPlace(&zeros_, D, K);
DCHECK_EQ(nonzeros_.size(), b.lambda1_.size());
DCHECK_EQ(nonzeros_.size(), b.lambda2_.size());
DCHECK_EQ(zeros_.size(), b.lambda2_z_.size());
TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda1_.size());
TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda2_.size());
TORCH_DCHECK_EQ(zeros_.size(), b.lambda2_z_.size());
for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) {
BoxCoxMixedLambda(
data_ptr,

View File

@ -100,7 +100,7 @@ bool BBoxTransformOp<float, CPUContext>::RunOnDevice() {
CAFFE_ENFORCE_EQ(iminfo_in.dim32(1), 3);
const int batch_size = iminfo_in.dim32(0);
DCHECK_EQ(weights_.size(), 4);
TORCH_DCHECK_EQ(weights_.size(), 4);
Eigen::Map<const ERArrXXf> boxes0(
roi_in.data<float>(), roi_in.dim32(0), roi_in.dim32(1));

View File

@ -64,7 +64,7 @@ TEST(BooleanUnmaskTest, Test) {
auto& unmasked_data = unmasked_data_blob->Get<TensorCPU>();
EXPECT_EQ(unmasked_data.numel(), 1);
CHECK_EQ(unmasked_data.data<float>()[0], 1.0f);
TORCH_CHECK_EQ(unmasked_data.data<float>()[0], 1.0f);
}
} // namespace caffe2

View File

@ -161,14 +161,14 @@ const auto& tscores = Input(0);
// Pick the first `detections_per_im_` boxes with highest scores
auto all_scores_sorted = get_all_scores_sorted();
DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
TORCH_DCHECK_GT(all_scores_sorted.size(), detections_per_im_);
// Reconstruct keeps from `all_scores_sorted`
for (auto& cur_keep : keeps) {
cur_keep.clear();
}
for (int i = 0; i < detections_per_im_; i++) {
DCHECK_GT(all_scores_sorted.size(), i);
TORCH_DCHECK_GT(all_scores_sorted.size(), i);
auto& cur = all_scores_sorted[i];
keeps[cur.first].push_back(cur.second);
}

View File

@ -272,8 +272,8 @@ bool MakeTwoClassOp<float, CPUContext>::RunOnDevice() {
const auto* Xdata = X.data<float>();
auto* Ydata = Y->template mutable_data<float>();
for (int64_t i = 0; i < N; ++i) {
DCHECK_GE(Xdata[i], 0.0);
DCHECK_LE(Xdata[i], 1.0);
TORCH_DCHECK_GE(Xdata[i], 0.0);
TORCH_DCHECK_LE(Xdata[i], 1.0);
Ydata[i * 2] = 1.0 - Xdata[i];
Ydata[i * 2 + 1] = Xdata[i];
}

View File

@ -308,7 +308,7 @@ void DeformConvOpBase<DType, Context>::DeformableIm2col(
at::IntArrayRef im_shape,
at::IntArrayRef col_shape,
DType* data_col) {
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
CAFFE_ENFORCE_EQ(pad_t(), pad_b());
CAFFE_ENFORCE_EQ(pad_l(), pad_r());
const int pad_h = pad_t();
@ -444,7 +444,7 @@ void DeformConvOpBase<DType, Context>::DeformableCol2im(
index_t channel_per_deformable_group = im_shape[1] / deformable_group_;
index_t num_kernels = size_from_dim_(0, col_shape);
// num_axes should be smaller than block size
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
// NOLINT_NEXT_LINE(whitespace/operators)
@ -592,7 +592,7 @@ void DeformConvOpBase<DType, Context>::DeformableCol2imCoord(
kernel_w() * deformable_group_;
index_t channel_per_deformable_group = col_shape[0] / deformable_group_;
// num_axes should be smaller than block size
CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
TORCH_CHECK_LT(2, CAFFE_CUDA_NUM_THREADS);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
// NOLINT_NEXT_LINE(whitespace/operators)

View File

@ -444,7 +444,7 @@ class GaussianFillOp final : public FillerOp<Context> {
: FillerOp<Context>(std::forward<Args>(args)...),
mean_(this->template GetSingleArgument<float>("mean", 0)),
std_(this->template GetSingleArgument<float>("std", 1)) {
DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
TORCH_DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
}
bool Fill(Tensor* output) override {

View File

@ -73,7 +73,7 @@ class FullyConnectedOp final : public Operator<Context> {
Y_shape_cache_ = X.sizes().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
TORCH_DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
Y_shape_cache_.resize(canonical_axis + 1);
Y_shape_cache_[canonical_axis] = N;
auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());

View File

@ -11,7 +11,7 @@ namespace {
size_t ComputeStartIndex(
const TensorCPU& tensor,
const std::vector<int>& index) {
DCHECK_EQ(index.size(), tensor.dim());
TORCH_DCHECK_EQ(index.size(), tensor.dim());
size_t ret = 0;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
@ -27,7 +27,7 @@ template <class T>
utils::ConstTensorView<T> GetSubTensorView(
const TensorCPU& tensor,
int dim0_start_index) {
DCHECK_EQ(tensor.dtype().itemsize(), sizeof(T));
TORCH_DCHECK_EQ(tensor.dtype().itemsize(), sizeof(T));
if (tensor.numel() == 0) {
return utils::ConstTensorView<T>(nullptr, {});
@ -244,7 +244,7 @@ void GenerateProposalsOp<CPUContext>::ProposalsForOneImage(
// 3. remove predicted boxes with either height or width < min_size
auto keep =
utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
DCHECK_LE(keep.size(), scores_sorted.size());
TORCH_DCHECK_LE(keep.size(), scores_sorted.size());
// 6. apply loose nms (e.g. threshold = 0.7)
// 7. take after_nms_topN (e.g. 300)

View File

@ -28,7 +28,7 @@ class ConstTensorView {
return dims_;
}
int dim(int i) const {
DCHECK_LE(i, dims_.size());
TORCH_DCHECK_LE(i, dims_.size());
return dims_[i];
}
const T* data() const {

View File

@ -316,7 +316,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0GPU) {
// Add angle in bbox deltas
int num_boxes = scores.size();
CHECK_EQ(bbx.size() / 4, num_boxes);
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
vector<float> bbx_with_angle(num_boxes * box_dim);
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
// at each spatial location for each anchor.
@ -516,7 +516,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedGPU) {
// Add angle in bbox deltas
int num_boxes = scores.size();
CHECK_EQ(bbx.size() / 4, num_boxes);
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
vector<float> bbx_with_angle(num_boxes * box_dim);
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
// at each spatial location for each anchor.

View File

@ -494,7 +494,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
// Add angle in bbox deltas
auto num_boxes = scores.size();
CHECK_EQ(bbx.size() / 4, num_boxes);
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
vector<float> bbx_with_angle(num_boxes * box_dim);
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
// at each spatial location for each anchor.
@ -667,7 +667,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
// Add angle in bbox deltas
auto num_boxes = scores.size();
CHECK_EQ(bbx.size() / 4, num_boxes);
TORCH_CHECK_EQ(bbx.size() / 4, num_boxes);
vector<float> bbx_with_angle(num_boxes * box_dim);
// bbx (deltas) is in shape (A * 4, H, W). Insert angle delta
// at each spatial location for each anchor.

View File

@ -36,7 +36,7 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp<Context> {
}
bool Fill(Tensor* output) override {
DCHECK_EQ(output->numel(), values_.numel())
TORCH_DCHECK_EQ(output->numel(), values_.numel())
<< "output size: " << output->numel()
<< " given size: " << values_.numel();
auto* data = output->template mutable_data<uint8_t>();
@ -51,7 +51,7 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp<Context> {
private:
void Extract() {
auto source_values = this->template GetRepeatedArgument<string>("values");
DCHECK_EQ(source_values.size(), 1)
TORCH_DCHECK_EQ(source_values.size(), 1)
<< "expected size: 1 "
<< " given size: " << source_values.size();

View File

@ -7,7 +7,7 @@ bool LRNOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
// Note(Yangqing): this one is copied from my Caffe implementation.
auto& X = Input(0);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
@ -81,7 +81,7 @@ bool LRNOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
// variants have I written...?
auto& X = Input(0);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int H = X.dim32(1);
const int W = X.dim32(2);
@ -135,7 +135,7 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
auto& Y = Input(1);
auto& dY = Input(2);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
@ -143,8 +143,8 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
const int image_size = C * H * W;
// Loosely checking the size, assuming that the shapes will be the same as
// long as the sizes check out.
DCHECK_EQ(X.numel(), Y.numel());
DCHECK_EQ(X.numel(), dY.numel());
TORCH_DCHECK_EQ(X.numel(), Y.numel());
TORCH_DCHECK_EQ(X.numel(), dY.numel());
auto* dX = Output(0, X.sizes(), at::dtype<float>());
const float* Xdata = X.data<float>();
@ -248,7 +248,7 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
auto& Y = Input(1);
auto& dY = Input(2);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int H = X.dim32(1);
const int W = X.dim32(2);
@ -257,8 +257,8 @@ bool LRNGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
const float* Xdata = X.data<float>();
// Loosely checking the size, assuming that the shapes will be the same as
// long as the sizes check out.
DCHECK_EQ(X.numel(), Y.numel());
DCHECK_EQ(X.numel(), dY.numel());
TORCH_DCHECK_EQ(X.numel(), Y.numel());
TORCH_DCHECK_EQ(X.numel(), dY.numel());
auto* dX = Output(0, X.sizes(), at::dtype<float>());
if (!scale_) {
scale_ = &local_scale_tensor_;

View File

@ -177,7 +177,7 @@ template<>
bool LRNOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& X = Input(0);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
@ -214,7 +214,7 @@ template<>
bool LRNOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& X = Input(0);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int H = X.dim32(1);
const int W = X.dim32(2);
@ -252,15 +252,15 @@ bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& Y = Input(1);
auto& dY = Input(2);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
const int W = X.dim32(3);
// Loosely checking the size, assuming that the shapes will be the same as
// long as the sizes check out.
DCHECK_EQ(X.numel(), Y.numel());
DCHECK_EQ(X.numel(), dY.numel());
TORCH_DCHECK_EQ(X.numel(), Y.numel());
TORCH_DCHECK_EQ(X.numel(), dY.numel());
auto* dX = Output(0, X.sizes(), at::dtype<float>());
const float* Xdata = X.data<float>();
@ -295,7 +295,7 @@ bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& Y = Input(1);
auto& dY = Input(2);
DCHECK_EQ(X.dim(), 4);
TORCH_DCHECK_EQ(X.dim(), 4);
const int N = X.dim32(0);
const int H = X.dim32(1);
const int W = X.dim32(2);
@ -303,8 +303,8 @@ bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
const float* Xdata = X.data<float>();
// Loosely checking the size, assuming that the shapes will be the same as
// long as the sizes check out.
DCHECK_EQ(X.numel(), Y.numel());
DCHECK_EQ(X.numel(), dY.numel());
TORCH_DCHECK_EQ(X.numel(), Y.numel());
TORCH_DCHECK_EQ(X.numel(), dY.numel());
auto* dX = Output(0, X.sizes(), at::dtype<float>());
if (!scale_) {
scale_ = &local_scale_tensor_;

View File

@ -22,10 +22,10 @@ class LRNOpBase : public Operator<Context> {
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
pre_pad_((size_ - 1) / 2) {
DCHECK_GT(size_, 0);
DCHECK_EQ(size_ % 2, 1);
DCHECK_GT(alpha_, 0);
DCHECK_GT(beta_, 0);
TORCH_DCHECK_GT(size_, 0);
TORCH_DCHECK_EQ(size_ % 2, 1);
TORCH_DCHECK_GT(alpha_, 0);
TORCH_DCHECK_GT(beta_, 0);
}
bool RunOnDevice() override {

View File

@ -11,11 +11,11 @@ class GetGPUMemoryUsageOp final : public Operator<CUDAContext> {
~GetGPUMemoryUsageOp() override {}
bool RunOnDevice() override {
CHECK_EQ(InputSize(), 0);
CHECK_EQ(OutputSize(), 1);
TORCH_CHECK_EQ(InputSize(), 0);
TORCH_CHECK_EQ(OutputSize(), 1);
std::vector<long> total_by_gpu = CUDAContext::TotalMemoryByGpu();
std::vector<long> max_by_gpu = CUDAContext::MaxMemoryByGpu();
CHECK_EQ(total_by_gpu.size(), max_by_gpu.size());
TORCH_CHECK_EQ(total_by_gpu.size(), max_by_gpu.size());
auto* stats = Output(0, {2, static_cast<int64_t>(total_by_gpu.size())}, at::dtype<long>());

View File

@ -7,13 +7,13 @@ bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(PREDICTION);
auto& label = Input(LABEL);
DCHECK_EQ(X.dim(), 2);
TORCH_DCHECK_EQ(X.dim(), 2);
// amount, number of instances
int N = X.dim32(0);
// dimension, number of classes
int D = X.dim32(1);
DCHECK_EQ(label.dim(), 1);
DCHECK_EQ(label.dim32(0), N);
TORCH_DCHECK_EQ(label.dim(), 1);
TORCH_DCHECK_EQ(label.dim32(0), N);
auto* Y0 = Output(0, {D}, at::dtype<float>());
auto* Y1 = Output(1, {D}, at::dtype<int>());
@ -34,7 +34,7 @@ bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
}
}
int labelid = labeldata[i];
DCHECK_LT(labelid, D);
TORCH_DCHECK_LT(labelid, D);
if (maxid == labelid) {
accuracies[labelid]++;
}

View File

@ -40,13 +40,13 @@ bool MultiClassAccuracyOp<float, CUDAContext>::RunOnDevice() {
auto& label = Input(LABEL);
DCHECK_EQ(X.dim(), 2);
TORCH_DCHECK_EQ(X.dim(), 2);
// amount, number of instances
int N = X.dim32(0);
// dimension, number of classes
int D = X.dim32(1);
DCHECK_EQ(label.dim(), 1);
DCHECK_EQ(label.dim32(0), N);
TORCH_DCHECK_EQ(label.dim(), 1);
TORCH_DCHECK_EQ(label.dim32(0), N);
auto* Y0 = Output(0, {D}, at::dtype<float>());
auto* Y1 = Output(1, {D}, at::dtype<int>());

View File

@ -52,12 +52,12 @@ class GPUFallbackOpEx final : public Operator<CUDAContext> {
// Set up the symbols for the local workspace.
for (const string& name : def.input()) {
local_input_blobs_.push_back(local_ws_.CreateBlob(name));
CHECK_NOTNULL(local_input_blobs_.back());
TORCH_CHECK_NOTNULL(local_input_blobs_.back());
}
base_op_ = CreateOperator(base_def_, &local_ws_);
for (const string& name : def.output()) {
local_output_blobs_.push_back(local_ws_.GetBlob(name));
CHECK_NOTNULL(local_output_blobs_.back());
TORCH_CHECK_NOTNULL(local_output_blobs_.back());
}
}

View File

@ -6,7 +6,7 @@ template <>
bool PerplexityOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
DCHECK_EQ(X.dim(), 1);
TORCH_DCHECK_EQ(X.dim(), 1);
int N = X.dim32(0);
auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());

View File

@ -21,7 +21,7 @@ template <>
bool PerplexityOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
DCHECK_EQ(X.dim(), 1);
TORCH_DCHECK_EQ(X.dim(), 1);
int N = X.dim32(0);
auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());

View File

@ -176,7 +176,7 @@ bool PReluGradientOp<float, CPUContext>::RunOnDevice() {
CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU");
DCHECK_EQ(dY.numel(), Y.numel());
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
auto* dW = Output(1, W.sizes(), at::dtype<float>());

View File

@ -212,7 +212,7 @@ bool PReluGradientOp<float, CUDAContext>::RunOnDevice() {
CAFFE_ENFORCE(&Y != &X, "Cannot backpropagate through an in-place PReLU");
DCHECK_EQ(dY.numel(), Y.numel());
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
auto* dW = Output(1, W.sizes(), at::dtype<float>());

View File

@ -37,7 +37,7 @@ void Decode(
int sz = output->numel();
for (C10_UNUSED const auto i : c10::irange(sz)) {
DCHECK_LE(*code_ptr, cb_size);
TORCH_DCHECK_LE(*code_ptr, cb_size);
*out_ptr++ = cb_ptr[*code_ptr++];
}
} else {
@ -49,7 +49,7 @@ void Decode(
CAFFE_ENFORCE_EQ(cb_size, output->numel());
auto* out_ptr = output->template mutable_data<CodebookT>();
while (gradient_ptr < gradient_end) {
DCHECK_LE(*code_ptr, cb_size);
TORCH_DCHECK_LE(*code_ptr, cb_size);
out_ptr[*code_ptr++] += *gradient_ptr++;
}
}

View File

@ -43,7 +43,7 @@ class Int8AveragePoolOp final : public ConvPoolOpBase<CPUContext> {
Y->scale = Y_scale;
Y->zero_point = Y_zero_point;
CHECK_EQ(X.t.dim(), 4);
TORCH_CHECK_EQ(X.t.dim(), 4);
const int channels = X.t.dim32(3);
ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);

View File

@ -42,10 +42,10 @@ class Int8ChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
this->template GetSingleArgument<int>("Y_zero_point", 0);
const float Y_scale =
this->template GetSingleArgument<float>("Y_scale", 1.0f);
CHECK_EQ(Y_offset, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
TORCH_CHECK_EQ(Y_offset, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
TORCH_CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
const auto C = X.t.dim32(3);
const auto G = this->group_;

View File

@ -20,13 +20,13 @@ class Int8ConcatOp final : public Operator<CPUContext> {
if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
// Default to C axis
axis_ = this->template GetSingleArgument<int>("axis", 3);
CHECK_GE(axis_, 0);
CHECK_LT(axis_, 4);
TORCH_CHECK_GE(axis_, 0);
TORCH_CHECK_LT(axis_, 4);
} else if (
this->template GetSingleArgument<string>("order", "") == "NCHW") {
axis_ = this->template GetSingleArgument<int>("axis", 1);
CHECK_GE(axis_, 0);
CHECK_LT(axis_, 4);
TORCH_CHECK_GE(axis_, 0);
TORCH_CHECK_LT(axis_, 4);
} else {
axis_ = this->template GetSingleArgument<int>("axis", 0);
}
@ -39,20 +39,20 @@ class Int8ConcatOp final : public Operator<CPUContext> {
Y->zero_point = X0.zero_point;
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_offset, X0.zero_point);
CHECK_EQ(Y_scale, X0.scale);
CHECK_GE(X0.zero_point, std::numeric_limits<uint8_t>::min());
CHECK_LE(X0.zero_point, std::numeric_limits<uint8_t>::max());
TORCH_CHECK_EQ(Y_offset, X0.zero_point);
TORCH_CHECK_EQ(Y_scale, X0.scale);
TORCH_CHECK_GE(X0.zero_point, std::numeric_limits<uint8_t>::min());
TORCH_CHECK_LE(X0.zero_point, std::numeric_limits<uint8_t>::max());
auto Y_dims = X0.t.sizes().vec();
if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
CHECK_EQ(Y_dims.size(), 4);
TORCH_CHECK_EQ(Y_dims.size(), 4);
}
for (const auto i : c10::irange(1, InputSize())) {
const auto& Xi = Inputs()[i]->template Get<Int8TensorCPU>();
CHECK_EQ(Xi.t.dim(), Y_dims.size());
TORCH_CHECK_EQ(Xi.t.dim(), Y_dims.size());
for (const auto j : c10::irange(Y_dims.size())) {
if (j != axis_) {
CHECK_EQ(Xi.t.size(j), Y_dims[j]);
TORCH_CHECK_EQ(Xi.t.size(j), Y_dims[j]);
}
}
Y_dims[axis_] += Xi.t.size(axis_);

View File

@ -56,7 +56,7 @@ class Int8ConvOp final : public ConvPoolOpBase<CPUContext> {
const bool isDepthwise = this->group_ > 1 && this->group_ == M &&
this->group_ == C && KC == 1 && KH * KW == 9 && dilation_w() == 1;
CHECK_EQ(Y->t.dim32(3), M);
TORCH_CHECK_EQ(Y->t.dim32(3), M);
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
initQNNPACK();

View File

@ -47,14 +47,14 @@ class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
const auto IC = X.t.size(3);
CHECK_EQ(IC, W.t.size(0));
TORCH_CHECK_EQ(IC, W.t.size(0));
const auto KH = W.t.size(1);
const auto KW = W.t.size(2);
const auto OC = W.t.size(3);
auto sizes = ConvTransposeUnpoolBase<CPUContext>::GetOutputSize(X.t, OC);
ReinitializeTensor(&(Y->t), sizes, at::dtype<uint8_t>().device(CPU));
CHECK_EQ(OC, Y->t.size(3));
TORCH_CHECK_EQ(OC, Y->t.size(3));
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
initQNNPACK();

View File

@ -39,8 +39,8 @@ class Int8FCOp final : public Operator<CPUContext> {
// (NxHxW)xC == MxK x (NxK) -> MxN
const auto K = X.t.size_from_dim(1);
const auto N = W.t.size(0);
CHECK_EQ(K, W.t.size(1));
CHECK_EQ(N, B.t.numel());
TORCH_CHECK_EQ(K, W.t.size(1));
TORCH_CHECK_EQ(N, B.t.numel());
const auto M = X.t.numel() / K;
ReinitializeTensor(&Y->t, {M, N}, at::dtype<uint8_t>().device(CPU));

View File

@ -22,8 +22,8 @@ class Int8FlattenOp : public Operator<CPUContext> {
auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_offset, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_EQ(Y_offset, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
Y->scale = Y_scale;
Y->zero_point = Y_offset;
CAFFE_ENFORCE_GE(

View File

@ -46,7 +46,7 @@ class Int8GivenTensorFillOp final : public Operator<CPUContext> {
}
bool Fill(Int8TensorCPU* output) {
DCHECK_EQ(output->t.numel(), values_.numel())
TORCH_DCHECK_EQ(output->t.numel(), values_.numel())
<< "output size: " << output->t.numel()
<< " given size: " << values_.numel();
auto* data = output->t.template mutable_data<uint8_t>();
@ -98,7 +98,7 @@ class Int8GivenIntTensorFillOp final : public Operator<CPUContext> {
}
bool Fill(Int8TensorCPU* output) {
DCHECK_EQ(output->t.numel(), values_.numel())
TORCH_DCHECK_EQ(output->t.numel(), values_.numel())
<< "output size: " << output->t.numel()
<< " given size: " << values_.numel();
auto* data = output->t.template mutable_data<int32_t>();

View File

@ -38,8 +38,8 @@ class Int8LeakyReluOp final : public Operator<CPUContext> {
const int32_t Y_zero_point =
this->template GetSingleArgument<int>("Y_zero_point", 0);
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_GE(Y_zero_point, std::numeric_limits<uint8_t>::min());
CHECK_LE(Y_zero_point, std::numeric_limits<uint8_t>::max());
TORCH_CHECK_GE(Y_zero_point, std::numeric_limits<uint8_t>::min());
TORCH_CHECK_LE(Y_zero_point, std::numeric_limits<uint8_t>::max());
/*
* Record quantization parameters for the input, because if the op is

View File

@ -38,10 +38,10 @@ class Int8MaxPoolOp final : public ConvPoolOpBase<CPUContext> {
const int32_t Y_zero_point =
this->template GetSingleArgument<int>("Y_zero_point", 0);
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_zero_point, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_EQ(Y_zero_point, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
CHECK_EQ(X.t.dim(), 4);
TORCH_CHECK_EQ(X.t.dim(), 4);
const int channels = X.t.dim32(3);
ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);

View File

@ -34,14 +34,14 @@ class Int8ReluOp final : public Operator<CPUContext> {
Y->t.ResizeLike(X.t);
Y->scale = X.scale;
Y->zero_point = X.zero_point;
CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
TORCH_CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
TORCH_CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
const int32_t Y_offset =
this->template GetSingleArgument<int>("Y_zero_point", 0);
const float Y_scale =
this->template GetSingleArgument<float>("Y_scale", 1.0f);
CHECK_EQ(Y_offset, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_EQ(Y_offset, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
initQNNPACK();

View File

@ -32,8 +32,8 @@ class Int8ReshapeOp final : public ReshapeOp<uint8_t, CPUContext> {
auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_offset, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_EQ(Y_offset, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
Y->scale = Y_scale;
Y->zero_point = Y_offset;
DoRunWithTypeImpl<T>(X.t, &Y->t);

View File

@ -49,8 +49,8 @@ class Int8ResizeNearestOp final : public Operator<CPUContext> {
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_offset, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_EQ(Y_offset, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
const uint8_t* Xdata = X.t.data<uint8_t>();
uint8_t* Ydata = Y->t.mutable_data<uint8_t>();

View File

@ -281,10 +281,10 @@ class Int8RoIAlignOp final : public Operator<CPUContext> {
sampling_ratio_(
this->template GetSingleArgument<int>("sampling_ratio", -1)),
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
DCHECK_GT(spatial_scale_, 0);
DCHECK_GT(pooled_height_, 0);
DCHECK_GT(pooled_width_, 0);
DCHECK_GE(sampling_ratio_, 0);
TORCH_DCHECK_GT(spatial_scale_, 0);
TORCH_DCHECK_GT(pooled_height_, 0);
TORCH_DCHECK_GT(pooled_width_, 0);
TORCH_DCHECK_GE(sampling_ratio_, 0);
// only supports NHWC
CAFFE_ENFORCE(order_ == StorageOrder::NHWC);
}

View File

@ -33,8 +33,8 @@ class Int8SigmoidOp final : public Operator<CPUContext> {
const int32_t Y_zero_point =
this->template GetSingleArgument<int>("Y_zero_point", 0);
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_zero_point, 0);
CHECK_EQ(Y_scale, 1.0f / 256.0f);
TORCH_CHECK_EQ(Y_zero_point, 0);
TORCH_CHECK_EQ(Y_scale, 1.0f / 256.0f);
/*
* Record quantization parameters for the input, because if the op is

View File

@ -76,8 +76,8 @@ class Int8SliceOp final : public SliceOp<CPUContext> {
auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_offset, X.zero_point);
CHECK_EQ(Y_scale, X.scale);
TORCH_CHECK_EQ(Y_offset, X.zero_point);
TORCH_CHECK_EQ(Y_scale, X.scale);
Y->scale = Y_scale;
Y->zero_point = Y_offset;

View File

@ -34,8 +34,8 @@ class Int8SoftmaxOp final : public Operator<CPUContext> {
const int32_t Y_zero_point =
this->template GetSingleArgument<int>("Y_zero_point", 0);
const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
CHECK_EQ(Y_zero_point, 0);
CHECK_EQ(Y_scale, 1.0f / 256.0f);
TORCH_CHECK_EQ(Y_zero_point, 0);
TORCH_CHECK_EQ(Y_scale, 1.0f / 256.0f);
/*
* Record quantization parameters for the input, because if the op is

View File

@ -341,7 +341,7 @@ TEST(Int8, SumRelu) {
}
void setq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
TORCH_CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
for (auto i = 0U; i < vs.size(); ++i) {
uint8_t vq = std::max(
std::numeric_limits<uint8_t>::min(),
@ -354,7 +354,7 @@ void setq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
}
void biassetq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
TORCH_CHECK_EQ(vs.size(), static_cast<size_t>(dst->t.numel()));
for (auto i = 0U; i < vs.size(); ++i) {
int32_t vq = std::max(
std::numeric_limits<int32_t>::min(),

View File

@ -91,8 +91,8 @@ inline void QuantizeMultiplierSmallerThanOne(
q_fixed /= 2;
--*right_shift;
}
CHECK_GE(*right_shift, 0);
CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
TORCH_CHECK_GE(*right_shift, 0);
TORCH_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
*quantized_multiplier = static_cast<int32_t>(q_fixed);
}
@ -108,8 +108,8 @@ inline void QuantizeMultiplierGreaterThanOne(
q_fixed /= 2;
++*left_shift;
}
CHECK_GE(*left_shift, 0);
CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
TORCH_CHECK_GE(*left_shift, 0);
TORCH_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
*quantized_multiplier = static_cast<int32_t>(q_fixed);
}

View File

@ -343,7 +343,7 @@ class BaseReducer {
}
void observeInput(int input, const Tensor& value, int skip_dims) {
DCHECK_EQ(0, input);
TORCH_DCHECK_EQ(0, input);
auto dims = value.sizes();
computeMeta(dims, skip_dims);
}

View File

@ -305,7 +305,7 @@ bool SumElementsGradientOp<T, Context>::RunOnDevice()
Tensor sum_grad(Input(1), CPU);
auto* dX = Output(0, X.sizes(), at::dtype<T>());
DCHECK_EQ(sum_grad.numel(), 1);
TORCH_DCHECK_EQ(sum_grad.numel(), 1);
math::Set<T, Context>(
dX->numel(),
static_cast<T>(

View File

@ -83,7 +83,7 @@ template <>
bool SumElementsGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& dY = Input(1);
DCHECK_EQ(dY.numel(), 1);
TORCH_DCHECK_EQ(dY.numel(), 1);
auto* dX = Output(0, X.sizes(), at::dtype<float>());
SumElementsGradientKernel<float>

View File

@ -25,10 +25,10 @@ class RoIAlignGradientOp final : public Operator<Context> {
sampling_ratio_(
this->template GetSingleArgument<int>("sampling_ratio", -1)),
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
DCHECK_GT(spatial_scale_, 0);
DCHECK_GT(pooled_height_, 0);
DCHECK_GT(pooled_width_, 0);
DCHECK_GE(sampling_ratio_, 0);
TORCH_DCHECK_GT(spatial_scale_, 0);
TORCH_DCHECK_GT(pooled_height_, 0);
TORCH_DCHECK_GT(pooled_width_, 0);
TORCH_DCHECK_GE(sampling_ratio_, 0);
}
USE_OPERATOR_CONTEXT_FUNCTIONS;

View File

@ -25,9 +25,9 @@ class RoIAlignOp final : public Operator<Context> {
OP_SINGLE_ARG(int, "pooled_w", pooled_w_, 1),
OP_SINGLE_ARG(int, "sampling_ratio", sampling_ratio_, -1),
OP_SINGLE_ARG(bool, "aligned", aligned_, false) {
DCHECK_GT(spatial_scale_, 0.0f);
DCHECK_GT(pooled_h_, 0);
DCHECK_GT(pooled_w_, 0);
TORCH_DCHECK_GT(spatial_scale_, 0.0f);
TORCH_DCHECK_GT(pooled_h_, 0);
TORCH_DCHECK_GT(pooled_w_, 0);
DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
}

View File

@ -22,10 +22,10 @@ class RoIAlignRotatedGradientOp final : public Operator<Context> {
sampling_ratio_(
this->template GetSingleArgument<int>("sampling_ratio", -1)),
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
DCHECK_GT(spatial_scale_, 0);
DCHECK_GT(pooled_height_, 0);
DCHECK_GT(pooled_width_, 0);
DCHECK_GE(sampling_ratio_, 0);
TORCH_DCHECK_GT(spatial_scale_, 0);
TORCH_DCHECK_GT(pooled_height_, 0);
TORCH_DCHECK_GT(pooled_width_, 0);
TORCH_DCHECK_GE(sampling_ratio_, 0);
}
USE_OPERATOR_CONTEXT_FUNCTIONS;

View File

@ -27,10 +27,10 @@ class RoIAlignRotatedOp final : public Operator<Context> {
sampling_ratio_(
this->template GetSingleArgument<int>("sampling_ratio", -1)),
aligned_(this->template GetSingleArgument<bool>("aligned", false)) {
DCHECK_GT(spatial_scale_, 0);
DCHECK_GT(pooled_height_, 0);
DCHECK_GT(pooled_width_, 0);
DCHECK_GE(sampling_ratio_, 0);
TORCH_DCHECK_GT(spatial_scale_, 0);
TORCH_DCHECK_GT(pooled_height_, 0);
TORCH_DCHECK_GT(pooled_width_, 0);
TORCH_DCHECK_GE(sampling_ratio_, 0);
DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
}
USE_OPERATOR_CONTEXT_FUNCTIONS;

View File

@ -137,10 +137,10 @@ bool SliceImpl(
src_offset_bytes + i * src_block_size_bytes;
char* local_dst_offset_bytes =
dst_offset_bytes + i * dst_block_size_bytes;
DCHECK_LE(
TORCH_DCHECK_LE(
static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
static_cast<void*>(src_bytes + src_nbytes));
DCHECK_LE(
TORCH_DCHECK_LE(
static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
static_cast<void*>(dst_bytes + dst_nbytes));
context->CopyItemsSameDevice(
@ -183,10 +183,10 @@ bool SliceImpl(
src_offset_bytes + i * src_block_size_bytes;
char* local_dst_offset_bytes =
dst_offset_bytes + i * dst_block_size_bytes;
DCHECK_LE(
TORCH_DCHECK_LE(
local_src_offset_bytes + src_block_size_bytes,
src_bytes + src_nbytes);
DCHECK_LE(
TORCH_DCHECK_LE(
local_dst_offset_bytes + src_block_size_bytes,
dst_bytes + dst_nbytes);
context->CopyItemsSameDevice(

View File

@ -98,7 +98,7 @@ class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
const int N = Y.size_to_dim(canonical_axis);
const int D = Y.size_from_dim(canonical_axis);
CHECK_EQ(Y.sizes(), dY.sizes());
TORCH_CHECK_EQ(Y.sizes(), dY.sizes());
auto* dX = Output(0, Y.sizes(), at::dtype<T>());
auto* dX_data = dX->template mutable_data<T>();
if (N == 0 || D == 0) {

View File

@ -23,7 +23,7 @@ bool SoftplusGradientOp<float, CPUContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
DCHECK_EQ(dY.numel(), Y.numel());
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
const float* Ydata = Y.data<float>();

View File

@ -24,7 +24,7 @@ template <>
bool SoftplusOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
DCHECK_GT(X.numel(), 0);
TORCH_DCHECK_GT(X.numel(), 0);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
SoftplusKernel<float>
<<<CAFFE_GET_BLOCKS(X.numel()),
@ -42,8 +42,8 @@ bool SoftplusGradientOp<float, CUDAContext>::RunOnDevice() {
auto& Y = Input(0);
auto& dY = Input(1);
DCHECK_GT(Y.numel(), 0);
DCHECK_EQ(dY.numel(), Y.numel());
TORCH_DCHECK_GT(Y.numel(), 0);
TORCH_DCHECK_EQ(dY.numel(), Y.numel());
auto* dX = Output(0, Y.sizes(), at::dtype<float>());
SoftplusGradientKernel<float>
<<<CAFFE_GET_BLOCKS(Y.numel()),

View File

@ -76,7 +76,7 @@ template<>
bool SummarizeOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
const int N = X.numel();
DCHECK_GT(N, 0);
TORCH_DCHECK_GT(N, 0);
// TODO(Yangqing): Any better way to avoid having to const cast?
thrust::device_ptr<float> Xdata(const_cast<float*>(X.data<float>()));

Some files were not shown because too many files have changed in this diff Show More