pytorch/test/cpp/jit/test_utils.h
Martin Yuan b2520ab3dc Add a demo backend with compiler (#52603)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52603

This PR introduced a backend with minimum compilation capability to the to_<backend> flow. The targets are:

- Demonstrate the end-to-end flow with adding a backend -> compilation -> runtime
- How the backend compilation errors be surfaced to the user, with the original model's source code information. (C++ only in this PR. Python APIs will be demonstrated in a following PR.)

Changes:

- Compilation

1. A backend with minimum compilation features, "backend_with_compiler_demo" is added.
2. The compilation happens AOT in the ```pre_process``` function registered to this backend.
3. Compiled results are stored in a string blob for each method. They are serialized to the lowered module with ```__get_state__``` function.
4. Error message with model source code is thrown, for features not handled by the backend compiler.

- Runtime

1. The compiled blob is loaded in ```__set_state__``` method.
2. The ```compile``` function of the backend pass through the AOT compiled blob. (TODO: parsing the blob to the format that the backend can understand can happen here.)
3. The ```execute``` function of the backend executes the specified method (handle).

Test Plan:
- ```BackendTest.TestCompiler```: the C++ end-to-end demonstration on a supported model. After compilation and running, the lowered model produces the same result as the original torchscript model.
- ```BackendTest.TestCompilerNotSupport```: Demonstrate the error message from the AOT compilation for a feature not supported from the input module. The error message looks like:

```
"The node of aten::mul is not supported in this compiler. Source code:   File "<string>", line 3

    def forward(self, x, h):
        return x * h
               ~~~~~ <--- HERE
```

Reviewed By: raziel

Differential Revision: D26593968

Pulled By: iseeyuan

fbshipit-source-id: 8f264f60a0470e9f07e36fdeccbf17da6c1d7cd7
2021-02-26 11:53:34 -08:00

60 lines
2.0 KiB
C++

#pragma once
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
namespace torch {
namespace jit {
using tensor_list = std::vector<at::Tensor>;
using namespace torch::autograd;
// work around the fact that variable_tensor_list doesn't duplicate all
// of std::vector's constructors.
// most constructors are never used in the implementation, just in our tests.
Stack createStack(std::vector<at::Tensor>&& list);
void assertAllClose(const tensor_list& a, const tensor_list& b);
std::vector<at::Tensor> run(
InterpreterState& interp,
const std::vector<at::Tensor>& inputs);
std::pair<tensor_list, tensor_list> runGradient(
Gradient& grad_spec,
tensor_list& tensors_in,
tensor_list& tensor_grads_in);
std::shared_ptr<Graph> build_lstm();
at::Tensor t_use(at::Tensor x);
at::Tensor t_def(at::Tensor x);
// given the difference of output vs expected tensor, check whether the
// difference is within a relative tolerance range. This is a standard way of
// matching tensor values up to certain precision
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs);
bool almostEqual(const at::Tensor& a, const at::Tensor& b);
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
std::pair<at::Tensor, at::Tensor> lstm(
at::Tensor input,
at::Tensor hx,
at::Tensor cx,
at::Tensor w_ih,
at::Tensor w_hh);
} // namespace jit
} // namespace torch