pytorch/test/mobile/nnc/test_registry.cpp
Jiakai Liu d82333e92a [pytorch][nnc] protocol classes to persist the context for compiled functions (#56851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56851

This is part of the changes to enable NNC AOT compilation for mobile.
At the end of the ahead-of-time compilation the compiler produces two sets of artifacts:
1. "compiled assembly code" - kernel functions in assembly format optimized for target platforms;
2. "compiled model" - regular TorchScript model that contains serialized parameters (weights/bias/etc) and invokes kernel functions via "handles" (name/version id/input & output specs/etc of the kernel functions).

This PR introduces a set of classes to represent kernel functions (a.k.a "handles"), which can be serialized/deserialized into/from the "compiled model" as an IValue.
Also introduces APIs to register/look-up "compiled assembly code".
ghstack-source-id: 128285802

Test Plan:
- unit tests
- for FB build environment:
buck test //caffe2/test/mobile/nnc:mobile_nnc

Reviewed By: kimishpatel, raziel

Differential Revision: D27921866

fbshipit-source-id: 4c2a4d8a4d072fc259416ae674b3b494f0ca56f3
2021-05-06 03:24:15 -07:00

38 lines
861 B
C++

#include <gtest/gtest.h>
#include <torch/csrc/jit/mobile/nnc/registry.h>
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
extern "C" {
int generated_asm_kernel_foo(void**) {
return 1;
}
int generated_asm_kernel_bar(void**) {
return 2;
}
} // extern "C"
REGISTER_NNC_KERNEL("foo:v1:VERTOKEN", generated_asm_kernel_foo)
REGISTER_NNC_KERNEL("bar:v1:VERTOKEN", generated_asm_kernel_bar)
TEST(MobileNNCRegistryTest, FindAndRun) {
auto foo_kernel = registry::get_nnc_kernel("foo:v1:VERTOKEN");
EXPECT_EQ(foo_kernel->execute(nullptr), 1);
auto bar_kernel = registry::get_nnc_kernel("bar:v1:VERTOKEN");
EXPECT_EQ(bar_kernel->execute(nullptr), 2);
}
TEST(MobileNNCRegistryTest, NoKernel) {
EXPECT_EQ(registry::has_nnc_kernel("missing"), false);
}
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch