mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51342 There is a subtle bug with the MemoryPlanner with regard to view ops with out variant. ``` def forward(self, a: Tensor, shape: List[int]): b = a.reshape(shape) return b + b ``` In this case, if we replace reshape with the out variant, b would be managed by the MemoryPlanner and the storage of its output would have been set to nullptr right after inference by the MemoryPlanner if opts.cleanup_activations is true. Because b is a view of a, the storage of a is also set to nullptr, and this violates the API which promises that a is const. To fix this bug, I changed the MemoryPlanner so that it puts b in the unmanaged part. Test Plan: Add unit test to enforce the constness of inputs ``` buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest ``` Reviewed By: ajyu Differential Revision: D26144203 fbshipit-source-id: 2dbacccf7685d0fe0f0b1195166e0510b2069fe3
320 lines
10 KiB
C++
320 lines
10 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/runtime/static/fusion.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include "deep_wide_pt.h"
|
|
#include "test_scripts.h"
|
|
|
|
using namespace caffe2;
|
|
using namespace torch;
|
|
using namespace torch::jit;
|
|
using c10::IValue;
|
|
|
|
namespace {
|
|
static at::Tensor getTensor(const at::IValue& ival) {
|
|
if (ival.isTensor()) {
|
|
return ival.toTensor();
|
|
} else if (ival.isTensorList()) {
|
|
auto tensor_vec = ival.toTensorVector();
|
|
TORCH_CHECK(tensor_vec.size() == 1);
|
|
return tensor_vec[0];
|
|
} else if (ival.isTuple()) {
|
|
auto tuple = ival.toTuple();
|
|
auto ivalue_vec = tuple->elements();
|
|
TORCH_CHECK(ivalue_vec.size() == 1);
|
|
return ivalue_vec[0].toTensor();
|
|
} else {
|
|
CAFFE_THROW("Unknown input IValue");
|
|
}
|
|
}
|
|
|
|
void compareTensorLists(
|
|
const std::vector<IValue>& l, /* values */
|
|
const std::vector<IValue>& r /* expects */) {
|
|
EXPECT_TRUE(l.size() == r.size());
|
|
for (int i = 0; i < l.size(); ++i) {
|
|
ASSERT_TRUE(l[i].isTensor());
|
|
ASSERT_TRUE(r[i].isTensor());
|
|
LOG(INFO) << "output " << i << ": \n" << l[i] << std::endl;
|
|
LOG(INFO) << "expect " << i << ": \n" << r[i] << std::endl;
|
|
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
|
|
}
|
|
}
|
|
|
|
void compareTensorLists(
|
|
const std::vector<at::Tensor>& l, /* values */
|
|
const std::vector<at::Tensor>& r /* expects */) {
|
|
EXPECT_TRUE(l.size() == r.size());
|
|
for (int i = 0; i < l.size(); ++i) {
|
|
LOG(INFO) << "output " << i << ": \n" << l[i] << std::endl;
|
|
LOG(INFO) << "expect " << i << ": \n" << r[i] << std::endl;
|
|
EXPECT_TRUE(l[i].equal(r[i]));
|
|
}
|
|
}
|
|
|
|
// Given a model/function in jit script, run the model/function
|
|
// with the jit interpreter and static runtime, and compare the results
|
|
void testStaticRuntime(
|
|
const std::string& jit_script,
|
|
const std::vector<IValue>& args) {
|
|
script::Module module("module");
|
|
module.define(jit_script);
|
|
|
|
std::vector<IValue> args_tensors, args_copy;
|
|
for (const auto& ival : args) {
|
|
if (ival.isTensor()) {
|
|
args_tensors.emplace_back(ival);
|
|
const at::Tensor& t = ival.toTensor();
|
|
args_copy.emplace_back(t.clone());
|
|
}
|
|
}
|
|
|
|
auto expect = module.forward(args);
|
|
|
|
StaticRuntime runtime(module);
|
|
auto actual = runtime.run(args, {});
|
|
|
|
if (expect.isTuple()) {
|
|
compareTensorLists(
|
|
expect.toTuple()->elements(), actual.toTuple()->elements());
|
|
} else if (expect.isList()) {
|
|
compareTensorLists(expect.toTensorVector(), actual.toTensorVector());
|
|
} else {
|
|
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
|
|
}
|
|
// make sure inputs were not modified
|
|
compareTensorLists(args_tensors, args_copy);
|
|
}
|
|
} // namespace
|
|
|
|
TEST(StaticRuntime, IndividualOps_Binary) {
|
|
auto a = at::randn({2, 3});
|
|
auto b = at::ones({2, 3});
|
|
|
|
std::vector<IValue> args{a, b};
|
|
|
|
testStaticRuntime(add_script, args);
|
|
testStaticRuntime(list_construct_script, args);
|
|
testStaticRuntime(list_unpack_script, args);
|
|
testStaticRuntime(tuple_construct_script, args);
|
|
}
|
|
|
|
TEST(StaticRuntime, IndividualOps_Reshape) {
|
|
auto a = at::randn({2, 3});
|
|
auto b = std::vector<int64_t>({3, 2});
|
|
std::vector<IValue> args{a, b};
|
|
|
|
testStaticRuntime(reshape_script_1, args);
|
|
testStaticRuntime(reshape_script_2, args);
|
|
}
|
|
|
|
TEST(StaticRuntime, IndividualOps_flatten) {
|
|
auto test_flatten =
|
|
[](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
|
|
auto a = at::randn(shape);
|
|
std::vector<IValue> args{a, start_dim, end_dim};
|
|
testStaticRuntime(flatten_script_1, args);
|
|
if (shape.size() > 2) {
|
|
testStaticRuntime(flatten_script_2, args);
|
|
}
|
|
};
|
|
|
|
test_flatten({2, 3}, 0, 1);
|
|
test_flatten({2, 1, 3}, 1, 2);
|
|
test_flatten({0, 1, 3, 0}, 1, 2);
|
|
test_flatten({2, 3}, 1, 1);
|
|
test_flatten({}, 0, 0);
|
|
}
|
|
|
|
TEST(StaticRuntime, LongModel) {
|
|
torch::jit::Module mod = getLongScriptModel();
|
|
auto a = torch::randn({2, 2});
|
|
auto b = torch::randn({2, 2});
|
|
auto c = torch::randn({2, 2});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> input_ivalues({a, b, c});
|
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
|
|
|
// run static runtime
|
|
std::vector<at::Tensor> input_tensors({a, b, c});
|
|
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
|
torch::jit::StaticRuntime runtime(g);
|
|
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
|
|
TEST(StaticRuntime, TrivialModel) {
|
|
torch::jit::Module mod = getTrivialScriptModel();
|
|
auto a = torch::randn({2, 2});
|
|
auto b = torch::randn({2, 2});
|
|
auto c = torch::randn({2, 2});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> input_ivalues({a, b, c});
|
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
|
|
|
// run static runtime
|
|
std::vector<at::Tensor> input_tensors({a, b, c});
|
|
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
|
torch::jit::StaticRuntime runtime(g);
|
|
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
|
|
TEST(StaticRuntime, LeakyReLU) {
|
|
torch::jit::Module mod = getLeakyReLUConstScriptModel();
|
|
auto inputs = torch::randn({2, 2});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> input_ivalues({inputs});
|
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
|
|
|
// run static runtime
|
|
std::vector<at::Tensor> input_tensors({inputs});
|
|
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
|
torch::jit::StaticRuntime runtime(g);
|
|
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
|
|
TEST(StaticRuntime, DeepWide) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
torch::jit::Module mod = getDeepAndWideSciptModel();
|
|
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
|
torch::jit::StaticRuntime runtime(g);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
auto output_1 = getTensor(mod.forward(inputs));
|
|
|
|
// run static runtime
|
|
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
|
|
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, KWargsAPI_1) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
auto module = getDeepAndWideSciptModel();
|
|
torch::jit::StaticRuntime runtime(module);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
at::Tensor output_1 = getTensor(module.forward(inputs));
|
|
|
|
// run static runtime
|
|
at::Tensor output_2 = getTensor(runtime.run(inputs, {}));
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, KWargsAPI_2) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
auto module = getDeepAndWideSciptModel();
|
|
auto g = torch::jit::PrepareForStaticRuntime(module);
|
|
torch::jit::StaticRuntime runtime(module);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> args({ad_emb_packed, user_emb, wide});
|
|
at::Tensor output_1 = getTensor(module.forward(args));
|
|
|
|
std::unordered_map<std::string, c10::IValue> kwargs(
|
|
{{"ad_emb_packed", ad_emb_packed},
|
|
{"user_emb", user_emb},
|
|
{"wide", wide}});
|
|
|
|
// run static runtime
|
|
at::Tensor output_2 = getTensor(runtime.run({}, kwargs));
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, CleanUpMemory) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
torch::jit::Module mod = getDeepAndWideSciptModel();
|
|
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
|
|
|
for (auto cleanup_memory : {true, false}) {
|
|
for (auto enable_out_variant : {true, false}) {
|
|
VLOG(1) << "cleanup_memory: " << cleanup_memory
|
|
<< ", enable_out_variant: " << enable_out_variant;
|
|
torch::jit::StaticRuntimeOptions opts{cleanup_memory, enable_out_variant};
|
|
torch::jit::StaticRuntime runtime(g, opts);
|
|
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
auto output_1 = getTensor(mod.forward(inputs));
|
|
|
|
// run static runtime
|
|
std::vector<at::Tensor> input_tensors(
|
|
{ad_emb_packed, user_emb, wide});
|
|
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(StaticRuntime, FusionPass) {
|
|
const int embedding_size = 32;
|
|
const int num_features = 50;
|
|
for (int batch_size : {1, 8, 32}) {
|
|
for (int i = 0; i < 2; ++i) {
|
|
torch::jit::Module module = getDeepAndWideSciptModel();
|
|
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
|
auto wide = torch::randn({batch_size, num_features});
|
|
|
|
// run jit graph executor
|
|
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
|
|
auto output_1 = getTensor(module.forward(inputs));
|
|
|
|
Method method = module.get_method("forward");
|
|
auto graph = method.graph();
|
|
fuseStaticSubgraphs(graph);
|
|
bool hit = false;
|
|
for (const auto& n : module.get_method("forward").graph()->nodes()) {
|
|
if (n->kind() == torch::jit::prim::StaticSubgraph) {
|
|
hit = true;
|
|
}
|
|
}
|
|
EXPECT_TRUE(hit);
|
|
auto output_2 = getTensor(module.forward(inputs));
|
|
EXPECT_TRUE(output_1.equal(output_2));
|
|
}
|
|
}
|
|
}
|