mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: fbcode/sigmoid/core/common -> fbcode/caffe2/torch/nativert/common Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72 Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/cpp/nativert:pytree_test ``` OSS CI Rollback Plan: Differential Revision: D75965059 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155136 Approved by: https://github.com/zhxchen17, https://github.com/XuehaiPan, https://github.com/zou3519
1151 lines
30 KiB
C++
1151 lines
30 KiB
C++
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <fmt/format.h>
|
|
|
|
#include <c10/util/Enumerate.h>
|
|
#include <torch/nativert/detail/ITree.h>
|
|
|
|
namespace torch::nativert::detail {
|
|
|
|
using torch::nativert::Graph;
|
|
using torch::nativert::stringToGraph;
|
|
using torch::nativert::Type;
|
|
using torch::nativert::Value;
|
|
|
|
std::pair<std::unique_ptr<Graph>, std::vector<const Value*>> makeValues(
|
|
int count) {
|
|
auto graph = Graph::createGraph();
|
|
std::vector<const Value*> values;
|
|
|
|
for (int i = 0; i < count; i++) {
|
|
std::string name = fmt::format("v{}", i);
|
|
Value* value = graph->addValue(name, Type::Kind::None, nullptr);
|
|
values.push_back(value);
|
|
}
|
|
|
|
return std::make_pair(std::move(graph), values);
|
|
}
|
|
|
|
TEST(ITreeTest, Unflatten) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}, (10,), {"11": 12}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "torch.fx.immutable_collections.immutable_list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "torch.fx.immutable_collections.immutable_dict",
|
|
"context": "[\"11\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
|
|
auto [graph, valuePtrs] = makeValues(8);
|
|
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
std::vector<c10::IValue> flats = {
|
|
c10::IValue(0),
|
|
c10::IValue(1),
|
|
c10::IValue(2),
|
|
c10::IValue(7),
|
|
c10::IValue(8),
|
|
c10::IValue(9),
|
|
c10::IValue(10),
|
|
c10::IValue(12),
|
|
};
|
|
auto itree = itreeUnflatten(flats, spec);
|
|
EXPECT_TRUE(itree.isList());
|
|
EXPECT_EQ(itree.toListRef().size(), 5);
|
|
|
|
EXPECT_TRUE(itree.toListRef().at(0).isTuple());
|
|
EXPECT_EQ(itree.toListRef().at(0).toTupleRef().elements()[0], c10::IValue(0));
|
|
EXPECT_EQ(itree.toListRef().at(0).toTupleRef().elements()[1], c10::IValue(1));
|
|
|
|
EXPECT_TRUE(itree.toListRef().at(1).isInt());
|
|
EXPECT_EQ(itree.toListRef().at(1), c10::IValue(2));
|
|
|
|
EXPECT_TRUE(itree.toListRef().at(2).isGenericDict());
|
|
EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("4"), c10::IValue(7));
|
|
EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("5"), c10::IValue(8));
|
|
EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("6"), c10::IValue(9));
|
|
|
|
EXPECT_TRUE(itree.toListRef().at(3).isList());
|
|
EXPECT_EQ(itree.toListRef().at(3).toListRef().at(0), c10::IValue(10));
|
|
|
|
EXPECT_TRUE(itree.toListRef().at(4).isGenericDict());
|
|
EXPECT_EQ(itree.toListRef().at(4).toGenericDict().at("11"), c10::IValue(12));
|
|
|
|
const auto flattened = itreeFlatten(itree, spec);
|
|
EXPECT_EQ(flattened.size(), flats.size());
|
|
for (size_t i = 0; i < flattened.size(); i++) {
|
|
EXPECT_EQ(flattened[i], flats[i]);
|
|
}
|
|
}
|
|
|
|
TEST(ITreeTest, NoVersion) {
|
|
auto jsonSpec = R"(
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
)";
|
|
|
|
auto [graph, valuePtrs] = makeValues(2);
|
|
EXPECT_THROW({ itreeSpecLoads(jsonSpec, valuePtrs); }, std::exception);
|
|
}
|
|
|
|
TEST(ITreeTest, NoField) {
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
|
|
auto [graph, valuePtrs] = makeValues(3);
|
|
EXPECT_THROW(itreeSpecLoads(jsonSpec, valuePtrs), std::exception);
|
|
}
|
|
|
|
TEST(ITreeTest, NoContext) {
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(3);
|
|
auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
|
|
std::vector<c10::IValue> flats = {
|
|
c10::IValue(7),
|
|
c10::IValue(8),
|
|
c10::IValue(9),
|
|
};
|
|
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
|
|
}
|
|
|
|
TEST(ITreeTest, TooManyContext) {
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\", \"10\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
|
|
auto [graph, valuePtrs] = makeValues(3);
|
|
auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
|
|
std::vector<c10::IValue> flats = {
|
|
c10::IValue(7),
|
|
c10::IValue(8),
|
|
c10::IValue(9),
|
|
};
|
|
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
|
|
}
|
|
|
|
TEST(ITreeTest, DoubleRegister) {
|
|
EXPECT_THROW(
|
|
{ registerPytreeNode("builtins.dict", NodeDef{}); }, std::exception);
|
|
}
|
|
|
|
TEST(ITreeTest, NotEnoughUnflatten) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(6);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
std::vector<c10::IValue> flats = {
|
|
c10::IValue(0),
|
|
c10::IValue(1),
|
|
c10::IValue(2),
|
|
c10::IValue(7),
|
|
};
|
|
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
|
|
}
|
|
|
|
TEST(ITreeTest, TooManyUnflatten) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(6);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
std::vector<c10::IValue> flats = {
|
|
c10::IValue(0),
|
|
c10::IValue(1),
|
|
c10::IValue(2),
|
|
c10::IValue(7),
|
|
c10::IValue(0),
|
|
c10::IValue(1),
|
|
c10::IValue(2),
|
|
c10::IValue(7),
|
|
c10::IValue(0),
|
|
c10::IValue(1),
|
|
c10::IValue(2),
|
|
c10::IValue(7),
|
|
};
|
|
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
|
|
}
|
|
|
|
TEST(ITreeTest, Flatten) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}, (10,), {"11": 12}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "torch.fx.immutable_collections.immutable_list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "torch.fx.immutable_collections.immutable_dict",
|
|
"context": "[\"11\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(8);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)});
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict.insert("4", c10::IValue(7));
|
|
dict.insert("5", c10::IValue(8));
|
|
dict.insert("6", c10::IValue(9));
|
|
c10::List<c10::IValue> ilist(c10::AnyType::get());
|
|
ilist.push_back(c10::IValue(10));
|
|
c10::Dict<c10::IValue, c10::IValue> idict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
idict.insert("11", c10::IValue(12));
|
|
c10::List<c10::IValue> list(c10::AnyType::get());
|
|
list.push_back(std::move(tup));
|
|
list.push_back(c10::IValue(2));
|
|
list.push_back(std::move(dict));
|
|
list.push_back(std::move(ilist));
|
|
list.push_back(std::move(idict));
|
|
auto flats = itreeFlatten(c10::IValue{list}, spec);
|
|
std::vector<c10::IValue> expected = {
|
|
c10::IValue(0),
|
|
c10::IValue(1),
|
|
c10::IValue(2),
|
|
c10::IValue(7),
|
|
c10::IValue(8),
|
|
c10::IValue(9),
|
|
c10::IValue(10),
|
|
c10::IValue(12),
|
|
};
|
|
for (const auto& [i, flat] : c10::enumerate(flats)) {
|
|
EXPECT_EQ(flat, expected.at(i));
|
|
}
|
|
}
|
|
|
|
TEST(ITreeTest, IValueApplyFromArgs) {
|
|
// inputSpec for testing is generated from E2ETestModelWithNestedDictInput
|
|
/*
|
|
args = (
|
|
{
|
|
"a": (
|
|
torch.rand(4, 4),
|
|
{
|
|
123: (torch.rand(4, 4), torch.rand(4, 4)),
|
|
234: (torch.rand(4, 4), torch.rand(4, 4)),
|
|
},
|
|
),
|
|
"b": (
|
|
torch.rand(4, 4),
|
|
{
|
|
345: (torch.rand(4, 4), torch.rand(4, 4)),
|
|
456: (torch.rand(4, 4), torch.rand(4, 4)),
|
|
},
|
|
),
|
|
},
|
|
)*/
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"a\", \"b\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[123, 234]",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[345, 456]",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[]",
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
|
|
auto tup_a1_123 =
|
|
c10::ivalue::Tuple::create({c10::IValue(1), c10::IValue(2)});
|
|
auto tup_a1_234 =
|
|
c10::ivalue::Tuple::create({c10::IValue(3), c10::IValue(4)});
|
|
c10::Dict<c10::IValue, c10::IValue> dict_a1(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict_a1.insert(123, tup_a1_123);
|
|
dict_a1.insert(234, tup_a1_234);
|
|
auto tup_a =
|
|
c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(dict_a1)});
|
|
|
|
auto tup_b1_345 =
|
|
c10::ivalue::Tuple::create({c10::IValue(6), c10::IValue(7)});
|
|
auto tup_b1_456 =
|
|
c10::ivalue::Tuple::create({c10::IValue(8), c10::IValue(9)});
|
|
c10::Dict<c10::IValue, c10::IValue> dict_b1(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict_b1.insert(345, tup_b1_345);
|
|
dict_b1.insert(456, tup_b1_456);
|
|
auto tup_b =
|
|
c10::ivalue::Tuple::create({c10::IValue(5), c10::IValue(dict_b1)});
|
|
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict.insert("a", tup_a);
|
|
dict.insert("b", tup_b);
|
|
std::vector<c10::IValue> args = {c10::IValue(dict)};
|
|
|
|
for (int usedIdx = 0; usedIdx < 10; usedIdx++) {
|
|
std::vector<bool> isUsed(10, false);
|
|
isUsed[usedIdx] = true;
|
|
std::stringstream ss;
|
|
for (int i = 0; i < 10; ++i) {
|
|
if (isUsed[i]) {
|
|
ss << fmt::format("%o1 = aten.foo(a=%a{})\n", i);
|
|
}
|
|
}
|
|
std::string source = fmt::format(
|
|
R"(graph(%a0, %a1, %a2, %a3, %a4, %a5, %a6, %a7, %a8, %a9):
|
|
{}
|
|
return(%o1)
|
|
)",
|
|
ss.str());
|
|
|
|
auto graph = stringToGraph(source);
|
|
std::vector<const Value*> userInputs(
|
|
graph->userInputs().begin(), graph->userInputs().end());
|
|
|
|
const auto spec = itreeSpecLoads(jsonSpec, userInputs);
|
|
|
|
std::vector<int> visited;
|
|
auto fn = [&](const c10::IValue& leaf, const Value* value) {
|
|
visited.push_back(value->id());
|
|
};
|
|
ivalueApplyFromArgs(fn, args, {}, spec);
|
|
|
|
EXPECT_EQ(visited.size(), 1);
|
|
EXPECT_EQ(visited[0], usedIdx);
|
|
}
|
|
}
|
|
|
|
TEST(ITreeTest, UnmatchedFlattenType) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(6);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)});
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict.insert("4", c10::IValue(7));
|
|
dict.insert("5", c10::IValue(8));
|
|
dict.insert("6", c10::IValue(9));
|
|
EXPECT_THROW(
|
|
{ itreeFlatten(c10::IValue{std::move(dict)}, spec); }, std::exception);
|
|
}
|
|
|
|
TEST(ITreeTest, UnmatchedDictFlatten) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(6);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)});
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict.insert("4", c10::IValue(7));
|
|
dict.insert("5", c10::IValue(8));
|
|
dict.insert("100", c10::IValue(8));
|
|
dict.insert("101", c10::IValue(8));
|
|
c10::List<c10::IValue> list(c10::AnyType::get());
|
|
list.push_back(std::move(tup));
|
|
list.push_back(c10::IValue(2));
|
|
list.push_back(std::move(dict));
|
|
ASSERT_DEATH(
|
|
{ itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed");
|
|
}
|
|
|
|
TEST(ITreeTest, DictFlattenTest) {
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(3);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
// allow dict.size < context
|
|
// test dict.size=2 , context,size=3,
|
|
dict.insert("4", c10::IValue(7));
|
|
dict.insert("5", c10::IValue(8));
|
|
c10::List<c10::IValue> list(c10::AnyType::get());
|
|
list.push_back(std::move(dict));
|
|
itreeFlatten(c10::IValue{std::move(list)}, spec);
|
|
}
|
|
|
|
TEST(ITreeTest, UnmatchedTupleFlatten) {
|
|
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\", \"6\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(6);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
auto tup = c10::ivalue::Tuple::create({c10::IValue(0)});
|
|
c10::Dict<c10::IValue, c10::IValue> dict(
|
|
c10::StringType::get(), c10::AnyType::get());
|
|
dict.insert("4", c10::IValue(7));
|
|
dict.insert("5", c10::IValue(8));
|
|
dict.insert("6", c10::IValue(8));
|
|
c10::List<c10::IValue> list(c10::AnyType::get());
|
|
list.push_back(std::move(tup));
|
|
list.push_back(c10::IValue(2));
|
|
list.push_back(std::move(dict));
|
|
ASSERT_DEATH(
|
|
{ itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed");
|
|
}
|
|
|
|
TEST(ITreeTest, ToAtenType) {
|
|
// Original data: ((0, 1), 2, {"4": 7, "5": 8}, [10], {6: 9})
|
|
auto jsonSpec = R"(
|
|
[
|
|
1,
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": "builtins.tuple",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[\"4\", \"5\"]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
},
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "builtins.list",
|
|
"context": "null",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"type": "builtins.dict",
|
|
"context": "[6]",
|
|
"children_spec": [
|
|
{
|
|
"type": null,
|
|
"context": null,
|
|
"children_spec": []
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
]
|
|
)";
|
|
auto [graph, valuePtrs] = makeValues(7);
|
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
|
auto atenType = spec.toAtenType();
|
|
|
|
// Root level is tuple.
|
|
EXPECT_EQ(atenType->kind(), c10::TypeKind::TupleType);
|
|
const c10::TupleType& rootType = atenType->expectRef<c10::TupleType>();
|
|
EXPECT_EQ(rootType.elements().size(), 5);
|
|
|
|
at::TypePtr elementType = rootType.elements()[0];
|
|
EXPECT_EQ(elementType->kind(), c10::TypeKind::TupleType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::TupleType>().elements()[0]->kind(),
|
|
c10::TypeKind::AnyType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::TupleType>().elements()[1]->kind(),
|
|
c10::TypeKind::AnyType);
|
|
|
|
elementType = rootType.elements()[1];
|
|
EXPECT_EQ(elementType->kind(), c10::TypeKind::AnyType);
|
|
|
|
elementType = rootType.elements()[2];
|
|
EXPECT_EQ(elementType->kind(), c10::TypeKind::DictType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::DictType>().getKeyType()->kind(),
|
|
c10::TypeKind::StringType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::DictType>().getValueType()->kind(),
|
|
c10::TypeKind::AnyType);
|
|
|
|
elementType = rootType.elements()[3];
|
|
EXPECT_EQ(elementType->kind(), c10::TypeKind::ListType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::ListType>().getElementType()->kind(),
|
|
c10::TypeKind::AnyType);
|
|
|
|
elementType = rootType.elements()[4];
|
|
EXPECT_EQ(elementType->kind(), c10::TypeKind::DictType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::DictType>().getKeyType()->kind(),
|
|
c10::TypeKind::IntType);
|
|
EXPECT_EQ(
|
|
elementType->expectRef<c10::DictType>().getValueType()->kind(),
|
|
c10::TypeKind::AnyType);
|
|
}
|
|
|
|
} // namespace torch::nativert::detail
|