mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Profile memory usage in VirtualScheduler and report peak memory usage.
To do so, NodeState now handles different output ports of a node (in case a node has multiple outputs). Also, VirtualScheduler code is cleaned up with more comments. PiperOrigin-RevId: 158209068
This commit is contained in:
parent
0ea0bf5aae
commit
8f89b654f4
|
|
@ -176,6 +176,7 @@ cc_library(
|
|||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//tensorflow/core/grappler/costs:cost_estimator",
|
||||
|
|
@ -192,6 +193,10 @@ cc_test(
|
|||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ constexpr char kNoOp[] = "NoOp";
|
|||
constexpr char kReshape[] = "Reshape";
|
||||
constexpr char kRecv[] = "_Recv";
|
||||
constexpr char kBatchMatMul[] = "BatchMatMul";
|
||||
constexpr char kVariable[] = "Variable";
|
||||
constexpr char kVariableV2[] = "VariableV2";
|
||||
|
||||
OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
// Syntactic sugar to build and return a lambda that takes an OpInfo and
|
||||
|
|
@ -53,6 +55,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
|||
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||
{kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||
{kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
|
||||
}
|
||||
|
||||
|
|
@ -567,7 +571,7 @@ int64 OpLevelCostEstimator::CalculateSingleInputSize(
|
|||
for (const auto& dim : input_shape.dim()) {
|
||||
input_size *= dim.size();
|
||||
}
|
||||
return input_size * DataTypeSize(input.dtype());
|
||||
return input_size * DataTypeSize(BaseType(input.dtype()));
|
||||
}
|
||||
|
||||
int64 OpLevelCostEstimator::CalculateInputSize(
|
||||
|
|
@ -589,7 +593,7 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
|
|||
for (const auto& output : op_features.outputs()) {
|
||||
DataType dt = output.dtype();
|
||||
const auto& original_output_shape = output.shape();
|
||||
int64 output_size = DataTypeSize(dt);
|
||||
int64 output_size = DataTypeSize(BaseType(dt));
|
||||
int num_dims = std::max(1, original_output_shape.dim_size());
|
||||
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
|
||||
found_unknown_shapes);
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
|
|
@ -55,10 +56,10 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
|
|||
const bool use_static_shapes,
|
||||
const string& default_device_type,
|
||||
Cluster* cluster, VirtualPlacer* placer)
|
||||
: graph_properties_(*grappler_item),
|
||||
graph_costs_(Costs::ZeroCosts()),
|
||||
// TODO(dyoon): Use a better way than FIFO.
|
||||
: // TODO(dyoon): Use a better way than FIFO.
|
||||
ready_nodes_(new FIFOManager()),
|
||||
graph_costs_(Costs::ZeroCosts()),
|
||||
graph_properties_(*grappler_item),
|
||||
cluster_(cluster),
|
||||
grappler_item_(grappler_item),
|
||||
use_static_shapes_(use_static_shapes),
|
||||
|
|
@ -68,6 +69,11 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
|
|||
}
|
||||
|
||||
Status VirtualScheduler::Init() {
|
||||
// Init() preprocesses the input grappler_item and graph_properties to extract
|
||||
// necessary information for emulating tensorflow op scheduling and
|
||||
// construct internal data structures (NodeState and DeviceState) for virtual
|
||||
// scheduling.
|
||||
|
||||
// Construct graph properties.
|
||||
Status status;
|
||||
if (use_static_shapes_) {
|
||||
|
|
@ -82,13 +88,12 @@ Status VirtualScheduler::Init() {
|
|||
const auto& graph = grappler_item_->graph;
|
||||
const auto& fetch_nodes = grappler_item_->fetch;
|
||||
|
||||
// First, get the nodes that would run to output fetch_nodes.
|
||||
// Get the nodes that would run to output fetch_nodes.
|
||||
std::vector<const NodeDef*> nodes =
|
||||
ComputeTransitiveFanin(graph, fetch_nodes);
|
||||
|
||||
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
|
||||
// ComputeTransitiveFanin().
|
||||
//
|
||||
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached
|
||||
// from the fetch nodes are scheduled. So the scheduled nodes should be
|
||||
// exactly the same as those executed for real. One possible discrepancy could
|
||||
|
|
@ -98,61 +103,72 @@ Status VirtualScheduler::Init() {
|
|||
name_to_node[node->name()] = node;
|
||||
}
|
||||
|
||||
// Build node_map.
|
||||
// Build node_map; for each node, create its NodeState and connect its inputs
|
||||
// and outputs.
|
||||
for (const auto* curr_node : nodes) {
|
||||
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
|
||||
const string curr_node_device = DeviceName(curr_node);
|
||||
for (const string& input_node_name : curr_node->input()) {
|
||||
// Note that input_node_name may be in <node_name>:<output_number> format,
|
||||
// where ":<output_number>" may be omitted. NodeName() extracts only the
|
||||
// node_name (prefeix "^", if there was for control input, is also
|
||||
// deleted).
|
||||
// Note that input_node_name may be in <prefix><node_name>:<port_num>
|
||||
// format, where <prefix> (e.g., "^" for control dependency) and
|
||||
// ":<port_num>" may be omitted. NodeName() extracts only the node_name.
|
||||
const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
|
||||
|
||||
CHECK(input_node);
|
||||
// Add input_to_curr_node to curr_node's input, and
|
||||
// add output_to_input_node to input_source_node's output.
|
||||
// Default values for when input_node and curr_node on the same device.
|
||||
const NodeDef* input_to_curr_node = input_node;
|
||||
const NodeDef* input_source_node = input_node;
|
||||
const NodeDef* output_to_input_node = curr_node;
|
||||
const string in_device = DeviceName(input_node);
|
||||
if (curr_node_device != in_device) {
|
||||
if (cached_ops_.count(input_node) > 0 &&
|
||||
cached_ops_[input_node].count(curr_node_device) > 0) {
|
||||
// Different device, but found an already-transferred copy; connect
|
||||
// the cached node to curr_node.
|
||||
input_to_curr_node = cached_ops_[input_node][curr_node_device];
|
||||
input_source_node = input_to_curr_node;
|
||||
output_to_input_node = curr_node;
|
||||
const auto input_node_port_num = NodePosition(input_node_name);
|
||||
|
||||
if (curr_node_device == in_device) {
|
||||
// Same device: connect input_node and curr_node directly.
|
||||
curr_node_state.inputs.push_back(
|
||||
std::make_pair(input_node, input_node_port_num));
|
||||
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
|
||||
input_node_state.outputs[input_node_port_num].push_back(curr_node);
|
||||
} else {
|
||||
if (cached_recv_nodes_.count(input_node) > 0 &&
|
||||
cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
|
||||
// Different device, but found an already-cached copy (a _Recv op);
|
||||
// connect the _Recv to curr_node.
|
||||
const auto* recv_op =
|
||||
cached_recv_nodes_[input_node][curr_node_device];
|
||||
// recv_op's output port is hard-coded to zero.
|
||||
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
|
||||
auto& input_node_state = node_map_.at(recv_op);
|
||||
input_node_state.outputs[0].push_back(curr_node);
|
||||
} else {
|
||||
// Different device, no cached copy; transfer input_node to the
|
||||
// curr_node's device.
|
||||
auto sendrecv_and_identity =
|
||||
TransferNode(input_node, curr_node, input_node_name);
|
||||
const auto* sendrecv = sendrecv_and_identity.first;
|
||||
const auto* identity = sendrecv_and_identity.second;
|
||||
input_to_curr_node = identity;
|
||||
input_source_node = input_node;
|
||||
output_to_input_node = sendrecv;
|
||||
auto send_and_recv =
|
||||
CreateSendRecv(input_node, curr_node, input_node_name);
|
||||
// Note that CreateSendRecv() already connected input/output between
|
||||
// _Send and _Recv ops.
|
||||
const auto* send = send_and_recv.first;
|
||||
const auto* recv = send_and_recv.second;
|
||||
// recv_op's output port is hard-coded to zero.
|
||||
curr_node_state.inputs.push_back(std::make_pair(recv, 0));
|
||||
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
|
||||
input_node_state.outputs[input_node_port_num].push_back(send);
|
||||
|
||||
// Cache the identity op for future use.
|
||||
cached_ops_[input_node][curr_node_device] = identity;
|
||||
// Cache the _Recv op for future use.
|
||||
cached_recv_nodes_[input_node][curr_node_device] = recv;
|
||||
}
|
||||
}
|
||||
curr_node_state.inputs.push_back(input_to_curr_node);
|
||||
|
||||
// Note that we do not care output number (in case a tf op has multiple
|
||||
// outputs), as VirtualScheduler only cares which nodes become ready as
|
||||
// a node is executed.
|
||||
auto& input_node_state = GetNodeStateOrCreateIt(input_source_node);
|
||||
input_node_state.outputs.push_back(output_to_input_node);
|
||||
}
|
||||
|
||||
if (curr_node->input().empty()) {
|
||||
curr_node_state.time_ready =
|
||||
Costs::Duration(); // Node without input: ready at time 0.
|
||||
// Node without input: ready at time 0.
|
||||
curr_node_state.time_ready = Costs::Duration();
|
||||
ready_nodes_->AddNode(curr_node);
|
||||
}
|
||||
|
||||
if (IsPersistentNode(curr_node)) {
|
||||
auto& device_state = device_[curr_node_device];
|
||||
for (int port_num = 0;
|
||||
port_num < curr_node_state.output_properties.size(); ++port_num) {
|
||||
device_state.persistent_nodes.insert(
|
||||
std::make_pair(curr_node, port_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ready_nodes_->Empty()) {
|
||||
|
|
@ -163,18 +179,26 @@ Status VirtualScheduler::Init() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void VirtualScheduler::MaybeUpdateInputProperties(
|
||||
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const {
|
||||
if (IsSendOp(node) || IsRecvOp(node)) {
|
||||
// _Send and _Recv ops are inserted from VirtualScheduler, so
|
||||
void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
|
||||
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
|
||||
// This method is called when NodeState is created and adds input and output
|
||||
// properties for a few exceptional cases that GraphProperties cannot provide
|
||||
// input/output properties.
|
||||
if (IsSend(*node) || IsRecv(*node)) {
|
||||
auto& node_state = node_map_[node];
|
||||
auto& inputs = node_state.input_properties;
|
||||
auto& outputs = node_state.output_properties;
|
||||
|
||||
// _Send and _Recv ops are created from VirtualScheduler, so
|
||||
// there should be no inputs TensorProperties.
|
||||
CHECK_EQ(inputs->size(), 0);
|
||||
CHECK(inputs.empty());
|
||||
CHECK(outputs.empty());
|
||||
const auto& attr = node->attr();
|
||||
// This is the original input source to the _Send and _Recv, and this
|
||||
// string includes "^" if it was control dependency, and output port
|
||||
/// (e.g., ":2") if the input source had multiple outputs.
|
||||
const auto& input_source_name = attr.at(kAttrInputSrc).s();
|
||||
if (input_source_name[0] == '^') {
|
||||
if (IsControlInput(input_source_name)) {
|
||||
// Control dependency; regardless of the input source tensor size,
|
||||
// send 4B.
|
||||
OpInfo::TensorProperties control_message;
|
||||
|
|
@ -182,51 +206,53 @@ void VirtualScheduler::MaybeUpdateInputProperties(
|
|||
control_message.mutable_shape()->add_dim()->set_size(1);
|
||||
auto* value = control_message.mutable_value();
|
||||
value->add_float_val(1);
|
||||
inputs->push_back(control_message);
|
||||
inputs.push_back(control_message);
|
||||
outputs.push_back(control_message);
|
||||
} else {
|
||||
auto output_properties =
|
||||
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
||||
// Like with HasInputProperties, if a node does not have output
|
||||
// properties, it's likely it was pruned during the shape inference run.
|
||||
if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) {
|
||||
const auto input_position = NodePosition(input_source_name);
|
||||
if (!output_properties.empty()) {
|
||||
const auto input_node_port_num = NodePosition(input_source_name);
|
||||
// Use the input source's output property as _Send and _Recv's input
|
||||
// property.
|
||||
auto outputs =
|
||||
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
||||
CHECK_GT(outputs.size(), input_position);
|
||||
inputs->push_back(outputs[input_position]);
|
||||
CHECK_GT(output_properties.size(), input_node_port_num);
|
||||
inputs.push_back(output_properties[input_node_port_num]);
|
||||
outputs.push_back(output_properties[input_node_port_num]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool VirtualScheduler::IsSendOp(const NodeDef* node) const {
|
||||
return node->op() == kSend;
|
||||
float VirtualScheduler::Round2(const float x) const {
|
||||
return std::round(100.0 * x) / 100.0;
|
||||
}
|
||||
|
||||
bool VirtualScheduler::IsRecvOp(const NodeDef* node) const {
|
||||
return node->op() == kRecv;
|
||||
bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
|
||||
// Variables are persistent nodes.
|
||||
return IsVariable(*node);
|
||||
}
|
||||
|
||||
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||
CHECK(!initialized_) << "DeviceName is called after Init().";
|
||||
|
||||
// TODO(dyoon): integrate this part with VirtualPlacer.
|
||||
if (IsSendOp(node)) {
|
||||
const auto& node_state = node_map_.at(node);
|
||||
const auto* from = node_state.inputs[0];
|
||||
const auto* to = node_state.outputs[0];
|
||||
return ChannelDeviceName(from, to);
|
||||
} else {
|
||||
return node->device().empty() ? "/" + default_device_type_ + ":0"
|
||||
return node->device().empty() ? "/device:" + default_device_type_ + ":0"
|
||||
: node->device();
|
||||
}
|
||||
}
|
||||
|
||||
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
|
||||
const NodeDef* to) const {
|
||||
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
|
||||
|
||||
return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to);
|
||||
}
|
||||
|
||||
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
||||
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||
const NodeDef* from, const NodeDef* to, const string& input_name) {
|
||||
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
|
||||
|
||||
// Connect "from" node to "to" node with _Send and _Recv such that
|
||||
// from -> _Send -> _Recv -> to.
|
||||
// _Send is placed on "Channel" device, and _Recv is on the same device
|
||||
|
|
@ -238,11 +264,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
|||
// NodeDefs created here need not be correct: in terms of name,
|
||||
// input names, attrs, etc.
|
||||
|
||||
auto input_node_port_num = NodePosition(input_name);
|
||||
|
||||
// _Send op.
|
||||
auto* send = new NodeDef();
|
||||
send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
|
||||
DeviceName(to));
|
||||
send->set_op(kSend);
|
||||
send->set_op("_Send");
|
||||
send->add_input(from->name());
|
||||
send->set_device(ChannelDeviceName(from, to));
|
||||
auto& send_attr = *(send->mutable_attr());
|
||||
|
|
@ -253,19 +281,22 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
|||
// _Recv op.
|
||||
auto* recv = new NodeDef();
|
||||
recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
|
||||
recv->set_op(kRecv);
|
||||
recv->set_op("_Recv");
|
||||
recv->add_input(send->name());
|
||||
recv->set_device(DeviceName(to));
|
||||
auto& recv_attr = *(recv->mutable_attr());
|
||||
recv_attr[kAttrInputSrc].set_s(input_name);
|
||||
|
||||
// Update NodeState for _Send and _Recv ops.
|
||||
// NodeState for _Send op.
|
||||
auto& send_node_state = GetNodeStateOrCreateIt(send);
|
||||
send_node_state.inputs.push_back(from);
|
||||
send_node_state.outputs.push_back(recv);
|
||||
send_node_state.device_name = send->device(); // Set Channel device.
|
||||
send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
|
||||
send_node_state.outputs[0].push_back(recv);
|
||||
|
||||
// NodeState for _Recv op.
|
||||
auto& recv_node_state = GetNodeStateOrCreateIt(recv);
|
||||
recv_node_state.inputs.push_back(send);
|
||||
recv_node_state.outputs.push_back(to);
|
||||
recv_node_state.inputs.push_back(std::make_pair(send, 0));
|
||||
recv_node_state.outputs[0].push_back(to);
|
||||
|
||||
// Keep the created nodes.
|
||||
additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
|
||||
|
|
@ -277,13 +308,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
|||
|
||||
NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||
std::vector<OpInfo::TensorProperties> inputs =
|
||||
graph_properties_.GetInputProperties(node->name());
|
||||
// Some ops created within VirtualScheduler may need further processing to
|
||||
// the input properties.
|
||||
MaybeUpdateInputProperties(node, &inputs);
|
||||
|
||||
// This is for compatibility; we can just use palcer_->get_device() for all
|
||||
// This is for compatibility; we can just use placer_->get_device() for all
|
||||
// cases, once VirtualCluster is properly set up.
|
||||
DeviceProperties device;
|
||||
if (placer_) {
|
||||
|
|
@ -294,7 +320,8 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
|||
int device_id;
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!node->device().empty() &&
|
||||
DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) {
|
||||
DeviceNameUtils::ParseFullName(node_map_.at(node).device_name,
|
||||
&parsed)) {
|
||||
device_type = parsed.type;
|
||||
device_id = parsed.id;
|
||||
} else {
|
||||
|
|
@ -309,81 +336,111 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
|||
}
|
||||
|
||||
// Special case for _Send op.
|
||||
if (IsSendOp(node)) {
|
||||
if (IsSend(*node)) {
|
||||
device.set_type(kChannelDevice);
|
||||
}
|
||||
|
||||
// Construct NodeInfo.
|
||||
const auto& node_state = node_map_.at(node);
|
||||
NodeInfo node_info;
|
||||
node_info.name = node->name();
|
||||
node_info.device_name = graph_properties_.GetDeviceName(node->name());
|
||||
std::vector<OpInfo::TensorProperties> outputs =
|
||||
graph_properties_.GetOutputProperties(node->name());
|
||||
node_info.device_name = node_state.device_name;
|
||||
auto& op_info = node_info.op_info;
|
||||
op_info.set_op(node->op());
|
||||
*op_info.mutable_attr() = node->attr();
|
||||
for (auto& input : inputs) {
|
||||
op_info.add_inputs()->Swap(&input);
|
||||
for (auto& input : node_state.input_properties) {
|
||||
*op_info.add_inputs() = input;
|
||||
}
|
||||
for (auto& output : outputs) {
|
||||
op_info.add_outputs()->Swap(&output);
|
||||
for (auto& output : node_state.output_properties) {
|
||||
*op_info.add_outputs() = output;
|
||||
}
|
||||
op_info.mutable_device()->Swap(&device);
|
||||
// add some more to the node_info.
|
||||
return node_info;
|
||||
}
|
||||
|
||||
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
|
||||
|
||||
auto it = node_map_.find(node);
|
||||
if (it == node_map_.end()) {
|
||||
// Not found; create a NodeState for this node.
|
||||
it = node_map_.emplace(node, NodeState()).first;
|
||||
auto& node_state = it->second;
|
||||
node_state.input_properties =
|
||||
graph_properties_.GetInputProperties(node->name());
|
||||
node_state.output_properties =
|
||||
graph_properties_.GetOutputProperties(node->name());
|
||||
|
||||
// Some ops may need further processing to the input / output properties:
|
||||
// _Send and _Recv.
|
||||
MaybeUpdateInputOutput(node);
|
||||
|
||||
if (!IsSend(*node)) {
|
||||
node_state.device_name = DeviceName(node);
|
||||
// For _Send op, device_name will be set to Channel in CreateSendRecv().
|
||||
}
|
||||
|
||||
// Initialize output port related data:
|
||||
// Assume the size of OutputProperties represents the number of output ports
|
||||
// of this node.
|
||||
for (int i = 0; i < node_state.output_properties.size(); ++i) {
|
||||
node_state.time_no_references[i] = Costs::Duration::max();
|
||||
node_state.num_outputs_executed[i] = 0;
|
||||
// Populate an empty vector for each port. The caller will add nodes
|
||||
// that use this port as input.
|
||||
node_state.outputs[i] = {};
|
||||
}
|
||||
// Port_num -1 is for control dependency.
|
||||
node_state.time_no_references[-1] = Costs::Duration::max();
|
||||
node_state.num_outputs_executed[-1] = 0;
|
||||
node_state.outputs[-1] = {};
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int64 VirtualScheduler::CalculateOutputSize(
|
||||
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||
const int port_num) const {
|
||||
if (port_num < 0) {
|
||||
return 4; // 4B for control dependency.
|
||||
}
|
||||
|
||||
if (port_num >= output_properties.size()) {
|
||||
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
|
||||
<< "port_num: " << port_num
|
||||
<< " >= output_properties.size(): " << output_properties.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto& output = output_properties[port_num];
|
||||
int64 output_size = DataTypeSize(BaseType(output.dtype()));
|
||||
|
||||
for (const auto& dim : output.shape().dim()) {
|
||||
auto dim_size = dim.size();
|
||||
if (dim_size < 0) {
|
||||
// Zero output size if there's any unknown dim.
|
||||
output_size = 0;
|
||||
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
|
||||
<< "unknown dim: " << output_size;
|
||||
break;
|
||||
}
|
||||
output_size *= dim_size;
|
||||
}
|
||||
|
||||
return output_size;
|
||||
}
|
||||
|
||||
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
|
||||
std::map<string, Costs>* op_cost) {
|
||||
auto it = op_cost->find(op_name);
|
||||
if (it == op_cost->end()) {
|
||||
// Note that default constructor of Costs sets some memory related fields
|
||||
// to unknown values so we should explicitly initialize it with ZeroCosts.
|
||||
it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool VirtualScheduler::PopCurrNode() {
|
||||
const auto* node = ready_nodes_->GetCurrNode();
|
||||
auto& node_state = node_map_[node];
|
||||
auto& device = device_[DeviceName(node)];
|
||||
auto curr_time = device.GetCurrTime();
|
||||
|
||||
// Increment num_inputs_ready of the output nodes.
|
||||
for (auto* output : node_state.outputs) {
|
||||
auto& output_state = node_map_[output];
|
||||
output_state.num_inputs_ready++;
|
||||
if (output_state.num_inputs_ready == output_state.inputs.size()) {
|
||||
// This output node is now ready.
|
||||
output_state.time_ready = curr_time;
|
||||
ready_nodes_->AddNode(output);
|
||||
}
|
||||
}
|
||||
|
||||
// Increment num_outputs_executed of the input nodes.
|
||||
for (auto* input : node_state.inputs) {
|
||||
auto& input_state = node_map_[input];
|
||||
input_state.num_outputs_executed++;
|
||||
if (input_state.num_outputs_executed == input_state.outputs.size()) {
|
||||
// All the outputs are executed; no reference to this input nodel
|
||||
input_state.time_no_reference = curr_time;
|
||||
// TODO(dyoon): collect device memory usage; note that this input node
|
||||
// use device memory between time_scheduled and time_no_reference.
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the current node; assume FIFO.
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
|
||||
return !ready_nodes_->Empty();
|
||||
}
|
||||
|
||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||
// Update graph_costs_ and per-op costs.
|
||||
graph_costs_ = CombineCosts(graph_costs_, node_costs);
|
||||
|
|
@ -402,7 +459,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
|
||||
// Update node and device states.
|
||||
auto& node_state = node_map_[node];
|
||||
auto& device = device_[DeviceName(node)];
|
||||
auto& device = device_[node_state.device_name];
|
||||
device.nodes_executed.push_back(node);
|
||||
// Node is scheduled when the device is available AND all the inputs are
|
||||
// ready; hence, time_scheduled is time_ready if time_ready > device curr
|
||||
|
|
@ -415,6 +472,21 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
auto curr_time = device.GetCurrTime();
|
||||
node_state.time_finished = curr_time;
|
||||
|
||||
// Update device memory usage.
|
||||
if (!IsPersistentNode(node)) {
|
||||
for (const auto& port_num_output_pair : node_state.outputs) {
|
||||
int port_num = port_num_output_pair.first;
|
||||
// There's a chance that a specific output is not used at all.
|
||||
if (node_state.outputs[port_num].empty()) {
|
||||
node_state.time_no_references[port_num] = curr_time;
|
||||
} else {
|
||||
device.memory_usage +=
|
||||
CalculateOutputSize(node_state.output_properties, port_num);
|
||||
device.nodes_in_memory.insert(std::make_pair(node, port_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update device's per-op cost.
|
||||
auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
|
||||
device_op_cost = CombineCosts(device_op_cost, node_costs);
|
||||
|
|
@ -425,7 +497,52 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
|||
<< ", scheduled: " << node_state.time_scheduled.count()
|
||||
<< ", finished: " << node_state.time_finished.count();
|
||||
|
||||
return PopCurrNode();
|
||||
// Increment num_inputs_ready of the output nodes
|
||||
for (const auto& port_num_output_pair : node_state.outputs) {
|
||||
for (auto* output_node : port_num_output_pair.second) {
|
||||
auto& output_state = node_map_[output_node];
|
||||
output_state.num_inputs_ready++;
|
||||
if (output_state.num_inputs_ready == output_state.inputs.size()) {
|
||||
// This output node is now ready.
|
||||
output_state.time_ready = curr_time;
|
||||
ready_nodes_->AddNode(output_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment num_outputs_executed of the input nodes.
|
||||
for (const auto& input_port : node_state.inputs) {
|
||||
auto* input = input_port.first;
|
||||
auto port = input_port.second;
|
||||
auto& input_state = node_map_[input];
|
||||
input_state.num_outputs_executed[port]++;
|
||||
if (input_state.num_outputs_executed[port] ==
|
||||
input_state.outputs[port].size() &&
|
||||
!IsPersistentNode(input)) {
|
||||
// All the outputs are executed; no reference to this output port of
|
||||
// input node.
|
||||
input_state.time_no_references[port] = curr_time;
|
||||
auto& input_device = device_[input_state.device_name];
|
||||
input_device.memory_usage -=
|
||||
CalculateOutputSize(input_state.output_properties, port);
|
||||
|
||||
input_device.nodes_in_memory.erase(std::make_pair(input, port));
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsPersistentNode(node)) {
|
||||
// Now that output memory is added and used up nodes are deallocated,
|
||||
// check max memory usage.
|
||||
if (device.memory_usage > device.max_memory_usage) {
|
||||
device.max_memory_usage = device.memory_usage;
|
||||
device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the current node; assume FIFO.
|
||||
ready_nodes_->RemoveCurrNode();
|
||||
|
||||
return !ready_nodes_->Empty();
|
||||
}
|
||||
|
||||
Costs VirtualScheduler::Summary() const {
|
||||
|
|
@ -452,17 +569,59 @@ Costs VirtualScheduler::Summary() const {
|
|||
for (const auto& device : device_) {
|
||||
const auto& name = device.first;
|
||||
const auto& state = device.second;
|
||||
|
||||
std::map<string, int64> op_to_memory;
|
||||
// First profile only persistent memory usage.
|
||||
int64 persistent_memory_usage = 0;
|
||||
std::set<string> persisent_ops;
|
||||
for (const auto& node_port : state.persistent_nodes) {
|
||||
const auto* node = node_port.first;
|
||||
const auto port = node_port.second;
|
||||
const auto output_size =
|
||||
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||
persistent_memory_usage += output_size;
|
||||
op_to_memory[node->op()] += output_size;
|
||||
persisent_ops.insert(node->op());
|
||||
}
|
||||
int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
|
||||
|
||||
VLOG(1) << "Device = " << name
|
||||
<< ", num_nodes = " << state.nodes_executed.size()
|
||||
<< ", execution_time = " << state.GetCurrTime().count();
|
||||
VLOG(1) << "Per-op execution time:";
|
||||
<< ", execution_time = " << state.GetCurrTime().count()
|
||||
<< ", memory usage: "
|
||||
<< "persistenst = "
|
||||
<< Round2(persistent_memory_usage / 1024.0 / 1024.0 / 1024.0)
|
||||
<< " GB, peak = "
|
||||
<< Round2(state.max_memory_usage / 1024.0 / 1024.0 / 1024.0)
|
||||
<< " GB, total = "
|
||||
<< Round2(max_memory_usage / 1024.0 / 1024.0 / 1024.0)
|
||||
<< " GB, at the end: " << state.memory_usage << " B";
|
||||
|
||||
VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
|
||||
// Profile non-persistent op memory usage.
|
||||
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
|
||||
const auto* node = node_port.first;
|
||||
const auto port = node_port.second;
|
||||
op_to_memory[node->op()] +=
|
||||
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||
}
|
||||
for (const auto& op_cost_pair : state.op_to_cost) {
|
||||
const auto& op = op_cost_pair.first;
|
||||
const auto& cost = op_cost_pair.second.execution_time.count();
|
||||
if (cost) { // Skip printing out zero-cost ops.
|
||||
VLOG(1) << " + " << op << " : " << cost;
|
||||
const float mem_usage_gb =
|
||||
Round2(op_to_memory[op] / 1024.0 / 1024.0 / 1024.0);
|
||||
int64 op_mem_usage = op_to_memory.at(op);
|
||||
const float mem_usage_percent =
|
||||
max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
|
||||
: 0.0;
|
||||
if (cost || mem_usage_percent > 1.0) {
|
||||
// Print out only non-zero cost ops or ops with > 1% memory usage.
|
||||
VLOG(1) << " + " << op << " : " << cost << " (" << mem_usage_gb
|
||||
<< " GB [" << mem_usage_percent << "%] "
|
||||
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
|
||||
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
|
||||
critical_path_costs = state.device_costs;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,36 +29,79 @@ namespace tensorflow {
|
|||
namespace grappler {
|
||||
|
||||
struct NodeState {
|
||||
std::vector<const NodeDef*> inputs;
|
||||
std::vector<const NodeDef*> outputs;
|
||||
// A node (i.e., an op) takes a set of input:port pairs and produces
|
||||
// a set of output ports.
|
||||
|
||||
// Cross references to input and output nodes from graphdef.
|
||||
std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs.
|
||||
// List of output nodes (a list of nodes that takes this output port as input)
|
||||
// keyed by port_num. Note that port_num -1 is used for control dependency.
|
||||
std::unordered_map<int, std::vector<const NodeDef*>> outputs;
|
||||
|
||||
// Info from GraphProperties.
|
||||
std::vector<OpInfo::TensorProperties> input_properties;
|
||||
std::vector<OpInfo::TensorProperties> output_properties;
|
||||
|
||||
// Canonical device name used within VirtualScheduler.
|
||||
string device_name;
|
||||
|
||||
// States updated as scheduling nodes.
|
||||
int num_inputs_ready;
|
||||
int num_outputs_executed;
|
||||
std::unordered_map<int, int> num_outputs_executed;
|
||||
Costs::Duration time_ready;
|
||||
Costs::Duration time_scheduled;
|
||||
Costs::Duration time_finished;
|
||||
Costs::Duration time_no_reference;
|
||||
// Time that all the consumers are executed (hence, no need to keep this
|
||||
// output in memory), keyed by port_num.
|
||||
std::unordered_map<int, Costs::Duration> time_no_references;
|
||||
|
||||
// Note that a node may have multiple output ports. The length of outputs,
|
||||
// num_outputs_executed, and time_no_references should be
|
||||
// identical when a NodeState is fully initialized.
|
||||
// They should be 1 + output_properties.size() as we add [-1] for control
|
||||
// dependency.
|
||||
|
||||
// Node will be ready to be executed at time_ready, scheduled at
|
||||
// time_scheduled, and finishes execution at time_finished.
|
||||
// Between time_scheduled and time_no_reference, the node's output tensor
|
||||
// needs to be on the device, using up device memory.
|
||||
// Each output port uses up memory space from time_scheduled to its
|
||||
// time_no_references.
|
||||
|
||||
NodeState() {
|
||||
num_inputs_ready = 0;
|
||||
num_outputs_executed = 0;
|
||||
time_ready = Costs::Duration::max();
|
||||
time_scheduled = Costs::Duration::max();
|
||||
time_finished = Costs::Duration::max();
|
||||
time_no_reference = Costs::Duration::max();
|
||||
// Note that num_outputs_executed and time_no_references are not initialized
|
||||
// here, since we don't know the size (i.e., # outputs for this node).
|
||||
}
|
||||
};
|
||||
|
||||
struct DeviceState {
|
||||
// Nodes executed on this device in execution order.
|
||||
std::vector<const NodeDef*> nodes_executed;
|
||||
|
||||
// Nodes currently allocated in memory: set of NodeDef* and port_num pairs
|
||||
// so that we can track which output of the node is in memory.
|
||||
std::set<std::pair<const NodeDef*, int>> nodes_in_memory;
|
||||
|
||||
// Nodes allocated in memory persistently: e.g., Variables.
|
||||
std::set<std::pair<const NodeDef*, int>> persistent_nodes;
|
||||
|
||||
// Snapshot of nodes_in_memory, when memory usage is at peak.
|
||||
// Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
|
||||
std::set<std::pair<const NodeDef*, int>> mem_usage_snapshot_at_peak;
|
||||
|
||||
Costs device_costs;
|
||||
std::map<string, Costs> op_to_cost; // Per-op cost.
|
||||
std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage.
|
||||
int64 memory_usage;
|
||||
int64 max_memory_usage;
|
||||
|
||||
DeviceState() { device_costs = Costs::ZeroCosts(); }
|
||||
DeviceState() {
|
||||
device_costs = Costs::ZeroCosts();
|
||||
memory_usage = 0;
|
||||
max_memory_usage = 0;
|
||||
}
|
||||
|
||||
Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
|
||||
};
|
||||
|
|
@ -106,48 +149,74 @@ class VirtualScheduler {
|
|||
const string& default_device_type, Cluster* cluster,
|
||||
VirtualPlacer* placer);
|
||||
|
||||
// Initializes NodeState and DeviceState from grappler_item_ and
|
||||
// graph_properties_.
|
||||
Status Init();
|
||||
|
||||
NodeInfo GetCurrNodeInfo() const;
|
||||
|
||||
// Returns true if there is any node to be scheduled.
|
||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||
|
||||
// Prints out summary of execution (timing, memory usage, etc.)
|
||||
Costs Summary() const;
|
||||
|
||||
protected:
|
||||
// GetDeviceStates and GetNodeStates are currently for testing purpuse only.
|
||||
// Retrieves detailed scheduling results.
|
||||
const std::unordered_map<string, DeviceState>& GetDeviceStates() const {
|
||||
return device_;
|
||||
}
|
||||
const std::unordered_map<const NodeDef*, NodeState>& GetNodeStates() const {
|
||||
return node_map_;
|
||||
}
|
||||
|
||||
// Returns the size of output at port_num (unit: bytes). A special case is
|
||||
// port_num -1, which is for control dependency and assumed to be 4 bytes.
|
||||
int64 CalculateOutputSize(
|
||||
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||
const int port_num) const;
|
||||
|
||||
private:
|
||||
const string kSend = "_Send";
|
||||
const string kRecv = "_Recv";
|
||||
// Constants.
|
||||
const string kAttrInputSrc = "input_source_";
|
||||
const string kAttrSrcDevice = "src_device_";
|
||||
const string kAttrDstDevice = "dst_device_";
|
||||
const string kChannelDevice = "Channel";
|
||||
|
||||
void MaybeUpdateInputProperties(
|
||||
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const;
|
||||
// Methods called from Init(). Fails if initialize_ is set.
|
||||
void MaybeUpdateInputOutput(const NodeDef* node);
|
||||
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
||||
std::pair<const NodeDef*, const NodeDef*> TransferNode(
|
||||
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
|
||||
const NodeDef* from, const NodeDef* to, const string& input_name);
|
||||
string DeviceName(const NodeDef* node) const;
|
||||
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
|
||||
|
||||
// Helper methods.
|
||||
Costs& FindOrCreateZero(const string& op_name,
|
||||
std::map<string, Costs>* op_cost);
|
||||
float Round2(const float x) const;
|
||||
bool IsPersistentNode(const NodeDef* node) const;
|
||||
|
||||
bool PopCurrNode();
|
||||
bool IsSendOp(const NodeDef* node) const;
|
||||
bool IsRecvOp(const NodeDef* node) const;
|
||||
// Scheduler states:
|
||||
std::unique_ptr<ReadyNodeManager> ready_nodes_;
|
||||
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
||||
std::unordered_map<string, DeviceState> device_;
|
||||
|
||||
GraphProperties graph_properties_;
|
||||
// Pool of NodeDefs for SendRecv and Identity ops created.
|
||||
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
|
||||
// Cache of nodes transferred to another device.
|
||||
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
|
||||
cached_recv_nodes_;
|
||||
|
||||
// Stats:
|
||||
std::map<string, int> op_counts_; // Op counts with key with input shape.
|
||||
std::map<string, int> op_costs_; // Individual op costs (with input shapes).
|
||||
Costs graph_costs_; // Graph cost.
|
||||
std::map<string, Costs> op_to_cost_; // Per-op cost.
|
||||
std::unique_ptr<ReadyNodeManager> ready_nodes_;
|
||||
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
||||
std::unordered_map<string, DeviceState> device_;
|
||||
// Pool of NodeDefs for SendRecv and Identity ops created.
|
||||
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
|
||||
// Cache of ops transferred to another device.
|
||||
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
|
||||
cached_ops_;
|
||||
|
||||
// Auxilliary data structures for constructing NodeState and DeviceState.
|
||||
GraphProperties graph_properties_;
|
||||
Cluster* cluster_; // Not owned.
|
||||
const GrapplerItem* grappler_item_; // Not owned.
|
||||
bool use_static_shapes_;
|
||||
|
|
|
|||
|
|
@ -23,42 +23,49 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
// Class for testing virtual scheduler.
|
||||
class TestVirtualScheduler : public VirtualScheduler {
|
||||
public:
|
||||
TestVirtualScheduler(const GrapplerItem* grappler_item,
|
||||
const bool use_static_shapes,
|
||||
const string& default_device_type, Cluster* cluster,
|
||||
VirtualPlacer* placer)
|
||||
: VirtualScheduler(grappler_item, use_static_shapes, default_device_type,
|
||||
cluster, placer) {}
|
||||
|
||||
FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
|
||||
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
|
||||
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
|
||||
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
|
||||
FRIEND_TEST(VirtualSchedulerTest, Variable);
|
||||
};
|
||||
|
||||
class VirtualSchedulerTest : public ::testing::Test {
|
||||
protected:
|
||||
const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
|
||||
|
||||
void SetUp() override {
|
||||
// Initializes cluster_ and placer_.
|
||||
std::unordered_map<string, DeviceProperties> devices;
|
||||
DeviceProperties cpu_device;
|
||||
cpu_device.set_type("CPU");
|
||||
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
|
||||
DeviceProperties gpu_device;
|
||||
gpu_device.set_type("GPU");
|
||||
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
|
||||
devices[kCPU0] = cpu_device;
|
||||
|
||||
cluster_.reset(new VirtualCluster(devices));
|
||||
placer_.reset(new VirtualPlacer(cluster_.get()));
|
||||
}
|
||||
|
||||
void CreateSchedulerWithConv2Ds() {
|
||||
// Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in
|
||||
// fetch nodes.
|
||||
const int bs = 4;
|
||||
const int width = 10;
|
||||
const int height = 10;
|
||||
const int depth_in = 8;
|
||||
const int kernel = 3;
|
||||
const int depth_out = 16;
|
||||
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
// Three Conv2Ds with only two in fetch nodes.
|
||||
void CreateGrapplerItemWithConv2Ds() {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||
auto x = tensorflow::ops::RandomUniform(
|
||||
s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT);
|
||||
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||
auto y = tensorflow::ops::RandomUniform(
|
||||
s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT);
|
||||
s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||
auto z = tensorflow::ops::RandomUniform(
|
||||
s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT);
|
||||
s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||
auto f = tensorflow::ops::RandomUniform(
|
||||
s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT);
|
||||
s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto c0 =
|
||||
tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
|
||||
|
|
@ -68,47 +75,253 @@ class VirtualSchedulerTest : public ::testing::Test {
|
|||
tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||
LOG(INFO) << def.DebugString();
|
||||
|
||||
grappler_item_.reset(new GrapplerItem);
|
||||
grappler_item_->id = "test_conv2d_graph";
|
||||
grappler_item_->graph = def;
|
||||
grappler_item_->fetch = {"c0", "c1"};
|
||||
|
||||
scheduler_.reset(new VirtualScheduler(
|
||||
dependency_["c0"] = {"x", "f"};
|
||||
dependency_["c1"] = {"y", "f"};
|
||||
}
|
||||
|
||||
// A Conv2D with a variable.
|
||||
void CreateGrapplerItemWithConv2DAndVariable() {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||
auto x = tensorflow::ops::RandomUniform(
|
||||
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||
auto f = tensorflow::ops::Variable(
|
||||
s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||
|
||||
grappler_item_.reset(new GrapplerItem);
|
||||
grappler_item_->id = "test_conv2d_var_graph";
|
||||
grappler_item_->graph = def;
|
||||
grappler_item_->fetch = {"y"};
|
||||
|
||||
dependency_["y"] = {"x", "f"};
|
||||
}
|
||||
|
||||
// AddN that takes 4 tensors with 10x10x10x10.
|
||||
void CreateGrapplerItemWithAddN() {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||
auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10},
|
||||
DT_FLOAT);
|
||||
auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10},
|
||||
DT_FLOAT);
|
||||
auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10},
|
||||
DT_FLOAT);
|
||||
auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10},
|
||||
DT_FLOAT);
|
||||
tensorflow::OutputList input_tensors = {x, y, z, w};
|
||||
auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors);
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||
|
||||
grappler_item_.reset(new GrapplerItem);
|
||||
grappler_item_->id = "test_addn_graph";
|
||||
grappler_item_->graph = def;
|
||||
grappler_item_->fetch = {"out"};
|
||||
|
||||
dependency_["out"] = {"x", "y", "z", "w"};
|
||||
}
|
||||
|
||||
// NoOp that takes 7 NoOps as control dependency.
|
||||
void CreateGrapplerItemWithControlDependency() {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||
std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
|
||||
std::vector<tensorflow::Operation> input_tensors;
|
||||
for (const auto& input : input_noop_names) {
|
||||
auto x = tensorflow::ops::NoOp(s.WithOpName(input));
|
||||
input_tensors.push_back(x.operation);
|
||||
}
|
||||
auto out = tensorflow::ops::NoOp(
|
||||
s.WithControlDependencies(input_tensors).WithOpName("out"));
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||
|
||||
grappler_item_.reset(new GrapplerItem);
|
||||
grappler_item_->id = "test_control_dependency_graph";
|
||||
grappler_item_->graph = def;
|
||||
grappler_item_->fetch = {"out"};
|
||||
|
||||
dependency_["out"] = input_noop_names;
|
||||
}
|
||||
|
||||
// FusedBN [an op with multiple outputs] with multiple consumers (including
|
||||
// control dependency).
|
||||
void CreateGrapplerItemWithBatchNorm() {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||
auto x = tensorflow::ops::RandomUniform(
|
||||
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||
auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
|
||||
{depth_in_}, DT_FLOAT);
|
||||
auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
|
||||
{depth_in_}, DT_FLOAT);
|
||||
auto mean =
|
||||
tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
|
||||
auto var =
|
||||
tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
|
||||
|
||||
auto batch_norm = tensorflow::ops::FusedBatchNorm(
|
||||
s.WithOpName("bn"), x, scale, offset, mean, var,
|
||||
ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
|
||||
auto y = batch_norm.y;
|
||||
auto batch_mean = batch_norm.batch_mean;
|
||||
auto batch_var = batch_norm.batch_variance;
|
||||
|
||||
auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y);
|
||||
auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var);
|
||||
auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var);
|
||||
std::vector<tensorflow::Operation> input_tensors = {
|
||||
batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(),
|
||||
};
|
||||
auto z4 = tensorflow::ops::NoOp(
|
||||
s.WithControlDependencies(batch_var).WithOpName("z4"));
|
||||
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||
|
||||
grappler_item_.reset(new GrapplerItem);
|
||||
grappler_item_->id = "test_complex_dependency_graph";
|
||||
grappler_item_->graph = def;
|
||||
grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
|
||||
|
||||
dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
|
||||
dependency_["z1"] = {"x", "bn"};
|
||||
dependency_["z2"] = {"bn"};
|
||||
dependency_["z3"] = {"bn"};
|
||||
dependency_["z4"] = {"bn"};
|
||||
}
|
||||
|
||||
// Call this after creating grappler_item_ and setting up dependency_.
|
||||
void InitScheduler() {
|
||||
scheduler_.reset(new TestVirtualScheduler(
|
||||
grappler_item_.get(), true /* use_static_shapes */,
|
||||
"CPU" /* default_device_type */, cluster_.get(), placer_.get()));
|
||||
TF_CHECK_OK(scheduler_->Init());
|
||||
}
|
||||
|
||||
// Call this after init scheduler_. Scheduler stops after executing
|
||||
// target_node.
|
||||
std::unordered_map<string, NodeInfo> RunScheduler(const string& target_node) {
|
||||
Costs zero_costs = Costs::ZeroCosts();
|
||||
std::unordered_map<string, NodeInfo> ops_executed;
|
||||
bool more_nodes = true;
|
||||
do {
|
||||
NodeInfo node_info = scheduler_->GetCurrNodeInfo();
|
||||
ops_executed[node_info.name] = node_info;
|
||||
|
||||
// Check scheduling order.
|
||||
auto it = dependency_.find(node_info.name);
|
||||
if (it != dependency_.end()) {
|
||||
for (const auto& preceding_node : it->second) {
|
||||
EXPECT_GT(ops_executed.count(preceding_node), 0);
|
||||
}
|
||||
}
|
||||
more_nodes = scheduler_->MarkCurrNodeExecuted(zero_costs);
|
||||
|
||||
if (node_info.name == target_node) {
|
||||
// Scheduler has the state after executing the target node.
|
||||
break;
|
||||
}
|
||||
} while (more_nodes);
|
||||
return ops_executed;
|
||||
}
|
||||
|
||||
// Helper method for validating a vector.
|
||||
template <typename T>
|
||||
void ExpectVectorEq(const std::vector<T>& expected,
|
||||
const std::vector<T>& test_elements) {
|
||||
// Set of expected elements for an easy comparison.
|
||||
std::set<T> expected_set(expected.begin(), expected.end());
|
||||
for (const auto& element : test_elements) {
|
||||
EXPECT_GT(expected_set.count(element), 0);
|
||||
}
|
||||
EXPECT_EQ(expected.size(), test_elements.size());
|
||||
}
|
||||
|
||||
// Helper method that checks the name of nodes.
|
||||
void ValidateNodeDefs(const std::vector<string>& expected,
|
||||
const std::vector<const NodeDef*>& node_defs) {
|
||||
std::vector<string> node_names;
|
||||
std::transform(node_defs.begin(), node_defs.end(),
|
||||
std::back_inserter(node_names),
|
||||
[](const NodeDef* node) { return node->name(); });
|
||||
ExpectVectorEq(expected, node_names);
|
||||
}
|
||||
|
||||
// Helper method for validating a set.
|
||||
template <typename T>
|
||||
void ExpectSetEq(const std::set<T>& expected,
|
||||
const std::set<T>& test_elements) {
|
||||
for (const auto& element : test_elements) {
|
||||
EXPECT_GT(expected.count(element), 0);
|
||||
}
|
||||
EXPECT_EQ(expected.size(), test_elements.size());
|
||||
}
|
||||
|
||||
// Helper method tthat checks name - port pairs.
|
||||
void ValidateMemoryUsageSnapshot(
|
||||
const std::vector<string>& expected_names, const int port_num_expected,
|
||||
const std::set<std::pair<const NodeDef*, int>>& mem_usage_snapshot) {
|
||||
std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
|
||||
std::transform(
|
||||
mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
|
||||
std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
|
||||
[](const std::pair<const NodeDef*, int>& node_port) {
|
||||
return std::make_pair(node_port.first->name(), node_port.second);
|
||||
});
|
||||
std::set<std::pair<string, int>> expected;
|
||||
std::transform(expected_names.begin(), expected_names.end(),
|
||||
std::inserter(expected, expected.begin()),
|
||||
[port_num_expected](const string& name) {
|
||||
return std::make_pair(name, port_num_expected);
|
||||
});
|
||||
ExpectSetEq(expected, nodes_at_peak_mem_usage);
|
||||
}
|
||||
|
||||
// Helper method for converting shape vector to TensorProperty.
|
||||
OpInfo::TensorProperties ShapeToTensorProperty(
|
||||
const std::vector<int> shape, const DataType& data_type) const {
|
||||
OpInfo::TensorProperties tensor_property;
|
||||
tensor_property.set_dtype(data_type);
|
||||
for (const auto& x : shape) {
|
||||
tensor_property.mutable_shape()->add_dim()->set_size(x);
|
||||
}
|
||||
return tensor_property;
|
||||
}
|
||||
|
||||
// SetUp() inits cluster_ and placer_.
|
||||
std::unique_ptr<VirtualCluster> cluster_;
|
||||
std::unique_ptr<VirtualPlacer> placer_;
|
||||
|
||||
// grappler_item_ and scheduler_ will be initialized differently for each test
|
||||
// case
|
||||
// case.
|
||||
std::unique_ptr<GrapplerItem> grappler_item_;
|
||||
std::unique_ptr<VirtualScheduler> scheduler_;
|
||||
std::unique_ptr<TestVirtualScheduler> scheduler_;
|
||||
// Node name -> its preceding nodes map for testing scheduling order.
|
||||
std::unordered_map<string, std::vector<string>> dependency_;
|
||||
|
||||
// Shared params for Conv2D related graphs:
|
||||
const int batch_size_ = 4;
|
||||
const int width_ = 10;
|
||||
const int height_ = 10;
|
||||
const int depth_in_ = 8;
|
||||
const int kernel_ = 3;
|
||||
const int depth_out_ = 16;
|
||||
};
|
||||
|
||||
TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
|
||||
CreateSchedulerWithConv2Ds(); // init scheduler_.
|
||||
// Init.
|
||||
CreateGrapplerItemWithConv2Ds();
|
||||
InitScheduler();
|
||||
|
||||
Costs zero_costs = Costs::ZeroCosts();
|
||||
std::unordered_map<string, NodeInfo> ops_executed;
|
||||
do {
|
||||
NodeInfo node_info = scheduler_->GetCurrNodeInfo();
|
||||
ops_executed[node_info.name] = node_info;
|
||||
|
||||
// Check scheduling order: x and f before c0, and y and f before c1.
|
||||
if (node_info.name == "c0") {
|
||||
EXPECT_GT(ops_executed.count("x"), 0);
|
||||
EXPECT_GT(ops_executed.count("f"), 0);
|
||||
} else if (node_info.name == "c1") {
|
||||
EXPECT_GT(ops_executed.count("y"), 0);
|
||||
EXPECT_GT(ops_executed.count("f"), 0);
|
||||
}
|
||||
} while (scheduler_->MarkCurrNodeExecuted(zero_costs));
|
||||
// Run the scheduler.
|
||||
auto ops_executed = RunScheduler(""); // Run all the nodes.
|
||||
|
||||
// [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
|
||||
// executed.
|
||||
|
|
@ -132,5 +345,162 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
|
|||
EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
|
||||
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithAddN();
|
||||
InitScheduler();
|
||||
|
||||
// Create a set of tensor properties.
|
||||
std::vector<OpInfo::TensorProperties> output;
|
||||
output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
|
||||
output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
|
||||
output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
|
||||
output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
|
||||
output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
|
||||
output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
|
||||
|
||||
// port_num -1 is for control dependency: hard coded 4B.
|
||||
EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
|
||||
|
||||
// Test valid outputs.
|
||||
EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
|
||||
EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
|
||||
EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
|
||||
EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
|
||||
|
||||
// Any uknown shape (-1) shall yield zero output size.
|
||||
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
|
||||
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
|
||||
|
||||
// Invalid port_num (though it may be an error) shall yield zero
|
||||
// output size.
|
||||
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, MemoryUsage) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithAddN();
|
||||
InitScheduler();
|
||||
|
||||
// Run the scheduler.
|
||||
RunScheduler("");
|
||||
|
||||
const auto& device_states = scheduler_->GetDeviceStates();
|
||||
const auto& cpu_state = device_states.at(kCPU0);
|
||||
|
||||
// out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
|
||||
// is 4 x the input tensor size while executing the out node.
|
||||
int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
|
||||
const std::vector<string> expected_names = {"x", "y", "z", "w"};
|
||||
EXPECT_EQ(expected_names.size() * one_input_node_size,
|
||||
cpu_state.max_memory_usage);
|
||||
ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
|
||||
cpu_state.mem_usage_snapshot_at_peak);
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, ControlDependency) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithControlDependency();
|
||||
InitScheduler();
|
||||
|
||||
// Run the scheduler.
|
||||
RunScheduler("");
|
||||
|
||||
const auto& device_states = scheduler_->GetDeviceStates();
|
||||
const auto& cpu_state = device_states.at(kCPU0);
|
||||
|
||||
// The graph has a NoOp that takes control dependency from 7 NoOps. The peak
|
||||
// memory usage is when executing the final NoOp.
|
||||
int64 one_input_node_size = 4; // control dependency
|
||||
const std::vector<string> expected_names = {"x", "y", "z", "w",
|
||||
"u", "v", "t"};
|
||||
EXPECT_EQ(expected_names.size() * one_input_node_size,
|
||||
cpu_state.max_memory_usage);
|
||||
ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
|
||||
cpu_state.mem_usage_snapshot_at_peak);
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, ComplexDependency) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithBatchNorm();
|
||||
InitScheduler();
|
||||
|
||||
// Run the scheduler.
|
||||
RunScheduler("bn");
|
||||
|
||||
const auto& device_states = scheduler_->GetDeviceStates();
|
||||
const auto& cpu_state = device_states.at(kCPU0);
|
||||
|
||||
// The graph is
|
||||
// bn = FusedBatchNorm(x, scale, offset, mean, var)
|
||||
// z1 = bn.y + x
|
||||
// z2 = bn.var + bn.var
|
||||
// z3 = bn.var + bn.var
|
||||
// z4 = control dependency from bn.
|
||||
// Note that bn.mean doesn't have any consumer.
|
||||
const int x_size = batch_size_ * width_ * height_ * depth_in_;
|
||||
int64 expected_size =
|
||||
4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
|
||||
1 /* control dependency */);
|
||||
EXPECT_EQ(expected_size, cpu_state.memory_usage);
|
||||
|
||||
// Nodes currrently in memory: bn's port -1, 0, and 2, and x's port 0.
|
||||
std::set<std::pair<string, int>> nodes_in_memory;
|
||||
std::transform(
|
||||
cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
|
||||
std::inserter(nodes_in_memory, nodes_in_memory.begin()),
|
||||
[](const std::pair<const NodeDef*, int>& node_port) {
|
||||
return std::make_pair(node_port.first->name(), node_port.second);
|
||||
});
|
||||
std::set<std::pair<string, int>> expected = {
|
||||
std::make_pair("bn", -1), std::make_pair("bn", 0),
|
||||
std::make_pair("bn", 2), std::make_pair("x", 0),
|
||||
};
|
||||
ExpectSetEq(expected, nodes_in_memory);
|
||||
|
||||
const auto& node_states = scheduler_->GetNodeStates();
|
||||
const NodeState* bn_node = nullptr;
|
||||
const NodeState* x_node = nullptr;
|
||||
for (const auto& nodedef_node_state : node_states) {
|
||||
const NodeDef* node = nodedef_node_state.first;
|
||||
const NodeState& node_state = nodedef_node_state.second;
|
||||
if (node->name() == "bn") {
|
||||
bn_node = &node_state;
|
||||
}
|
||||
if (node->name() == "x") {
|
||||
x_node = &node_state;
|
||||
}
|
||||
}
|
||||
CHECK_NOTNULL(bn_node);
|
||||
CHECK_NOTNULL(x_node);
|
||||
|
||||
ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
|
||||
ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
|
||||
ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
|
||||
// z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
|
||||
ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
|
||||
}
|
||||
|
||||
TEST_F(VirtualSchedulerTest, Variable) {
|
||||
// Init.
|
||||
CreateGrapplerItemWithConv2DAndVariable();
|
||||
InitScheduler();
|
||||
|
||||
// Run the scheduler.
|
||||
RunScheduler("");
|
||||
|
||||
const auto& device_states = scheduler_->GetDeviceStates();
|
||||
const auto& cpu_state = device_states.at(kCPU0);
|
||||
|
||||
// There is one Conv2D that takes x and f, but f is variable, so it should be
|
||||
// in persistent nodes.
|
||||
// f is variable.
|
||||
ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
|
||||
cpu_state.persistent_nodes);
|
||||
// Only x in peak memory usage snapshot.
|
||||
ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
|
||||
cpu_state.mem_usage_snapshot_at_peak);
|
||||
}
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -45,18 +45,33 @@ bool IsMerge(const NodeDef& node) {
|
|||
return op == "Merge";
|
||||
}
|
||||
|
||||
bool IsNoOp(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "NoOp";
|
||||
}
|
||||
|
||||
bool IsPlaceholder(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Placeholder" || op == "PlaceholderV2" ||
|
||||
op == "PlaceholderWithDefault";
|
||||
}
|
||||
|
||||
bool IsRecv(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "_Recv";
|
||||
}
|
||||
|
||||
bool IsReduction(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
|
||||
op == "Mean" || op == "Any" || op == "All";
|
||||
}
|
||||
|
||||
bool IsSend(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "_Send";
|
||||
}
|
||||
|
||||
bool IsSwitch(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Switch";
|
||||
|
|
|
|||
|
|
@ -26,8 +26,11 @@ bool IsConstant(const NodeDef& node);
|
|||
bool IsDequeueOp(const NodeDef& node);
|
||||
bool IsIdentity(const NodeDef& node);
|
||||
bool IsMerge(const NodeDef& node);
|
||||
bool IsNoOp(const NodeDef& node);
|
||||
bool IsPlaceholder(const NodeDef& node);
|
||||
bool IsRecv(const NodeDef& node);
|
||||
bool IsReduction(const NodeDef& node);
|
||||
bool IsSend(const NodeDef& node);
|
||||
bool IsSwitch(const NodeDef& node);
|
||||
bool IsTranspose(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user