pytorch/test/cpp/api/support.h
Peter Goldsborough 825181ea9d Rewrite C++ API tests in gtest (#11953)
Summary:
This PR is a large codemod to rewrite all C++ API tests with GoogleTest (gtest) instead of Catch.

You can largely trust me to have correctly code-modded the tests, so it's not required to review every of the 2000+ changed lines. However, additional things I changed were:

1. Moved the cmake parts for these tests into their own `CMakeLists.txt` under `test/cpp/api` and calling `add_subdirectory` from `torch/CMakeLists.txt`
2. Fixing DataParallel tests which weren't being compiled because `USE_CUDA` wasn't correctly being set at all.
3. Updated README

ezyang ebetica
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11953

Differential Revision: D9998883

Pulled By: goldsborough

fbshipit-source-id: affe3f320b0ca63e7e0019926a59076bb943db80
2018-09-21 21:28:16 -07:00

92 lines
2.3 KiB
C++

#pragma once
#include <gtest/gtest.h>
#include <torch/nn/cloneable.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <cstdio>
#include <cstdlib>
#include <stdexcept>
#include <string>
#include <utility>
#ifndef WIN32
#include <unistd.h>
#endif
namespace torch {
namespace test {
// Lets you use a container without making a new class,
// for experimental implementations
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
public:
void reset() override {}
template <typename ModuleHolder>
ModuleHolder add(
ModuleHolder module_holder,
std::string name = std::string()) {
return Module::register_module(std::move(name), module_holder);
}
};
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
return first.data<float>() == second.data<float>();
}
#ifdef WIN32
struct TempFile {
TempFile() : filename_(std::tmpnam(nullptr)) {}
const std::string& str() const {
return filename_;
}
std::string filename_;
};
#else
struct TempFile {
TempFile() {
// http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html
char filename[] = "/tmp/fileXXXXXX";
fd_ = mkstemp(filename);
AT_CHECK(fd_ != -1, "Error creating tempfile");
filename_.assign(filename);
}
~TempFile() {
close(fd_);
}
const std::string& str() const {
return filename_;
}
std::string filename_;
int fd_;
};
#endif
struct SeedingFixture : public ::testing::Test {
SeedingFixture() {
torch::manual_seed(0);
}
};
#define ASSERT_THROWS_WITH(statement, prefix) \
try { \
(void)statement; \
FAIL() << "Expected statement `" #statement \
"` to throw an exception, but it did not"; \
} catch (const std::exception& e) { \
std::string message = e.what(); \
if (message.find(prefix) == std::string::npos) { \
FAIL() << "Error message \"" << message \
<< "\" did not match expected prefix \"" << prefix << "\""; \
} \
}
} // namespace test
} // namespace torch