mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24801 This is to fix the ODR-violations in fbcode static builds, which have been broken for several months. This PR is unfortunately quite large, but the changes are only mechanical: 1. Tests defined in header files -> tests defined in cpp files 2. Remove the `torch::jit::testing` namespace -> `torch::jit`. 3. Single `test.h` file that aggregates all tests. 4. Separate out files for gtest and python versions of the tests instead of using a build flag 5. Add a readme for how to add a new test, and explaining a bit about why the cpp tests are the way they are. Test Plan: Imported from OSS Differential Revision: D16878605 Pulled By: suo fbshipit-source-id: 27b5c077dadd990a5f74e25d01731f9c1f491603
63 lines
1.7 KiB
C++
63 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include "test/cpp/jit/test_base.h"
|
|
#include "torch/csrc/jit/autodiff.h"
|
|
#include "torch/csrc/jit/interpreter.h"
|
|
#include "torch/csrc/jit/irparser.h"
|
|
#include "torch/csrc/jit/symbolic_variable.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using Var = SymbolicVariable;
|
|
using tensor_list = std::vector<at::Tensor>;
|
|
using namespace torch::autograd;
|
|
|
|
// work around the fact that variable_tensor_list doesn't duplicate all
|
|
// of std::vector's constructors.
|
|
// most constructors are never used in the implementation, just in our tests.
|
|
Stack createStack(std::vector<at::Tensor>&& list);
|
|
|
|
void assertAllClose(const tensor_list& a, const tensor_list& b);
|
|
|
|
std::vector<at::Tensor> run(
|
|
InterpreterState& interp,
|
|
const std::vector<at::Tensor>& inputs);
|
|
|
|
std::pair<tensor_list, tensor_list> runGradient(
|
|
Gradient& grad_spec,
|
|
tensor_list& tensors_in,
|
|
tensor_list& tensor_grads_in);
|
|
|
|
std::tuple<Var, Var> build_lstm_body(
|
|
Graph& g,
|
|
Var input,
|
|
Var hx,
|
|
Var cx,
|
|
Var w_ih,
|
|
Var w_hh);
|
|
|
|
std::shared_ptr<Graph> build_lstm();
|
|
|
|
at::Tensor t_use(at::Tensor x);
|
|
at::Tensor t_def(at::Tensor x);
|
|
|
|
// given the difference of output vs expected tensor, check whether the
|
|
// difference is within a relative tolerance range. This is a standard way of
|
|
// matching tensor values upto certain precision
|
|
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
|
|
bool almostEqual(const at::Tensor& a, const at::Tensor& b);
|
|
|
|
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
|
|
|
|
std::pair<at::Tensor, at::Tensor> lstm(
|
|
at::Tensor input,
|
|
at::Tensor hx,
|
|
at::Tensor cx,
|
|
at::Tensor w_ih,
|
|
at::Tensor w_hh);
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|