mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR is a large codemod to rewrite all C++ API tests with GoogleTest (gtest) instead of Catch. You can largely trust me to have correctly code-modded the tests, so it's not required to review every of the 2000+ changed lines. However, additional things I changed were: 1. Moved the cmake parts for these tests into their own `CMakeLists.txt` under `test/cpp/api` and calling `add_subdirectory` from `torch/CMakeLists.txt` 2. Fixing DataParallel tests which weren't being compiled because `USE_CUDA` wasn't correctly being set at all. 3. Updated README ezyang ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/11953 Differential Revision: D9998883 Pulled By: goldsborough fbshipit-source-id: affe3f320b0ca63e7e0019926a59076bb943db80
92 lines
2.3 KiB
C++
92 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <torch/nn/cloneable.h>
|
|
#include <torch/tensor.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#ifndef WIN32
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
namespace torch {
|
|
namespace test {
|
|
|
|
// Lets you use a container without making a new class,
|
|
// for experimental implementations
|
|
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|
public:
|
|
void reset() override {}
|
|
|
|
template <typename ModuleHolder>
|
|
ModuleHolder add(
|
|
ModuleHolder module_holder,
|
|
std::string name = std::string()) {
|
|
return Module::register_module(std::move(name), module_holder);
|
|
}
|
|
};
|
|
|
|
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
|
|
return first.data<float>() == second.data<float>();
|
|
}
|
|
|
|
#ifdef WIN32
|
|
struct TempFile {
|
|
TempFile() : filename_(std::tmpnam(nullptr)) {}
|
|
const std::string& str() const {
|
|
return filename_;
|
|
}
|
|
std::string filename_;
|
|
};
|
|
#else
|
|
struct TempFile {
|
|
TempFile() {
|
|
// http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html
|
|
char filename[] = "/tmp/fileXXXXXX";
|
|
fd_ = mkstemp(filename);
|
|
AT_CHECK(fd_ != -1, "Error creating tempfile");
|
|
filename_.assign(filename);
|
|
}
|
|
|
|
~TempFile() {
|
|
close(fd_);
|
|
}
|
|
|
|
const std::string& str() const {
|
|
return filename_;
|
|
}
|
|
|
|
std::string filename_;
|
|
int fd_;
|
|
};
|
|
#endif
|
|
|
|
struct SeedingFixture : public ::testing::Test {
|
|
SeedingFixture() {
|
|
torch::manual_seed(0);
|
|
}
|
|
};
|
|
|
|
#define ASSERT_THROWS_WITH(statement, prefix) \
|
|
try { \
|
|
(void)statement; \
|
|
FAIL() << "Expected statement `" #statement \
|
|
"` to throw an exception, but it did not"; \
|
|
} catch (const std::exception& e) { \
|
|
std::string message = e.what(); \
|
|
if (message.find(prefix) == std::string::npos) { \
|
|
FAIL() << "Error message \"" << message \
|
|
<< "\" did not match expected prefix \"" << prefix << "\""; \
|
|
} \
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace torch
|