pytorch/test/cpp/nativert/test_itree.cpp
Shangdi Yu 4e19477196 [nativert] Move Pytree (#155136)
Summary: fbcode/sigmoid/core/common -> fbcode/caffe2/torch/nativert/common

Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72

Test Plan:
```
buck run fbcode//mode/dev-nosan  //caffe2/test/cpp/nativert:pytree_test
```

OSS CI

Rollback Plan:

Differential Revision: D75965059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155136
Approved by: https://github.com/zhxchen17, https://github.com/XuehaiPan, https://github.com/zou3519
2025-06-12 01:10:34 +00:00

1151 lines
30 KiB
C++

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <fmt/format.h>
#include <c10/util/Enumerate.h>
#include <torch/nativert/detail/ITree.h>
namespace torch::nativert::detail {
using torch::nativert::Graph;
using torch::nativert::stringToGraph;
using torch::nativert::Type;
using torch::nativert::Value;
std::pair<std::unique_ptr<Graph>, std::vector<const Value*>> makeValues(
int count) {
auto graph = Graph::createGraph();
std::vector<const Value*> values;
for (int i = 0; i < count; i++) {
std::string name = fmt::format("v{}", i);
Value* value = graph->addValue(name, Type::Kind::None, nullptr);
values.push_back(value);
}
return std::make_pair(std::move(graph), values);
}
TEST(ITreeTest, Unflatten) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}, (10,), {"11": 12}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "torch.fx.immutable_collections.immutable_list",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "torch.fx.immutable_collections.immutable_dict",
"context": "[\"11\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(8);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
std::vector<c10::IValue> flats = {
c10::IValue(0),
c10::IValue(1),
c10::IValue(2),
c10::IValue(7),
c10::IValue(8),
c10::IValue(9),
c10::IValue(10),
c10::IValue(12),
};
auto itree = itreeUnflatten(flats, spec);
EXPECT_TRUE(itree.isList());
EXPECT_EQ(itree.toListRef().size(), 5);
EXPECT_TRUE(itree.toListRef().at(0).isTuple());
EXPECT_EQ(itree.toListRef().at(0).toTupleRef().elements()[0], c10::IValue(0));
EXPECT_EQ(itree.toListRef().at(0).toTupleRef().elements()[1], c10::IValue(1));
EXPECT_TRUE(itree.toListRef().at(1).isInt());
EXPECT_EQ(itree.toListRef().at(1), c10::IValue(2));
EXPECT_TRUE(itree.toListRef().at(2).isGenericDict());
EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("4"), c10::IValue(7));
EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("5"), c10::IValue(8));
EXPECT_EQ(itree.toListRef().at(2).toGenericDict().at("6"), c10::IValue(9));
EXPECT_TRUE(itree.toListRef().at(3).isList());
EXPECT_EQ(itree.toListRef().at(3).toListRef().at(0), c10::IValue(10));
EXPECT_TRUE(itree.toListRef().at(4).isGenericDict());
EXPECT_EQ(itree.toListRef().at(4).toGenericDict().at("11"), c10::IValue(12));
const auto flattened = itreeFlatten(itree, spec);
EXPECT_EQ(flattened.size(), flats.size());
for (size_t i = 0; i < flattened.size(); i++) {
EXPECT_EQ(flattened[i], flats[i]);
}
}
TEST(ITreeTest, NoVersion) {
auto jsonSpec = R"(
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
)";
auto [graph, valuePtrs] = makeValues(2);
EXPECT_THROW({ itreeSpecLoads(jsonSpec, valuePtrs); }, std::exception);
}
TEST(ITreeTest, NoField) {
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(3);
EXPECT_THROW(itreeSpecLoads(jsonSpec, valuePtrs), std::exception);
}
TEST(ITreeTest, NoContext) {
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.dict",
"context": "[]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(3);
auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
std::vector<c10::IValue> flats = {
c10::IValue(7),
c10::IValue(8),
c10::IValue(9),
};
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
}
TEST(ITreeTest, TooManyContext) {
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\", \"10\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(3);
auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
std::vector<c10::IValue> flats = {
c10::IValue(7),
c10::IValue(8),
c10::IValue(9),
};
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
}
TEST(ITreeTest, DoubleRegister) {
EXPECT_THROW(
{ registerPytreeNode("builtins.dict", NodeDef{}); }, std::exception);
}
TEST(ITreeTest, NotEnoughUnflatten) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
std::vector<c10::IValue> flats = {
c10::IValue(0),
c10::IValue(1),
c10::IValue(2),
c10::IValue(7),
};
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
}
TEST(ITreeTest, TooManyUnflatten) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
std::vector<c10::IValue> flats = {
c10::IValue(0),
c10::IValue(1),
c10::IValue(2),
c10::IValue(7),
c10::IValue(0),
c10::IValue(1),
c10::IValue(2),
c10::IValue(7),
c10::IValue(0),
c10::IValue(1),
c10::IValue(2),
c10::IValue(7),
};
ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed");
}
TEST(ITreeTest, Flatten) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}, (10,), {"11": 12}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "torch.fx.immutable_collections.immutable_list",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "torch.fx.immutable_collections.immutable_dict",
"context": "[\"11\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(8);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)});
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
dict.insert("4", c10::IValue(7));
dict.insert("5", c10::IValue(8));
dict.insert("6", c10::IValue(9));
c10::List<c10::IValue> ilist(c10::AnyType::get());
ilist.push_back(c10::IValue(10));
c10::Dict<c10::IValue, c10::IValue> idict(
c10::StringType::get(), c10::AnyType::get());
idict.insert("11", c10::IValue(12));
c10::List<c10::IValue> list(c10::AnyType::get());
list.push_back(std::move(tup));
list.push_back(c10::IValue(2));
list.push_back(std::move(dict));
list.push_back(std::move(ilist));
list.push_back(std::move(idict));
auto flats = itreeFlatten(c10::IValue{list}, spec);
std::vector<c10::IValue> expected = {
c10::IValue(0),
c10::IValue(1),
c10::IValue(2),
c10::IValue(7),
c10::IValue(8),
c10::IValue(9),
c10::IValue(10),
c10::IValue(12),
};
for (const auto& [i, flat] : c10::enumerate(flats)) {
EXPECT_EQ(flat, expected.at(i));
}
}
TEST(ITreeTest, IValueApplyFromArgs) {
// inputSpec for testing is generated from E2ETestModelWithNestedDictInput
/*
args = (
{
"a": (
torch.rand(4, 4),
{
123: (torch.rand(4, 4), torch.rand(4, 4)),
234: (torch.rand(4, 4), torch.rand(4, 4)),
},
),
"b": (
torch.rand(4, 4),
{
345: (torch.rand(4, 4), torch.rand(4, 4)),
456: (torch.rand(4, 4), torch.rand(4, 4)),
},
),
},
)*/
auto jsonSpec = R"(
[
1,
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": "builtins.dict",
"context": "[\"a\", \"b\"]",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[123, 234]",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
},
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[345, 456]",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
}
]
}
]
},
{
"type": "builtins.dict",
"context": "[]",
"children_spec": []
}
]
}
]
)";
auto tup_a1_123 =
c10::ivalue::Tuple::create({c10::IValue(1), c10::IValue(2)});
auto tup_a1_234 =
c10::ivalue::Tuple::create({c10::IValue(3), c10::IValue(4)});
c10::Dict<c10::IValue, c10::IValue> dict_a1(
c10::StringType::get(), c10::AnyType::get());
dict_a1.insert(123, tup_a1_123);
dict_a1.insert(234, tup_a1_234);
auto tup_a =
c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(dict_a1)});
auto tup_b1_345 =
c10::ivalue::Tuple::create({c10::IValue(6), c10::IValue(7)});
auto tup_b1_456 =
c10::ivalue::Tuple::create({c10::IValue(8), c10::IValue(9)});
c10::Dict<c10::IValue, c10::IValue> dict_b1(
c10::StringType::get(), c10::AnyType::get());
dict_b1.insert(345, tup_b1_345);
dict_b1.insert(456, tup_b1_456);
auto tup_b =
c10::ivalue::Tuple::create({c10::IValue(5), c10::IValue(dict_b1)});
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
dict.insert("a", tup_a);
dict.insert("b", tup_b);
std::vector<c10::IValue> args = {c10::IValue(dict)};
for (int usedIdx = 0; usedIdx < 10; usedIdx++) {
std::vector<bool> isUsed(10, false);
isUsed[usedIdx] = true;
std::stringstream ss;
for (int i = 0; i < 10; ++i) {
if (isUsed[i]) {
ss << fmt::format("%o1 = aten.foo(a=%a{})\n", i);
}
}
std::string source = fmt::format(
R"(graph(%a0, %a1, %a2, %a3, %a4, %a5, %a6, %a7, %a8, %a9):
{}
return(%o1)
)",
ss.str());
auto graph = stringToGraph(source);
std::vector<const Value*> userInputs(
graph->userInputs().begin(), graph->userInputs().end());
const auto spec = itreeSpecLoads(jsonSpec, userInputs);
std::vector<int> visited;
auto fn = [&](const c10::IValue& leaf, const Value* value) {
visited.push_back(value->id());
};
ivalueApplyFromArgs(fn, args, {}, spec);
EXPECT_EQ(visited.size(), 1);
EXPECT_EQ(visited[0], usedIdx);
}
}
TEST(ITreeTest, UnmatchedFlattenType) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)});
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
dict.insert("4", c10::IValue(7));
dict.insert("5", c10::IValue(8));
dict.insert("6", c10::IValue(9));
EXPECT_THROW(
{ itreeFlatten(c10::IValue{std::move(dict)}, spec); }, std::exception);
}
TEST(ITreeTest, UnmatchedDictFlatten) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
auto tup = c10::ivalue::Tuple::create({c10::IValue(0), c10::IValue(1)});
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
dict.insert("4", c10::IValue(7));
dict.insert("5", c10::IValue(8));
dict.insert("100", c10::IValue(8));
dict.insert("101", c10::IValue(8));
c10::List<c10::IValue> list(c10::AnyType::get());
list.push_back(std::move(tup));
list.push_back(c10::IValue(2));
list.push_back(std::move(dict));
ASSERT_DEATH(
{ itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed");
}
TEST(ITreeTest, DictFlattenTest) {
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(3);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
// allow dict.size < context
// test dict.size=2 , context,size=3,
dict.insert("4", c10::IValue(7));
dict.insert("5", c10::IValue(8));
c10::List<c10::IValue> list(c10::AnyType::get());
list.push_back(std::move(dict));
itreeFlatten(c10::IValue{std::move(list)}, spec);
}
TEST(ITreeTest, UnmatchedTupleFlatten) {
// Original data: [(0, 1), 2, {"4": 7, "5": 8, "6": 9}]
auto jsonSpec = R"(
[
1,
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\", \"6\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
auto tup = c10::ivalue::Tuple::create({c10::IValue(0)});
c10::Dict<c10::IValue, c10::IValue> dict(
c10::StringType::get(), c10::AnyType::get());
dict.insert("4", c10::IValue(7));
dict.insert("5", c10::IValue(8));
dict.insert("6", c10::IValue(8));
c10::List<c10::IValue> list(c10::AnyType::get());
list.push_back(std::move(tup));
list.push_back(c10::IValue(2));
list.push_back(std::move(dict));
ASSERT_DEATH(
{ itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed");
}
TEST(ITreeTest, ToAtenType) {
// Original data: ((0, 1), 2, {"4": 7, "5": 8}, [10], {6: 9})
auto jsonSpec = R"(
[
1,
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": "builtins.dict",
"context": "[\"4\", \"5\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "builtins.list",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "builtins.dict",
"context": "[6]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(7);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
auto atenType = spec.toAtenType();
// Root level is tuple.
EXPECT_EQ(atenType->kind(), c10::TypeKind::TupleType);
const c10::TupleType& rootType = atenType->expectRef<c10::TupleType>();
EXPECT_EQ(rootType.elements().size(), 5);
at::TypePtr elementType = rootType.elements()[0];
EXPECT_EQ(elementType->kind(), c10::TypeKind::TupleType);
EXPECT_EQ(
elementType->expectRef<c10::TupleType>().elements()[0]->kind(),
c10::TypeKind::AnyType);
EXPECT_EQ(
elementType->expectRef<c10::TupleType>().elements()[1]->kind(),
c10::TypeKind::AnyType);
elementType = rootType.elements()[1];
EXPECT_EQ(elementType->kind(), c10::TypeKind::AnyType);
elementType = rootType.elements()[2];
EXPECT_EQ(elementType->kind(), c10::TypeKind::DictType);
EXPECT_EQ(
elementType->expectRef<c10::DictType>().getKeyType()->kind(),
c10::TypeKind::StringType);
EXPECT_EQ(
elementType->expectRef<c10::DictType>().getValueType()->kind(),
c10::TypeKind::AnyType);
elementType = rootType.elements()[3];
EXPECT_EQ(elementType->kind(), c10::TypeKind::ListType);
EXPECT_EQ(
elementType->expectRef<c10::ListType>().getElementType()->kind(),
c10::TypeKind::AnyType);
elementType = rootType.elements()[4];
EXPECT_EQ(elementType->kind(), c10::TypeKind::DictType);
EXPECT_EQ(
elementType->expectRef<c10::DictType>().getKeyType()->kind(),
c10::TypeKind::IntType);
EXPECT_EQ(
elementType->expectRef<c10::DictType>().getValueType()->kind(),
c10::TypeKind::AnyType);
}
} // namespace torch::nativert::detail