pytorch/test/cpp/jit/test_backend.cpp
Raziel Alvarez Guevara c5cd993add Adds a bool is_available() method to the backend contract (#53068)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53068

Adds a ```bool is_available()``` method to the backend contract: it returns ```true``` if ```compile()``` and ```execute()``` can be called; ```false``` otherwise.

It is used to implement the following changes in the ```LoweredModule```:
* ```compile()``` in ```__setstate__``` will run if ```is_available()```, else ```__setstate__``` throws an exception (“Backend not available.”).
* ```compile()``` at ```LoweredModule``` creation will run if ```is_available()```, else a WARNING will be thrown.
* ```execute()``` will only be executed if ```is_available()``` returns true; else throws an exception (“Backend not available.”).

The goal of these changes is to ensure we have a well defined behaviour for the different combinations of backend availability on-host and on-target.

More specifically, backends may have different capabilities to compile and/or execute the Module, depending whether this happens on-host (i.e. where the program is being written) or on-target (where the program is being executed).

First of all, we know that "preprocess" always takes place, and that only happens on-host at creation time. So, we can assume that any compilation is needed/possible on-host then all of it could be pushed here.

Overall, we want to ensure the following:

**On host**

| compile | execute | Outcome |
| -- | -- | -- |
| No | No | On module creation, LoweredModule is generated, with a warning  (since compilation and execution can still take place on-target). On module load, throws an exception (since execution is not possible). |
| No | Yes | This configuration should not be possible. This assumes the full compiler is not available, even if some work was done in preprocess the program cannot be finalized for execution. |
| Yes | No | In this case, the expectation would be for is_available() to return false, and compilation logic to move into preprocess. |
| Yes | Yes | All good. This is the only case that is_available() should return true. |

**On target**

| compile | execute | Outcome |
| -- | -- | -- |
| No | No | Loading the LoweredModule throws an exception. Since execution is not possible. |
| No | Yes | Basically this is another instance of Yes/Yes: compilation per se may not be possible on device, which means compile() can be called without issue but it is a no-op, and thus is_available should return true. Consequently, loading the LoweredModule: Succeeds, if the preprocessed module is ready for execution. Fails with exception otherwise. |
| Yes | No | This configuration should not be possible. Just putting here for completeness. |
| Yes | Yes | All good. This, along with No/Yes case (because compilation is assumed to have happened on-host, so it's just another instance of Yes/Yes), are the cases where is_available() should return true. |

**Refactoring existing code**
This change also updates other backends (Glow) code, to implement the is_available() method to have the same behaviour as before this change (i.e. always available).

This should not cause backward incompatibilities with already saved models since we're adding a new method to the PyTorchBackendInterface.
Models saved with the old interface that didn't have is_available() will still find the other 2 methods in the bound object (i.e. compile and execute), and the saved LoweredModule logic will be the old one.

**Future**
We plan to use is_available() to implement support for fallback to the PyTorch interpreter.
ghstack-source-id: 123498571

Test Plan: Added C++ (test_backend.cpp) and Python (test_backends.py) tests to validate the exceptions.

Reviewed By: jackm321, spaugh, iseeyuan

Differential Revision: D26615833

fbshipit-source-id: 562e8b11db25784348b5f86bbc4179aedf15e0d3
2021-03-10 00:24:16 -08:00

165 lines
5.8 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/torch.h>
// Tests go in torch::jit
namespace torch {
namespace jit {
TEST(BackendTest, ToBackend) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
)");
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs).toTuple()->elements();
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"test_backend", m, compile_spec, any_dict_ty);
// lowered module code:
/*
class test_backendLoweredModule(Module):
__parameters__ = []
__buffers__ = []
__processed_module : Any
__method_compile_spec : Dict[str, Any]
__backend : __torch__.torch.classes.__backends__.test_backend
__handles : Dict[str, Any]
def __create_backend(self: torch.jit.test_backendLoweredModule) -> None:
_0 =
__torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend)
_1 = (_0).__init__()
self.__backend = _0
return None
def __getstate__(self: torch.jit.test_backendLoweredModule) ->
Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec,
self.__processed_module) return _2 def __setstate__(self:
torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) ->
None: self.__method_compile_spec = (state)[0] self.__processed_module =
(state)[1] _3 = (self).__create_backend() _4 =
(self.__backend).compile(self.__processed_module,
self.__method_compile_spec, ) self.__handles = _4 return None def
forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) ->
Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs =
annotate(List[Any], [x, h]) _6 =
(self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7,
_8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7)
else:
ops.prim.RaiseException("AssertionError: ")
_10 = _5
_11 = isinstance(_8, Tensor)
if _11:
_12 = unchecked_cast(Tensor, _8)
else:
ops.prim.RaiseException("AssertionError: ")
_12 = _5
return (_10, _12)
*/
auto res = lm.forward(inputs).toTuple()->elements();
AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
}
TEST(BackendTest, ToBackendNotAvailable) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
)");
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs).toTuple()->elements();
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// Produce lowered module (backend not available).
// Exception is not thrown at this point.
auto lm = torch::jit::detail::codegen_backend_module(
"test_backend_unavailable", m, compile_spec, any_dict_ty);
// Validate exception is thrown when trying to execute and
// the backend is not available.
ASSERT_THROWS_WITH_MESSAGE(
lm.forward(inputs).toTuple()->elements(), "Backend is not available.");
}
TEST(BackendTest, TestCompiler) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return x + h
)");
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs);
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
auto res = lm.forward(inputs);
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
std::stringstream ss;
lm._save_for_mobile(ss);
auto mlm = _load_for_mobile(ss);
auto mres = mlm.forward(inputs);
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
}
TEST(BackendTest, TestCompilerNotSupport) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return x * h
)");
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
ASSERT_THROWS_WITH_MESSAGE(
torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty),
"The node of aten::mul is not supported in this compiler. Source code:");
}
} // namespace jit
} // namespace torch