mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Update xnnpack graph schema to use xnode and xvalue (#89036)
There are different nodes definition like [Node in autograd](https://www.internalfb.com/code/fbsource/fbcode/caffe2/torch/csrc/autograd/function.h?lines=108-609&reveal=108-609) and onnxnodes and etc. Understand namespace can be used where nodes from definition are used together, however it's still better to slightly differentiate the name. Differential Revision: [D41002324](https://our.internmc.facebook.com/intern/diff/D41002324/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89036 Approved by: https://github.com/mcr229
This commit is contained in:
parent
8c46a5de3a
commit
2452e3f99a
|
|
@ -46,10 +46,10 @@ XNNExecutor XNNCompiler::compileModel(std::string ser_model) {
|
|||
// a new mapping from the old ids to the newly created ones
|
||||
std::unordered_map<uint32_t, uint32_t> remapped_ids;
|
||||
|
||||
for (auto value : *flatbuffer_graph->values()) {
|
||||
switch (value->value_type()) {
|
||||
case fb_xnnpack::ValueUnion::XNNTensorValue: {
|
||||
auto tensor_value = value->value_as_XNNTensorValue();
|
||||
for (auto value : *flatbuffer_graph->xvalues()) {
|
||||
switch (value->xvalue_type()) {
|
||||
case fb_xnnpack::XValueUnion::XNNTensorValue: {
|
||||
auto tensor_value = value->xvalue_as_XNNTensorValue();
|
||||
|
||||
const void* data_ptr = nullptr;
|
||||
auto buffer_idx = tensor_value->constant_buffer_idx();
|
||||
|
|
@ -85,10 +85,10 @@ XNNExecutor XNNCompiler::compileModel(std::string ser_model) {
|
|||
}
|
||||
}
|
||||
|
||||
for (auto node : *flatbuffer_graph->nodes()) {
|
||||
switch (node->node_type()) {
|
||||
case fb_xnnpack::NodeUnion::XNNAdd: {
|
||||
auto graph_node = node->node_as_XNNAdd();
|
||||
for (auto node : *flatbuffer_graph->xnodes()) {
|
||||
switch (node->xnode_type()) {
|
||||
case fb_xnnpack::XNodeUnion::XNNAdd: {
|
||||
auto graph_node = node->xnode_as_XNNAdd();
|
||||
status = xnn_define_add2(
|
||||
subgraph_ptr,
|
||||
output_min,
|
||||
|
|
|
|||
|
|
@ -44,22 +44,22 @@ table XNNTensorValue {
|
|||
id_out:uint;
|
||||
}
|
||||
|
||||
union NodeUnion {
|
||||
union XNodeUnion {
|
||||
XNNAdd,
|
||||
}
|
||||
|
||||
union ValueUnion {
|
||||
union XValueUnion {
|
||||
XNNTensorValue,
|
||||
}
|
||||
|
||||
table Node {
|
||||
node:NodeUnion;
|
||||
table XNode {
|
||||
xnode:XNodeUnion;
|
||||
// An int which can be linked back to the node in the origin graph
|
||||
debug_handle:uint;
|
||||
}
|
||||
|
||||
table Value {
|
||||
value:ValueUnion;
|
||||
table XValue {
|
||||
xvalue:XValueUnion;
|
||||
}
|
||||
|
||||
table XNNAdd {
|
||||
|
|
@ -72,8 +72,8 @@ table XNNAdd {
|
|||
table XNNGraph {
|
||||
// Schema version.
|
||||
version:string;
|
||||
nodes:[Node];
|
||||
values:[Value];
|
||||
xnodes:[XNode];
|
||||
xvalues:[XValue];
|
||||
|
||||
// Ids of external inputs
|
||||
input_ids:[uint];
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ void XNNSerializer::serializeAddNode(
|
|||
const auto addNode =
|
||||
CreateXNNAdd(_builder, input1_id, input2_id, output_id, flags);
|
||||
const auto flatbufferNode =
|
||||
CreateNode(_builder, NodeUnion::XNNAdd, addNode.Union());
|
||||
CreateXNode(_builder, XNodeUnion::XNNAdd, addNode.Union());
|
||||
_nodes.push_back(flatbufferNode);
|
||||
}
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ void XNNSerializer::serializeTensorValue(
|
|||
id_out);
|
||||
|
||||
const auto flatbufferValue =
|
||||
CreateValue(_builder, ValueUnion::XNNTensorValue, tensorValue.Union());
|
||||
CreateXValue(_builder, XValueUnion::XNNTensorValue, tensorValue.Union());
|
||||
_values.push_back(flatbufferValue);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,10 +61,10 @@ class XNNSerializer {
|
|||
flatbuffers_fbsource::FlatBufferBuilder _builder;
|
||||
|
||||
// Vector of the serialized xnnpack nodes
|
||||
std::vector<flatbuffers_fbsource::Offset<Node>> _nodes;
|
||||
std::vector<flatbuffers_fbsource::Offset<XNode>> _nodes;
|
||||
|
||||
// Vector of the serialized xnnpack values
|
||||
std::vector<flatbuffers_fbsource::Offset<Value>> _values;
|
||||
std::vector<flatbuffers_fbsource::Offset<XValue>> _values;
|
||||
|
||||
std::vector<flatbuffers_fbsource::Offset<Buffer>> _constantBuffer;
|
||||
std::vector<uint32_t> _bufferSizes;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user