mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63096 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D30415255 Pulled By: tugsbayasgalan fbshipit-source-id: eb40440a3b46258394d035479f5fc4a4baa12bcc
109 lines
3.4 KiB
C++
109 lines
3.4 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/runtime/autodiff.h>
|
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
|
|
namespace {
|
|
static inline void trim(std::string& s) {
|
|
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
|
|
return !std::isspace(ch);
|
|
}));
|
|
s.erase(
|
|
std::find_if(
|
|
s.rbegin(),
|
|
s.rend(),
|
|
[](unsigned char ch) { return !std::isspace(ch); })
|
|
.base(),
|
|
s.end());
|
|
for (int64_t i = 0; i < s.size(); ++i) {
|
|
if (s[i] == '\n') {
|
|
s.erase(i, 1);
|
|
i--;
|
|
}
|
|
}
|
|
for (int64_t i = 0; i < s.size(); ++i) {
|
|
if (s[i] == ' ') {
|
|
for (int64_t j = i + 1; j < s.size(); j++) {
|
|
if (s[j] == ' ') {
|
|
s.erase(j, 1);
|
|
j--;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
|
|
try { \
|
|
(void)statement; \
|
|
FAIL(); \
|
|
} catch (const std::exception& e) { \
|
|
std::string substring_s(substring); \
|
|
trim(substring_s); \
|
|
auto exception_string = std::string(e.what()); \
|
|
trim(exception_string); \
|
|
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
|
|
}
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using tensor_list = std::vector<at::Tensor>;
|
|
using namespace torch::autograd;
|
|
|
|
// work around the fact that variable_tensor_list doesn't duplicate all
|
|
// of std::vector's constructors.
|
|
// most constructors are never used in the implementation, just in our tests.
|
|
Stack createStack(std::vector<at::Tensor>&& list);
|
|
|
|
void assertAllClose(const tensor_list& a, const tensor_list& b);
|
|
|
|
std::vector<at::Tensor> run(
|
|
InterpreterState& interp,
|
|
const std::vector<at::Tensor>& inputs);
|
|
|
|
std::pair<tensor_list, tensor_list> runGradient(
|
|
Gradient& grad_spec,
|
|
tensor_list& tensors_in,
|
|
tensor_list& tensor_grads_in);
|
|
|
|
std::shared_ptr<Graph> build_lstm();
|
|
std::shared_ptr<Graph> build_mobile_export_analysis_graph();
|
|
std::shared_ptr<Graph> build_mobile_export_with_out();
|
|
std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg();
|
|
std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested();
|
|
std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const();
|
|
|
|
at::Tensor t_use(at::Tensor x);
|
|
at::Tensor t_def(at::Tensor x);
|
|
|
|
// given the difference of output vs expected tensor, check whether the
|
|
// difference is within a relative tolerance range. This is a standard way of
|
|
// matching tensor values up to certain precision
|
|
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
|
|
bool almostEqual(const at::Tensor& a, const at::Tensor& b);
|
|
|
|
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
|
|
bool exactlyEqual(
|
|
const std::vector<at::Tensor>& a,
|
|
const std::vector<at::Tensor>& b);
|
|
|
|
std::vector<at::Tensor> runGraph(
|
|
std::shared_ptr<Graph> graph,
|
|
const std::vector<at::Tensor>& inputs);
|
|
|
|
std::pair<at::Tensor, at::Tensor> lstm(
|
|
at::Tensor input,
|
|
at::Tensor hx,
|
|
at::Tensor cx,
|
|
at::Tensor w_ih,
|
|
at::Tensor w_hh);
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|