mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: In our #better-engineering quest of removing all uses of catch in favor of gtest, this PR ports JIT tests to gtest. After #11846 lands, we will be able to delete catch. I don't claim to use/write these tests much (though I wrote the custom operator tests) so please do scrutinize whether you will want to write tests in the way I propose. Basically: 1. One function declaration per "test case" in test/cpp/jit/test.h 2. One definition in test/cpp/jit/test.cpp 3. If you want to be able to run it in Python, add it to `runJitTests()` which is called from Python tests 4. If you want to be able to run it in C++, add a `JIT_TEST` line in test/cpp/jit/gtest.cpp Notice also I was able to share support code between C++ frontend and JIT tests, which is healthy. ezyang apaszke zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/12030 Differential Revision: D10207745 Pulled By: goldsborough fbshipit-source-id: d4bae087e4d03818b72b8853cd5802d79a4cf32e
43 lines
904 B
C++
43 lines
904 B
C++
#pragma once
|
|
|
|
#include <test/cpp/common/support.h>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <torch/nn/cloneable.h>
|
|
#include <torch/tensor.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace test {
|
|
|
|
// Lets you use a container without making a new class,
|
|
// for experimental implementations
|
|
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|
public:
|
|
void reset() override {}
|
|
|
|
template <typename ModuleHolder>
|
|
ModuleHolder add(
|
|
ModuleHolder module_holder,
|
|
std::string name = std::string()) {
|
|
return Module::register_module(std::move(name), module_holder);
|
|
}
|
|
};
|
|
|
|
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
|
|
return first.data<float>() == second.data<float>();
|
|
}
|
|
|
|
struct SeedingFixture : public ::testing::Test {
|
|
SeedingFixture() {
|
|
torch::manual_seed(0);
|
|
}
|
|
};
|
|
|
|
} // namespace test
|
|
} // namespace torch
|