#pragma once #include #include "test/cpp/jit/test_base.h" #include "torch/csrc/jit/autodiff.h" #include "torch/csrc/jit/interpreter.h" #include "torch/csrc/jit/symbolic_variable.h" namespace torch { namespace jit { namespace test { using Var = SymbolicVariable; using tensor_list = std::vector; using namespace torch::autograd; // work around the fact that variable_tensor_list doesn't duplicate all // of std::vector's constructors. // most constructors are never used in the implementation, just in our tests. Stack createStack(std::vector&& list) { return Stack( std::make_move_iterator(list.begin()), std::make_move_iterator(list.end())); } void assertAllClose(const tensor_list& a, const tensor_list& b) { ASSERT_EQ(a.size(), b.size()); for (size_t i = 0; i < a.size(); ++i) { ASSERT_TRUE(a[i].is_same_size(b[i])); ASSERT_TRUE(a[i].allclose(b[i])); } } std::vector run( InterpreterState& interp, const std::vector& inputs) { std::vector stack(inputs.begin(), inputs.end()); interp.run(stack); return fmap(stack, [](const IValue& i) { return i.toTensor(); }); } std::pair runGradient( Gradient& grad_spec, tensor_list& tensors_in, tensor_list& tensor_grads_in) { static const auto as_tensorlist = [](const Stack& stack) { return fmap(stack, [](const IValue& i) { return i.toTensor(); }); }; Code f_code{grad_spec.f}, df_code{grad_spec.df}; InterpreterState f_interpreter{f_code}, df_interpreter{df_code}; auto f_stack = fmap(tensors_in); f_interpreter.run(f_stack); Stack df_stack; df_stack.insert( df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end()); for (auto offset : grad_spec.df_input_captured_inputs) df_stack.push_back(tensors_in[offset]); for (auto offset : grad_spec.df_input_captured_outputs) df_stack.push_back(f_stack[offset]); df_interpreter.run(df_stack); // Outputs of f needs to be sliced f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end()); return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack)); } std::tuple build_lstm_body( Graph& g, Var input, Var hx, Var cx, Var w_ih, Var w_hh) { auto gates = input.mm(w_ih); gates = gates + hx.mm(w_hh); auto outputs = gates.chunk(4, 1); auto ingate = outputs[0]; auto forgetgate = outputs[1]; auto cellgate = outputs[2]; auto outgate = outputs[3]; ingate = ingate.sigmoid(); outgate = outgate.sigmoid(); cellgate = cellgate.tanh(); forgetgate = forgetgate.sigmoid(); auto cy = forgetgate * cx; cy = cy + ingate * cellgate; auto hy = outgate * cy.tanh(); return std::make_tuple(hy, cy); } std::shared_ptr build_lstm() { auto r = std::make_shared(); auto& g = *r; Value* input = g.addInput(); Value* hx = g.addInput(); Value* cx = g.addInput(); Value* w_ih = g.addInput(); Value* w_hh = g.addInput(); Var hy; Var cy; std::tie(hy, cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh); hy.addAsOutput(); cy.addAsOutput(); g.lint(); return r; } at::Tensor t_use(at::Tensor x) { return x; } at::Tensor t_def(at::Tensor x) { return x.t(); } // given the difference of output vs expected tensor, check whether the // difference is within a relative tolerance range. This is a standard way of // matching tensor values upto certain precision bool checkRtol(const at::Tensor& diff, const std::vector inputs) { double maxValue = 0.0; for (auto& tensor : inputs) { maxValue = fmax(tensor.abs().max().item(), maxValue); } return diff.abs().max().item() < 2e-6 * maxValue; } bool almostEqual(const at::Tensor& a, const at::Tensor& b) { return checkRtol(a - b, {a, b}); } bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { return (a - b).abs().max().item() == 0.f; } std::pair lstm( at::Tensor input, at::Tensor hx, at::Tensor cx, at::Tensor w_ih, at::Tensor w_hh) { auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh)); auto chunked_gates = gates.chunk(4, 1); auto ingate = chunked_gates[0]; auto forgetgate = chunked_gates[1]; auto cellgate = chunked_gates[2]; auto outgate = chunked_gates[3]; ingate = ingate.sigmoid(); outgate = outgate.sigmoid(); cellgate = cellgate.tanh(); forgetgate = forgetgate.sigmoid(); auto cy = (forgetgate * cx) + (ingate * cellgate); auto hy = outgate * cy.tanh(); return {hy, cy}; } } // namespace test } // namespace jit } // namespace torch