pytorch/torch/csrc/jit/mobile/nnc/backend.cpp
Jiakai Liu b4a098f1fb [pytorch][nnc] mobile nnc backend skeleton (#56852)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56852

This is part of the changes to enable NNC AOT compilation for mobile.
It introduced a custom backend for NNC, which uses the components defined in the stacked PRs to load and execute a NNC-compiled model.
ghstack-source-id: 128285801

Test Plan:
- On X86 host:
```
buck build //xplat/caffe2/fb/lite_predictor:lite_predictor_nnc
buck-out/last/lite_predictor_nnc --model xplat/pytorch_models/build/pytorch_dev_linear/v1/nnc/compiled.pt --print_output true --input_dims '4,4' --input_type float
```
- On Android:
```
buck build fbsource//fbandroid/mode/gnustl //xplat/caffe2/fb/lite_predictor:lite_predictor_nncAndroid#android-armv7
adb push buck-out/last/lite_predictor_nncAndroid#android-armv7 /data/local/tmp
adb push xplat/pytorch_models/build/pytorch_dev_linear/v1/nnc/compiled.pt /data/local/tmp
adb shell 'cd /data/local/tmp; ./lite_predictor_nncAndroid\#android-armv7 --model compiled.pt --print_output true --input_dims "4,4" --input_type float'
```

Reviewed By: kimishpatel, raziel

Differential Revision: D27897153

fbshipit-source-id: 8e039089d1602782582747adfd75b31496b525ca
2021-05-06 03:25:18 -07:00

62 lines
1.6 KiB
C++

#include <vector>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/mobile/nnc/context.h>
#include <torch/script.h>
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
class NNCBackend : public PyTorchBackendInterface {
public:
explicit NNCBackend() = default;
~NNCBackend() override = default;
bool is_available() override {
return true;
}
c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
cu_ = std::make_shared<CompilationUnit>(processed);
// Input method_compile_spec:
// Key: method name
// Value: compile spec for each method
// Output:
// Key: method name
// Value: a backend handle for each method
auto spec =
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
auto handles = c10::Dict<std::string, std::string>();
for (const auto& it : spec) {
// The handle for each method is the key (method name) itself.
handles.insert(it.key(), it.key());
}
return c10::impl::toGenericDict(handles);
}
c10::impl::GenericList execute(
c10::IValue handle,
c10::impl::GenericList inputs) override {
const std::string& method_name = handle.toStringRef();
auto function_name = c10::QualifiedName(method_name);
return cu_->run(function_name, inputs);
}
private:
std::shared_ptr<CompilationUnit> cu_;
};
namespace {
static const auto cls = torch::jit::backend<NNCBackend>("nnc");
} // namespace
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch