mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AI Bench] Add support for nlu model
Summary: add support for nlu specific input
Test Plan:
tested
```
buck run aibench:run_bench -- -b aibench/specifications/models/pytorch/fbnet/assistant_mobile_inference.json --platform android/full_jit --framework pytorch --remote --devices SM-G950U-7.0-24
```
make sure it compatible with previous test
```
buck run aibench:run_bench -- -b aibench/specifications/models/pytorch/fbnet/fbnet_mobile_inference.json --platform android/full_jit --framework pytorch --remote --devices SM-G950U-7.0-24
```
```
{
"model": {
"category": "CNN",
"description": "Assistant Mobile Inference",
"files": {
"model": {
"filename": "model.pt1",
"location": "//everstore/GICWmAB2Znbi_mAAAB0P51IPW8UrbllgAAAP/model.pt1",
"md5": "c0f4b29c442bbaeb0007fb0ce513ccb3"
},
"data": {
"filename": "input.txt",
"location": "/home/pengxia/test/input.txt",
"md5": "c0f4b29c442bbaeb0007fb0ce513ccb3"
}
},
"format": "pytorch",
"framework": "pytorch",
"kind": "deployment",
"name": "Assistant Mobile Inference"
},
"tests": [
{
"command": "{program} --model {files.model} --input_dims \"1\" --input_type NLUType --warmup {warmup} --iter {iter} --input_file {files.data} --report_pep true",
"identifier": "{ID}",
"metric": "delay",
"iter": 5,
"warmup": 2,
"log_output": true
}
]
}
```
input.txt
```
what is weather today
what time it is
set a reminder for tomorrow
```
result
https://our.intern.facebook.com/intern/aibench/details/137241352201417
Reviewed By: kimishpatel
Differential Revision: D20300947
fbshipit-source-id: 7c1619541a2e9514a560a9acb9029cfc4669f37a
This commit is contained in:
parent
bcfd348858
commit
91e922a338
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ C10_DEFINE_string(
|
|||
"semicolon to separate the dimension of different "
|
||||
"tensors.");
|
||||
C10_DEFINE_string(input_type, "", "Input type (uint8_t/float)");
|
||||
C10_DEFINE_string(input_file, "", "Input file");
|
||||
C10_DEFINE_bool(
|
||||
print_output,
|
||||
false,
|
||||
|
|
@ -63,6 +65,22 @@ split(char separator, const std::string& string, bool ignore_empty = true) {
|
|||
return pieces;
|
||||
}
|
||||
|
||||
std::vector<std::vector<c10::IValue>> nlu_process(std::string file_path) {
|
||||
std::vector<std::vector<c10::IValue>> nlu_inputs;
|
||||
std::ifstream input_file(FLAGS_input_file);
|
||||
for (std::string line; getline(input_file, line);) {
|
||||
std::vector<c10::IValue> nlu_input;
|
||||
c10::List<std::string> tokens(split(' ', line));
|
||||
nlu_input.push_back(tokens);
|
||||
auto len = torch::jit::IValue(static_cast<int64_t>(tokens.size()));
|
||||
nlu_input.push_back({});
|
||||
nlu_input.push_back(len);
|
||||
nlu_inputs.emplace_back(std::move(nlu_input));
|
||||
std::cout << line << std::endl;
|
||||
}
|
||||
return nlu_inputs;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
c10::SetUsageMessage(
|
||||
"Run speed benchmark for pytorch model.\n"
|
||||
|
|
@ -88,27 +106,32 @@ int main(int argc, char** argv) {
|
|||
input_type_list.size(),
|
||||
"Input dims and type should have the same number of items.");
|
||||
|
||||
std::vector<c10::IValue> inputs;
|
||||
for (size_t i = 0; i < input_dims_list.size(); ++i) {
|
||||
auto input_dims_str = split(',', input_dims_list[i]);
|
||||
std::vector<int64_t> input_dims;
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
if (input_type_list[i] == "float") {
|
||||
inputs.push_back(torch::ones(input_dims, at::ScalarType::Float));
|
||||
} else if (input_type_list[i] == "uint8_t") {
|
||||
inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte));
|
||||
} else if (input_type_list[i] == "int64") {
|
||||
inputs.push_back(torch::ones(input_dims, torch::kI64));
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
std::vector<std::vector<c10::IValue>> inputs;
|
||||
if (input_type_list[0] == "NLUType"){
|
||||
inputs = nlu_process(FLAGS_input_file);
|
||||
} else {
|
||||
inputs.push_back(std::vector<c10::IValue>());
|
||||
for (size_t i = 0; i < input_dims_list.size(); ++i) {
|
||||
auto input_dims_str = split(',', input_dims_list[i]);
|
||||
std::vector<int64_t> input_dims;
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
if (input_type_list[i] == "float") {
|
||||
inputs[0].push_back(torch::ones(input_dims, at::ScalarType::Float));
|
||||
} else if (input_type_list[i] == "uint8_t") {
|
||||
inputs[0].push_back(torch::ones(input_dims, at::ScalarType::Byte));
|
||||
} else if (input_type_list[i] == "int64") {
|
||||
inputs[0].push_back(torch::ones(input_dims, torch::kI64));
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (FLAGS_pytext_len > 0) {
|
||||
auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
|
||||
inputs.push_back(stensor);
|
||||
inputs[0].push_back(stensor);
|
||||
}
|
||||
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
|
|
@ -121,7 +144,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
module.eval();
|
||||
if (FLAGS_print_output) {
|
||||
std::cout << module.forward(inputs) << std::endl;
|
||||
std::cout << module.forward(inputs[0]) << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Starting benchmark." << std::endl;
|
||||
|
|
@ -131,8 +154,10 @@ int main(int argc, char** argv) {
|
|||
"Number of warm up runs should be non negative, provided ",
|
||||
FLAGS_warmup,
|
||||
".");
|
||||
for (int i = 0; i < FLAGS_warmup; ++i) {
|
||||
module.forward(inputs);
|
||||
for (unsigned int i = 0; i < FLAGS_warmup; ++i) {
|
||||
for (const auto& input : inputs) {
|
||||
module.forward(input);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Main runs." << std::endl;
|
||||
|
|
@ -146,7 +171,9 @@ int main(int argc, char** argv) {
|
|||
auto millis = timer.MilliSeconds();
|
||||
for (int i = 0; i < FLAGS_iter; ++i) {
|
||||
auto start = high_resolution_clock::now();
|
||||
module.forward(inputs);
|
||||
for (const std::vector<c10::IValue>& input: inputs) {
|
||||
module.forward(input);
|
||||
}
|
||||
auto stop = high_resolution_clock::now();
|
||||
auto duration = duration_cast<milliseconds>(stop - start);
|
||||
times.push_back(duration.count());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user