Profile memory usage in VirtualScheduler and report peak memory usage.

To do so, NodeState now handles different output ports of a node (in case
a node has multiple outputs).

Also, VirtualScheduler code is cleaned up with more comments.

PiperOrigin-RevId: 158209068
This commit is contained in:
A. Unique TensorFlower 2017-06-06 16:45:02 -07:00 committed by TensorFlower Gardener
parent 0ea0bf5aae
commit 8f89b654f4
7 changed files with 826 additions and 201 deletions

View File

@ -176,6 +176,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:utils", "//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/costs:cost_estimator", "//tensorflow/core/grappler/costs:cost_estimator",
@ -192,6 +193,10 @@ cc_test(
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/clusters:virtual_cluster",
], ],
) )

View File

@ -31,6 +31,8 @@ constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape"; constexpr char kReshape[] = "Reshape";
constexpr char kRecv[] = "_Recv"; constexpr char kRecv[] = "_Recv";
constexpr char kBatchMatMul[] = "BatchMatMul"; constexpr char kBatchMatMul[] = "BatchMatMul";
constexpr char kVariable[] = "Variable";
constexpr char kVariableV2[] = "VariableV2";
OpLevelCostEstimator::OpLevelCostEstimator() { OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and // Syntactic sugar to build and return a lambda that takes an OpInfo and
@ -53,6 +55,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}}; {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
} }
@ -567,7 +571,7 @@ int64 OpLevelCostEstimator::CalculateSingleInputSize(
for (const auto& dim : input_shape.dim()) { for (const auto& dim : input_shape.dim()) {
input_size *= dim.size(); input_size *= dim.size();
} }
return input_size * DataTypeSize(input.dtype()); return input_size * DataTypeSize(BaseType(input.dtype()));
} }
int64 OpLevelCostEstimator::CalculateInputSize( int64 OpLevelCostEstimator::CalculateInputSize(
@ -589,7 +593,7 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
for (const auto& output : op_features.outputs()) { for (const auto& output : op_features.outputs()) {
DataType dt = output.dtype(); DataType dt = output.dtype();
const auto& original_output_shape = output.shape(); const auto& original_output_shape = output.shape();
int64 output_size = DataTypeSize(dt); int64 output_size = DataTypeSize(BaseType(dt));
int num_dims = std::max(1, original_output_shape.dim_size()); int num_dims = std::max(1, original_output_shape.dim_size());
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims, auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
found_unknown_shapes); found_unknown_shapes);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -55,10 +56,10 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes, const bool use_static_shapes,
const string& default_device_type, const string& default_device_type,
Cluster* cluster, VirtualPlacer* placer) Cluster* cluster, VirtualPlacer* placer)
: graph_properties_(*grappler_item), : // TODO(dyoon): Use a better way than FIFO.
graph_costs_(Costs::ZeroCosts()),
// TODO(dyoon): Use a better way than FIFO.
ready_nodes_(new FIFOManager()), ready_nodes_(new FIFOManager()),
graph_costs_(Costs::ZeroCosts()),
graph_properties_(*grappler_item),
cluster_(cluster), cluster_(cluster),
grappler_item_(grappler_item), grappler_item_(grappler_item),
use_static_shapes_(use_static_shapes), use_static_shapes_(use_static_shapes),
@ -68,6 +69,11 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
} }
Status VirtualScheduler::Init() { Status VirtualScheduler::Init() {
// Init() preprocesses the input grappler_item and graph_properties to extract
// necessary information for emulating tensorflow op scheduling and
// construct internal data structures (NodeState and DeviceState) for virtual
// scheduling.
// Construct graph properties. // Construct graph properties.
Status status; Status status;
if (use_static_shapes_) { if (use_static_shapes_) {
@ -82,13 +88,12 @@ Status VirtualScheduler::Init() {
const auto& graph = grappler_item_->graph; const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch; const auto& fetch_nodes = grappler_item_->fetch;
// First, get the nodes that would run to output fetch_nodes. // Get the nodes that would run to output fetch_nodes.
std::vector<const NodeDef*> nodes = std::vector<const NodeDef*> nodes =
ComputeTransitiveFanin(graph, fetch_nodes); ComputeTransitiveFanin(graph, fetch_nodes);
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in // TODO(dyoon): this is a bit inefficient as name_to_node is already built in
// ComputeTransitiveFanin(). // ComputeTransitiveFanin().
//
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
// from the fetch nodes are scheduled. So the scheduled nodes should be // from the fetch nodes are scheduled. So the scheduled nodes should be
// exactly the same as those executed for real. One possible discrepancy could // exactly the same as those executed for real. One possible discrepancy could
@ -98,61 +103,72 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node; name_to_node[node->name()] = node;
} }
// Build node_map. // Build node_map; for each node, create its NodeState and connect its inputs
// and outputs.
for (const auto* curr_node : nodes) { for (const auto* curr_node : nodes) {
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
const string curr_node_device = DeviceName(curr_node); const string curr_node_device = DeviceName(curr_node);
for (const string& input_node_name : curr_node->input()) { for (const string& input_node_name : curr_node->input()) {
// Note that input_node_name may be in <node_name>:<output_number> format, // Note that input_node_name may be in <prefix><node_name>:<port_num>
// where ":<output_number>" may be omitted. NodeName() extracts only the // format, where <prefix> (e.g., "^" for control dependency) and
// node_name (prefeix "^", if there was for control input, is also // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
// deleted).
const NodeDef* input_node = name_to_node[NodeName(input_node_name)]; const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
CHECK(input_node); CHECK(input_node);
// Add input_to_curr_node to curr_node's input, and
// add output_to_input_node to input_source_node's output.
// Default values for when input_node and curr_node on the same device.
const NodeDef* input_to_curr_node = input_node;
const NodeDef* input_source_node = input_node;
const NodeDef* output_to_input_node = curr_node;
const string in_device = DeviceName(input_node); const string in_device = DeviceName(input_node);
if (curr_node_device != in_device) { const auto input_node_port_num = NodePosition(input_node_name);
if (cached_ops_.count(input_node) > 0 &&
cached_ops_[input_node].count(curr_node_device) > 0) { if (curr_node_device == in_device) {
// Different device, but found an already-transferred copy; connect // Same device: connect input_node and curr_node directly.
// the cached node to curr_node. curr_node_state.inputs.push_back(
input_to_curr_node = cached_ops_[input_node][curr_node_device]; std::make_pair(input_node, input_node_port_num));
input_source_node = input_to_curr_node; auto& input_node_state = GetNodeStateOrCreateIt(input_node);
output_to_input_node = curr_node; input_node_state.outputs[input_node_port_num].push_back(curr_node);
} else {
if (cached_recv_nodes_.count(input_node) > 0 &&
cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
// Different device, but found an already-cached copy (a _Recv op);
// connect the _Recv to curr_node.
const auto* recv_op =
cached_recv_nodes_[input_node][curr_node_device];
// recv_op's output port is hard-coded to zero.
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
auto& input_node_state = node_map_.at(recv_op);
input_node_state.outputs[0].push_back(curr_node);
} else { } else {
// Different device, no cached copy; transfer input_node to the // Different device, no cached copy; transfer input_node to the
// curr_node's device. // curr_node's device.
auto sendrecv_and_identity = auto send_and_recv =
TransferNode(input_node, curr_node, input_node_name); CreateSendRecv(input_node, curr_node, input_node_name);
const auto* sendrecv = sendrecv_and_identity.first; // Note that CreateSendRecv() already connected input/output between
const auto* identity = sendrecv_and_identity.second; // _Send and _Recv ops.
input_to_curr_node = identity; const auto* send = send_and_recv.first;
input_source_node = input_node; const auto* recv = send_and_recv.second;
output_to_input_node = sendrecv; // recv_op's output port is hard-coded to zero.
curr_node_state.inputs.push_back(std::make_pair(recv, 0));
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
input_node_state.outputs[input_node_port_num].push_back(send);
// Cache the identity op for future use. // Cache the _Recv op for future use.
cached_ops_[input_node][curr_node_device] = identity; cached_recv_nodes_[input_node][curr_node_device] = recv;
} }
} }
curr_node_state.inputs.push_back(input_to_curr_node);
// Note that we do not care output number (in case a tf op has multiple
// outputs), as VirtualScheduler only cares which nodes become ready as
// a node is executed.
auto& input_node_state = GetNodeStateOrCreateIt(input_source_node);
input_node_state.outputs.push_back(output_to_input_node);
} }
if (curr_node->input().empty()) { if (curr_node->input().empty()) {
curr_node_state.time_ready = // Node without input: ready at time 0.
Costs::Duration(); // Node without input: ready at time 0. curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node); ready_nodes_->AddNode(curr_node);
} }
if (IsPersistentNode(curr_node)) {
auto& device_state = device_[curr_node_device];
for (int port_num = 0;
port_num < curr_node_state.output_properties.size(); ++port_num) {
device_state.persistent_nodes.insert(
std::make_pair(curr_node, port_num));
}
}
} }
if (ready_nodes_->Empty()) { if (ready_nodes_->Empty()) {
@ -163,18 +179,26 @@ Status VirtualScheduler::Init() {
return Status::OK(); return Status::OK();
} }
void VirtualScheduler::MaybeUpdateInputProperties( void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const { CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
if (IsSendOp(node) || IsRecvOp(node)) { // This method is called when NodeState is created and adds input and output
// _Send and _Recv ops are inserted from VirtualScheduler, so // properties for a few exceptional cases that GraphProperties cannot provide
// input/output properties.
if (IsSend(*node) || IsRecv(*node)) {
auto& node_state = node_map_[node];
auto& inputs = node_state.input_properties;
auto& outputs = node_state.output_properties;
// _Send and _Recv ops are created from VirtualScheduler, so
// there should be no inputs TensorProperties. // there should be no inputs TensorProperties.
CHECK_EQ(inputs->size(), 0); CHECK(inputs.empty());
CHECK(outputs.empty());
const auto& attr = node->attr(); const auto& attr = node->attr();
// This is the original input source to the _Send and _Recv, and this // This is the original input source to the _Send and _Recv, and this
// string includes "^" if it was control dependency, and output port // string includes "^" if it was control dependency, and output port
/// (e.g., ":2") if the input source had multiple outputs. /// (e.g., ":2") if the input source had multiple outputs.
const auto& input_source_name = attr.at(kAttrInputSrc).s(); const auto& input_source_name = attr.at(kAttrInputSrc).s();
if (input_source_name[0] == '^') { if (IsControlInput(input_source_name)) {
// Control dependency; regardless of the input source tensor size, // Control dependency; regardless of the input source tensor size,
// send 4B. // send 4B.
OpInfo::TensorProperties control_message; OpInfo::TensorProperties control_message;
@ -182,51 +206,53 @@ void VirtualScheduler::MaybeUpdateInputProperties(
control_message.mutable_shape()->add_dim()->set_size(1); control_message.mutable_shape()->add_dim()->set_size(1);
auto* value = control_message.mutable_value(); auto* value = control_message.mutable_value();
value->add_float_val(1); value->add_float_val(1);
inputs->push_back(control_message); inputs.push_back(control_message);
outputs.push_back(control_message);
} else { } else {
auto output_properties =
graph_properties_.GetOutputProperties(NodeName(input_source_name));
// Like with HasInputProperties, if a node does not have output // Like with HasInputProperties, if a node does not have output
// properties, it's likely it was pruned during the shape inference run. // properties, it's likely it was pruned during the shape inference run.
if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) { if (!output_properties.empty()) {
const auto input_position = NodePosition(input_source_name); const auto input_node_port_num = NodePosition(input_source_name);
// Use the input source's output property as _Send and _Recv's input // Use the input source's output property as _Send and _Recv's input
// property. // property.
auto outputs = CHECK_GT(output_properties.size(), input_node_port_num);
graph_properties_.GetOutputProperties(NodeName(input_source_name)); inputs.push_back(output_properties[input_node_port_num]);
CHECK_GT(outputs.size(), input_position); outputs.push_back(output_properties[input_node_port_num]);
inputs->push_back(outputs[input_position]);
} }
} }
} }
} }
bool VirtualScheduler::IsSendOp(const NodeDef* node) const { float VirtualScheduler::Round2(const float x) const {
return node->op() == kSend; return std::round(100.0 * x) / 100.0;
} }
bool VirtualScheduler::IsRecvOp(const NodeDef* node) const { bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
return node->op() == kRecv; // Variables are persistent nodes.
return IsVariable(*node);
} }
string VirtualScheduler::DeviceName(const NodeDef* node) const { string VirtualScheduler::DeviceName(const NodeDef* node) const {
CHECK(!initialized_) << "DeviceName is called after Init().";
// TODO(dyoon): integrate this part with VirtualPlacer. // TODO(dyoon): integrate this part with VirtualPlacer.
if (IsSendOp(node)) { return node->device().empty() ? "/device:" + default_device_type_ + ":0"
const auto& node_state = node_map_.at(node); : node->device();
const auto* from = node_state.inputs[0];
const auto* to = node_state.outputs[0];
return ChannelDeviceName(from, to);
} else {
return node->device().empty() ? "/" + default_device_type_ + ":0"
: node->device();
}
} }
string VirtualScheduler::ChannelDeviceName(const NodeDef* from, string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const { const NodeDef* to) const {
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to); return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to);
} }
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode( std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
const NodeDef* from, const NodeDef* to, const string& input_name) { const NodeDef* from, const NodeDef* to, const string& input_name) {
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
// Connect "from" node to "to" node with _Send and _Recv such that // Connect "from" node to "to" node with _Send and _Recv such that
// from -> _Send -> _Recv -> to. // from -> _Send -> _Recv -> to.
// _Send is placed on "Channel" device, and _Recv is on the same device // _Send is placed on "Channel" device, and _Recv is on the same device
@ -238,11 +264,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
// NodeDefs created here need not be correct: in terms of name, // NodeDefs created here need not be correct: in terms of name,
// input names, attrs, etc. // input names, attrs, etc.
auto input_node_port_num = NodePosition(input_name);
// _Send op. // _Send op.
auto* send = new NodeDef(); auto* send = new NodeDef();
send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " + send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
DeviceName(to)); DeviceName(to));
send->set_op(kSend); send->set_op("_Send");
send->add_input(from->name()); send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to)); send->set_device(ChannelDeviceName(from, to));
auto& send_attr = *(send->mutable_attr()); auto& send_attr = *(send->mutable_attr());
@ -253,19 +281,22 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
// _Recv op. // _Recv op.
auto* recv = new NodeDef(); auto* recv = new NodeDef();
recv->set_name("Recv " + from->name() + " on " + DeviceName(to)); recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
recv->set_op(kRecv); recv->set_op("_Recv");
recv->add_input(send->name()); recv->add_input(send->name());
recv->set_device(DeviceName(to)); recv->set_device(DeviceName(to));
auto& recv_attr = *(recv->mutable_attr()); auto& recv_attr = *(recv->mutable_attr());
recv_attr[kAttrInputSrc].set_s(input_name); recv_attr[kAttrInputSrc].set_s(input_name);
// Update NodeState for _Send and _Recv ops. // NodeState for _Send op.
auto& send_node_state = GetNodeStateOrCreateIt(send); auto& send_node_state = GetNodeStateOrCreateIt(send);
send_node_state.inputs.push_back(from); send_node_state.device_name = send->device(); // Set Channel device.
send_node_state.outputs.push_back(recv); send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
send_node_state.outputs[0].push_back(recv);
// NodeState for _Recv op.
auto& recv_node_state = GetNodeStateOrCreateIt(recv); auto& recv_node_state = GetNodeStateOrCreateIt(recv);
recv_node_state.inputs.push_back(send); recv_node_state.inputs.push_back(std::make_pair(send, 0));
recv_node_state.outputs.push_back(to); recv_node_state.outputs[0].push_back(to);
// Keep the created nodes. // Keep the created nodes.
additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send)); additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
@ -277,13 +308,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
NodeInfo VirtualScheduler::GetCurrNodeInfo() const { NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
const NodeDef* node = ready_nodes_->GetCurrNode(); const NodeDef* node = ready_nodes_->GetCurrNode();
std::vector<OpInfo::TensorProperties> inputs =
graph_properties_.GetInputProperties(node->name());
// Some ops created within VirtualScheduler may need further processing to
// the input properties.
MaybeUpdateInputProperties(node, &inputs);
// This is for compatibility; we can just use palcer_->get_device() for all // This is for compatibility; we can just use placer_->get_device() for all
// cases, once VirtualCluster is properly set up. // cases, once VirtualCluster is properly set up.
DeviceProperties device; DeviceProperties device;
if (placer_) { if (placer_) {
@ -294,7 +320,8 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
int device_id; int device_id;
DeviceNameUtils::ParsedName parsed; DeviceNameUtils::ParsedName parsed;
if (!node->device().empty() && if (!node->device().empty() &&
DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) { DeviceNameUtils::ParseFullName(node_map_.at(node).device_name,
&parsed)) {
device_type = parsed.type; device_type = parsed.type;
device_id = parsed.id; device_id = parsed.id;
} else { } else {
@ -309,81 +336,111 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
} }
// Special case for _Send op. // Special case for _Send op.
if (IsSendOp(node)) { if (IsSend(*node)) {
device.set_type(kChannelDevice); device.set_type(kChannelDevice);
} }
// Construct NodeInfo.
const auto& node_state = node_map_.at(node);
NodeInfo node_info; NodeInfo node_info;
node_info.name = node->name(); node_info.name = node->name();
node_info.device_name = graph_properties_.GetDeviceName(node->name()); node_info.device_name = node_state.device_name;
std::vector<OpInfo::TensorProperties> outputs =
graph_properties_.GetOutputProperties(node->name());
auto& op_info = node_info.op_info; auto& op_info = node_info.op_info;
op_info.set_op(node->op()); op_info.set_op(node->op());
*op_info.mutable_attr() = node->attr(); *op_info.mutable_attr() = node->attr();
for (auto& input : inputs) { for (auto& input : node_state.input_properties) {
op_info.add_inputs()->Swap(&input); *op_info.add_inputs() = input;
} }
for (auto& output : outputs) { for (auto& output : node_state.output_properties) {
op_info.add_outputs()->Swap(&output); *op_info.add_outputs() = output;
} }
op_info.mutable_device()->Swap(&device); op_info.mutable_device()->Swap(&device);
// add some more to the node_info.
return node_info; return node_info;
} }
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) { NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
auto it = node_map_.find(node); auto it = node_map_.find(node);
if (it == node_map_.end()) { if (it == node_map_.end()) {
// Not found; create a NodeState for this node.
it = node_map_.emplace(node, NodeState()).first; it = node_map_.emplace(node, NodeState()).first;
auto& node_state = it->second;
node_state.input_properties =
graph_properties_.GetInputProperties(node->name());
node_state.output_properties =
graph_properties_.GetOutputProperties(node->name());
// Some ops may need further processing to the input / output properties:
// _Send and _Recv.
MaybeUpdateInputOutput(node);
if (!IsSend(*node)) {
node_state.device_name = DeviceName(node);
// For _Send op, device_name will be set to Channel in CreateSendRecv().
}
// Initialize output port related data:
// Assume the size of OutputProperties represents the number of output ports
// of this node.
for (int i = 0; i < node_state.output_properties.size(); ++i) {
node_state.time_no_references[i] = Costs::Duration::max();
node_state.num_outputs_executed[i] = 0;
// Populate an empty vector for each port. The caller will add nodes
// that use this port as input.
node_state.outputs[i] = {};
}
// Port_num -1 is for control dependency.
node_state.time_no_references[-1] = Costs::Duration::max();
node_state.num_outputs_executed[-1] = 0;
node_state.outputs[-1] = {};
} }
return it->second; return it->second;
} }
int64 VirtualScheduler::CalculateOutputSize(
const std::vector<OpInfo::TensorProperties>& output_properties,
const int port_num) const {
if (port_num < 0) {
return 4; // 4B for control dependency.
}
if (port_num >= output_properties.size()) {
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
<< "port_num: " << port_num
<< " >= output_properties.size(): " << output_properties.size();
return 0;
}
const auto& output = output_properties[port_num];
int64 output_size = DataTypeSize(BaseType(output.dtype()));
for (const auto& dim : output.shape().dim()) {
auto dim_size = dim.size();
if (dim_size < 0) {
// Zero output size if there's any unknown dim.
output_size = 0;
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
<< "unknown dim: " << output_size;
break;
}
output_size *= dim_size;
}
return output_size;
}
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
std::map<string, Costs>* op_cost) { std::map<string, Costs>* op_cost) {
auto it = op_cost->find(op_name); auto it = op_cost->find(op_name);
if (it == op_cost->end()) { if (it == op_cost->end()) {
// Note that default constructor of Costs sets some memory related fields
// to unknown values so we should explicitly initialize it with ZeroCosts.
it = op_cost->emplace(op_name, Costs::ZeroCosts()).first; it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
} }
return it->second; return it->second;
} }
bool VirtualScheduler::PopCurrNode() {
const auto* node = ready_nodes_->GetCurrNode();
auto& node_state = node_map_[node];
auto& device = device_[DeviceName(node)];
auto curr_time = device.GetCurrTime();
// Increment num_inputs_ready of the output nodes.
for (auto* output : node_state.outputs) {
auto& output_state = node_map_[output];
output_state.num_inputs_ready++;
if (output_state.num_inputs_ready == output_state.inputs.size()) {
// This output node is now ready.
output_state.time_ready = curr_time;
ready_nodes_->AddNode(output);
}
}
// Increment num_outputs_executed of the input nodes.
for (auto* input : node_state.inputs) {
auto& input_state = node_map_[input];
input_state.num_outputs_executed++;
if (input_state.num_outputs_executed == input_state.outputs.size()) {
// All the outputs are executed; no reference to this input nodel
input_state.time_no_reference = curr_time;
// TODO(dyoon): collect device memory usage; note that this input node
// use device memory between time_scheduled and time_no_reference.
}
}
// Remove the current node; assume FIFO.
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty();
}
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs. // Update graph_costs_ and per-op costs.
graph_costs_ = CombineCosts(graph_costs_, node_costs); graph_costs_ = CombineCosts(graph_costs_, node_costs);
@ -402,7 +459,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update node and device states. // Update node and device states.
auto& node_state = node_map_[node]; auto& node_state = node_map_[node];
auto& device = device_[DeviceName(node)]; auto& device = device_[node_state.device_name];
device.nodes_executed.push_back(node); device.nodes_executed.push_back(node);
// Node is scheduled when the device is available AND all the inputs are // Node is scheduled when the device is available AND all the inputs are
// ready; hence, time_scheduled is time_ready if time_ready > device curr // ready; hence, time_scheduled is time_ready if time_ready > device curr
@ -415,6 +472,21 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
auto curr_time = device.GetCurrTime(); auto curr_time = device.GetCurrTime();
node_state.time_finished = curr_time; node_state.time_finished = curr_time;
// Update device memory usage.
if (!IsPersistentNode(node)) {
for (const auto& port_num_output_pair : node_state.outputs) {
int port_num = port_num_output_pair.first;
// There's a chance that a specific output is not used at all.
if (node_state.outputs[port_num].empty()) {
node_state.time_no_references[port_num] = curr_time;
} else {
device.memory_usage +=
CalculateOutputSize(node_state.output_properties, port_num);
device.nodes_in_memory.insert(std::make_pair(node, port_num));
}
}
}
// Update device's per-op cost. // Update device's per-op cost.
auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost); auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
device_op_cost = CombineCosts(device_op_cost, node_costs); device_op_cost = CombineCosts(device_op_cost, node_costs);
@ -425,7 +497,52 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
<< ", scheduled: " << node_state.time_scheduled.count() << ", scheduled: " << node_state.time_scheduled.count()
<< ", finished: " << node_state.time_finished.count(); << ", finished: " << node_state.time_finished.count();
return PopCurrNode(); // Increment num_inputs_ready of the output nodes
for (const auto& port_num_output_pair : node_state.outputs) {
for (auto* output_node : port_num_output_pair.second) {
auto& output_state = node_map_[output_node];
output_state.num_inputs_ready++;
if (output_state.num_inputs_ready == output_state.inputs.size()) {
// This output node is now ready.
output_state.time_ready = curr_time;
ready_nodes_->AddNode(output_node);
}
}
}
// Increment num_outputs_executed of the input nodes.
for (const auto& input_port : node_state.inputs) {
auto* input = input_port.first;
auto port = input_port.second;
auto& input_state = node_map_[input];
input_state.num_outputs_executed[port]++;
if (input_state.num_outputs_executed[port] ==
input_state.outputs[port].size() &&
!IsPersistentNode(input)) {
// All the outputs are executed; no reference to this output port of
// input node.
input_state.time_no_references[port] = curr_time;
auto& input_device = device_[input_state.device_name];
input_device.memory_usage -=
CalculateOutputSize(input_state.output_properties, port);
input_device.nodes_in_memory.erase(std::make_pair(input, port));
}
}
if (!IsPersistentNode(node)) {
// Now that output memory is added and used up nodes are deallocated,
// check max memory usage.
if (device.memory_usage > device.max_memory_usage) {
device.max_memory_usage = device.memory_usage;
device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
}
}
// Remove the current node; assume FIFO.
ready_nodes_->RemoveCurrNode();
return !ready_nodes_->Empty();
} }
Costs VirtualScheduler::Summary() const { Costs VirtualScheduler::Summary() const {
@ -452,17 +569,59 @@ Costs VirtualScheduler::Summary() const {
for (const auto& device : device_) { for (const auto& device : device_) {
const auto& name = device.first; const auto& name = device.first;
const auto& state = device.second; const auto& state = device.second;
std::map<string, int64> op_to_memory;
// First profile only persistent memory usage.
int64 persistent_memory_usage = 0;
std::set<string> persisent_ops;
for (const auto& node_port : state.persistent_nodes) {
const auto* node = node_port.first;
const auto port = node_port.second;
const auto output_size =
CalculateOutputSize(node_map_.at(node).output_properties, port);
persistent_memory_usage += output_size;
op_to_memory[node->op()] += output_size;
persisent_ops.insert(node->op());
}
int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
VLOG(1) << "Device = " << name VLOG(1) << "Device = " << name
<< ", num_nodes = " << state.nodes_executed.size() << ", num_nodes = " << state.nodes_executed.size()
<< ", execution_time = " << state.GetCurrTime().count(); << ", execution_time = " << state.GetCurrTime().count()
VLOG(1) << "Per-op execution time:"; << ", memory usage: "
<< "persistenst = "
<< Round2(persistent_memory_usage / 1024.0 / 1024.0 / 1024.0)
<< " GB, peak = "
<< Round2(state.max_memory_usage / 1024.0 / 1024.0 / 1024.0)
<< " GB, total = "
<< Round2(max_memory_usage / 1024.0 / 1024.0 / 1024.0)
<< " GB, at the end: " << state.memory_usage << " B";
VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
// Profile non-persistent op memory usage.
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
const auto* node = node_port.first;
const auto port = node_port.second;
op_to_memory[node->op()] +=
CalculateOutputSize(node_map_.at(node).output_properties, port);
}
for (const auto& op_cost_pair : state.op_to_cost) { for (const auto& op_cost_pair : state.op_to_cost) {
const auto& op = op_cost_pair.first; const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count(); const auto& cost = op_cost_pair.second.execution_time.count();
if (cost) { // Skip printing out zero-cost ops. const float mem_usage_gb =
VLOG(1) << " + " << op << " : " << cost; Round2(op_to_memory[op] / 1024.0 / 1024.0 / 1024.0);
int64 op_mem_usage = op_to_memory.at(op);
const float mem_usage_percent =
max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
: 0.0;
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
VLOG(1) << " + " << op << " : " << cost << " (" << mem_usage_gb
<< " GB [" << mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
} }
} }
VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
if (critical_path_costs.execution_time <= state.GetCurrTime()) { if (critical_path_costs.execution_time <= state.GetCurrTime()) {
critical_path_costs = state.device_costs; critical_path_costs = state.device_costs;
} }

View File

@ -29,36 +29,79 @@ namespace tensorflow {
namespace grappler { namespace grappler {
struct NodeState { struct NodeState {
std::vector<const NodeDef*> inputs; // A node (i.e., an op) takes a set of input:port pairs and produces
std::vector<const NodeDef*> outputs; // a set of output ports.
// Cross references to input and output nodes from graphdef.
std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs.
// List of output nodes (a list of nodes that takes this output port as input)
// keyed by port_num. Note that port_num -1 is used for control dependency.
std::unordered_map<int, std::vector<const NodeDef*>> outputs;
// Info from GraphProperties.
std::vector<OpInfo::TensorProperties> input_properties;
std::vector<OpInfo::TensorProperties> output_properties;
// Canonical device name used within VirtualScheduler.
string device_name;
// States updated as scheduling nodes.
int num_inputs_ready; int num_inputs_ready;
int num_outputs_executed; std::unordered_map<int, int> num_outputs_executed;
Costs::Duration time_ready; Costs::Duration time_ready;
Costs::Duration time_scheduled; Costs::Duration time_scheduled;
Costs::Duration time_finished; Costs::Duration time_finished;
Costs::Duration time_no_reference; // Time that all the consumers are executed (hence, no need to keep this
// output in memory), keyed by port_num.
std::unordered_map<int, Costs::Duration> time_no_references;
// Note that a node may have multiple output ports. The length of outputs,
// num_outputs_executed, and time_no_references should be
// identical when a NodeState is fully initialized.
// They should be 1 + output_properties.size() as we add [-1] for control
// dependency.
// Node will be ready to be executed at time_ready, scheduled at // Node will be ready to be executed at time_ready, scheduled at
// time_scheduled, and finishes execution at time_finished. // time_scheduled, and finishes execution at time_finished.
// Between time_scheduled and time_no_reference, the node's output tensor // Each output port uses up memory space from time_scheduled to its
// needs to be on the device, using up device memory. // time_no_references.
NodeState() { NodeState() {
num_inputs_ready = 0; num_inputs_ready = 0;
num_outputs_executed = 0;
time_ready = Costs::Duration::max(); time_ready = Costs::Duration::max();
time_scheduled = Costs::Duration::max(); time_scheduled = Costs::Duration::max();
time_finished = Costs::Duration::max(); time_finished = Costs::Duration::max();
time_no_reference = Costs::Duration::max(); // Note that num_outputs_executed and time_no_references are not initialized
// here, since we don't know the size (i.e., # outputs for this node).
} }
}; };
struct DeviceState { struct DeviceState {
// Nodes executed on this device in execution order.
std::vector<const NodeDef*> nodes_executed; std::vector<const NodeDef*> nodes_executed;
Costs device_costs;
std::map<string, Costs> op_to_cost; // Per-op cost.
DeviceState() { device_costs = Costs::ZeroCosts(); } // Nodes currently allocated in memory: set of NodeDef* and port_num pairs
// so that we can track which output of the node is in memory.
std::set<std::pair<const NodeDef*, int>> nodes_in_memory;
// Nodes allocated in memory persistently: e.g., Variables.
std::set<std::pair<const NodeDef*, int>> persistent_nodes;
// Snapshot of nodes_in_memory, when memory usage is at peak.
// Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
std::set<std::pair<const NodeDef*, int>> mem_usage_snapshot_at_peak;
Costs device_costs;
std::map<string, Costs> op_to_cost; // Per-op cost.
std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage.
int64 memory_usage;
int64 max_memory_usage;
DeviceState() {
device_costs = Costs::ZeroCosts();
memory_usage = 0;
max_memory_usage = 0;
}
Costs::Duration GetCurrTime() const { return device_costs.execution_time; } Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
}; };
@ -106,48 +149,74 @@ class VirtualScheduler {
const string& default_device_type, Cluster* cluster, const string& default_device_type, Cluster* cluster,
VirtualPlacer* placer); VirtualPlacer* placer);
// Initializes NodeState and DeviceState from grappler_item_ and
// graph_properties_.
Status Init(); Status Init();
NodeInfo GetCurrNodeInfo() const; NodeInfo GetCurrNodeInfo() const;
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs); bool MarkCurrNodeExecuted(const Costs& node_costs);
// Prints out summary of execution (timing, memory usage, etc.)
Costs Summary() const; Costs Summary() const;
protected:
// GetDeviceStates and GetNodeStates are currently for testing purpuse only.
// Retrieves detailed scheduling results.
const std::unordered_map<string, DeviceState>& GetDeviceStates() const {
return device_;
}
const std::unordered_map<const NodeDef*, NodeState>& GetNodeStates() const {
return node_map_;
}
// Returns the size of output at port_num (unit: bytes). A special case is
// port_num -1, which is for control dependency and assumed to be 4 bytes.
int64 CalculateOutputSize(
const std::vector<OpInfo::TensorProperties>& output_properties,
const int port_num) const;
private: private:
const string kSend = "_Send"; // Constants.
const string kRecv = "_Recv";
const string kAttrInputSrc = "input_source_"; const string kAttrInputSrc = "input_source_";
const string kAttrSrcDevice = "src_device_"; const string kAttrSrcDevice = "src_device_";
const string kAttrDstDevice = "dst_device_"; const string kAttrDstDevice = "dst_device_";
const string kChannelDevice = "Channel"; const string kChannelDevice = "Channel";
void MaybeUpdateInputProperties( // Methods called from Init(). Fails if initialize_ is set.
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const; void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node); NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
std::pair<const NodeDef*, const NodeDef*> TransferNode( std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
const NodeDef* from, const NodeDef* to, const string& input_name); const NodeDef* from, const NodeDef* to, const string& input_name);
string DeviceName(const NodeDef* node) const; string DeviceName(const NodeDef* node) const;
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
// Helper methods.
Costs& FindOrCreateZero(const string& op_name, Costs& FindOrCreateZero(const string& op_name,
std::map<string, Costs>* op_cost); std::map<string, Costs>* op_cost);
float Round2(const float x) const;
bool IsPersistentNode(const NodeDef* node) const;
bool PopCurrNode(); // Scheduler states:
bool IsSendOp(const NodeDef* node) const;
bool IsRecvOp(const NodeDef* node) const;
GraphProperties graph_properties_;
std::map<string, int> op_counts_; // Op counts with key with input shape.
std::map<string, int> op_costs_; // Individual op costs (with input shapes).
Costs graph_costs_; // Graph cost.
std::map<string, Costs> op_to_cost_; // Per-op cost.
std::unique_ptr<ReadyNodeManager> ready_nodes_; std::unique_ptr<ReadyNodeManager> ready_nodes_;
std::unordered_map<const NodeDef*, NodeState> node_map_; std::unordered_map<const NodeDef*, NodeState> node_map_;
std::unordered_map<string, DeviceState> device_; std::unordered_map<string, DeviceState> device_;
// Pool of NodeDefs for SendRecv and Identity ops created. // Pool of NodeDefs for SendRecv and Identity ops created.
std::vector<std::unique_ptr<NodeDef>> additional_nodes_; std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
// Cache of ops transferred to another device. // Cache of nodes transferred to another device.
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>> std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
cached_ops_; cached_recv_nodes_;
// Stats:
std::map<string, int> op_counts_; // Op counts with key with input shape.
std::map<string, int> op_costs_; // Individual op costs (with input shapes).
Costs graph_costs_; // Graph cost.
std::map<string, Costs> op_to_cost_; // Per-op cost.
// Auxilliary data structures for constructing NodeState and DeviceState.
GraphProperties graph_properties_;
Cluster* cluster_; // Not owned. Cluster* cluster_; // Not owned.
const GrapplerItem* grappler_item_; // Not owned. const GrapplerItem* grappler_item_; // Not owned.
bool use_static_shapes_; bool use_static_shapes_;

View File

@ -23,42 +23,49 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
// Class for testing virtual scheduler.
class TestVirtualScheduler : public VirtualScheduler {
public:
TestVirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
const string& default_device_type, Cluster* cluster,
VirtualPlacer* placer)
: VirtualScheduler(grappler_item, use_static_shapes, default_device_type,
cluster, placer) {}
FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
FRIEND_TEST(VirtualSchedulerTest, Variable);
};
class VirtualSchedulerTest : public ::testing::Test { class VirtualSchedulerTest : public ::testing::Test {
protected: protected:
const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
void SetUp() override { void SetUp() override {
// Initializes cluster_ and placer_. // Initializes cluster_ and placer_.
std::unordered_map<string, DeviceProperties> devices; std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device; DeviceProperties cpu_device;
cpu_device.set_type("CPU"); cpu_device.set_type("CPU");
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; devices[kCPU0] = cpu_device;
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
cluster_.reset(new VirtualCluster(devices)); cluster_.reset(new VirtualCluster(devices));
placer_.reset(new VirtualPlacer(cluster_.get())); placer_.reset(new VirtualPlacer(cluster_.get()));
} }
void CreateSchedulerWithConv2Ds() { // Three Conv2Ds with only two in fetch nodes.
// Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in void CreateGrapplerItemWithConv2Ds() {
// fetch nodes. tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
const int bs = 4;
const int width = 10;
const int height = 10;
const int depth_in = 8;
const int kernel = 3;
const int depth_out = 16;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = tensorflow::ops::RandomUniform( auto x = tensorflow::ops::RandomUniform(
s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT); s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto y = tensorflow::ops::RandomUniform( auto y = tensorflow::ops::RandomUniform(
s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT); s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto z = tensorflow::ops::RandomUniform( auto z = tensorflow::ops::RandomUniform(
s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT); s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto f = tensorflow::ops::RandomUniform( auto f = tensorflow::ops::RandomUniform(
s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT); s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
std::vector<int> strides = {1, 1, 1, 1}; std::vector<int> strides = {1, 1, 1, 1};
auto c0 = auto c0 =
tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
@ -68,47 +75,253 @@ class VirtualSchedulerTest : public ::testing::Test {
tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
GraphDef def; GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def)); TF_CHECK_OK(s.ToGraphDef(&def));
LOG(INFO) << def.DebugString();
grappler_item_.reset(new GrapplerItem); grappler_item_.reset(new GrapplerItem);
grappler_item_->id = "test_conv2d_graph"; grappler_item_->id = "test_conv2d_graph";
grappler_item_->graph = def; grappler_item_->graph = def;
grappler_item_->fetch = {"c0", "c1"}; grappler_item_->fetch = {"c0", "c1"};
scheduler_.reset(new VirtualScheduler( dependency_["c0"] = {"x", "f"};
dependency_["c1"] = {"y", "f"};
}
// A Conv2D with a variable.
void CreateGrapplerItemWithConv2DAndVariable() {
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
auto x = tensorflow::ops::RandomUniform(
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto f = tensorflow::ops::Variable(
s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
std::vector<int> strides = {1, 1, 1, 1};
auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
grappler_item_.reset(new GrapplerItem);
grappler_item_->id = "test_conv2d_var_graph";
grappler_item_->graph = def;
grappler_item_->fetch = {"y"};
dependency_["y"] = {"x", "f"};
}
// AddN that takes 4 tensors with 10x10x10x10.
void CreateGrapplerItemWithAddN() {
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10},
DT_FLOAT);
auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10},
DT_FLOAT);
auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10},
DT_FLOAT);
auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10},
DT_FLOAT);
tensorflow::OutputList input_tensors = {x, y, z, w};
auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors);
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
grappler_item_.reset(new GrapplerItem);
grappler_item_->id = "test_addn_graph";
grappler_item_->graph = def;
grappler_item_->fetch = {"out"};
dependency_["out"] = {"x", "y", "z", "w"};
}
// NoOp that takes 7 NoOps as control dependency.
void CreateGrapplerItemWithControlDependency() {
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
std::vector<tensorflow::Operation> input_tensors;
for (const auto& input : input_noop_names) {
auto x = tensorflow::ops::NoOp(s.WithOpName(input));
input_tensors.push_back(x.operation);
}
auto out = tensorflow::ops::NoOp(
s.WithControlDependencies(input_tensors).WithOpName("out"));
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
grappler_item_.reset(new GrapplerItem);
grappler_item_->id = "test_control_dependency_graph";
grappler_item_->graph = def;
grappler_item_->fetch = {"out"};
dependency_["out"] = input_noop_names;
}
// FusedBN [an op with multiple outputs] with multiple consumers (including
// control dependency).
void CreateGrapplerItemWithBatchNorm() {
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
auto x = tensorflow::ops::RandomUniform(
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
{depth_in_}, DT_FLOAT);
auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
{depth_in_}, DT_FLOAT);
auto mean =
tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
auto var =
tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
auto batch_norm = tensorflow::ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var,
ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
auto y = batch_norm.y;
auto batch_mean = batch_norm.batch_mean;
auto batch_var = batch_norm.batch_variance;
auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y);
auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var);
auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var);
std::vector<tensorflow::Operation> input_tensors = {
batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(),
};
auto z4 = tensorflow::ops::NoOp(
s.WithControlDependencies(batch_var).WithOpName("z4"));
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
grappler_item_.reset(new GrapplerItem);
grappler_item_->id = "test_complex_dependency_graph";
grappler_item_->graph = def;
grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
dependency_["z1"] = {"x", "bn"};
dependency_["z2"] = {"bn"};
dependency_["z3"] = {"bn"};
dependency_["z4"] = {"bn"};
}
// Call this after creating grappler_item_ and setting up dependency_.
void InitScheduler() {
scheduler_.reset(new TestVirtualScheduler(
grappler_item_.get(), true /* use_static_shapes */, grappler_item_.get(), true /* use_static_shapes */,
"CPU" /* default_device_type */, cluster_.get(), placer_.get())); "CPU" /* default_device_type */, cluster_.get(), placer_.get()));
TF_CHECK_OK(scheduler_->Init()); TF_CHECK_OK(scheduler_->Init());
} }
// Call this after init scheduler_. Scheduler stops after executing
// target_node.
std::unordered_map<string, NodeInfo> RunScheduler(const string& target_node) {
Costs zero_costs = Costs::ZeroCosts();
std::unordered_map<string, NodeInfo> ops_executed;
bool more_nodes = true;
do {
NodeInfo node_info = scheduler_->GetCurrNodeInfo();
ops_executed[node_info.name] = node_info;
// Check scheduling order.
auto it = dependency_.find(node_info.name);
if (it != dependency_.end()) {
for (const auto& preceding_node : it->second) {
EXPECT_GT(ops_executed.count(preceding_node), 0);
}
}
more_nodes = scheduler_->MarkCurrNodeExecuted(zero_costs);
if (node_info.name == target_node) {
// Scheduler has the state after executing the target node.
break;
}
} while (more_nodes);
return ops_executed;
}
// Helper method for validating a vector.
template <typename T>
void ExpectVectorEq(const std::vector<T>& expected,
const std::vector<T>& test_elements) {
// Set of expected elements for an easy comparison.
std::set<T> expected_set(expected.begin(), expected.end());
for (const auto& element : test_elements) {
EXPECT_GT(expected_set.count(element), 0);
}
EXPECT_EQ(expected.size(), test_elements.size());
}
// Helper method that checks the name of nodes.
void ValidateNodeDefs(const std::vector<string>& expected,
const std::vector<const NodeDef*>& node_defs) {
std::vector<string> node_names;
std::transform(node_defs.begin(), node_defs.end(),
std::back_inserter(node_names),
[](const NodeDef* node) { return node->name(); });
ExpectVectorEq(expected, node_names);
}
// Helper method for validating a set.
template <typename T>
void ExpectSetEq(const std::set<T>& expected,
const std::set<T>& test_elements) {
for (const auto& element : test_elements) {
EXPECT_GT(expected.count(element), 0);
}
EXPECT_EQ(expected.size(), test_elements.size());
}
// Helper method tthat checks name - port pairs.
void ValidateMemoryUsageSnapshot(
const std::vector<string>& expected_names, const int port_num_expected,
const std::set<std::pair<const NodeDef*, int>>& mem_usage_snapshot) {
std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
std::transform(
mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
[](const std::pair<const NodeDef*, int>& node_port) {
return std::make_pair(node_port.first->name(), node_port.second);
});
std::set<std::pair<string, int>> expected;
std::transform(expected_names.begin(), expected_names.end(),
std::inserter(expected, expected.begin()),
[port_num_expected](const string& name) {
return std::make_pair(name, port_num_expected);
});
ExpectSetEq(expected, nodes_at_peak_mem_usage);
}
// Helper method for converting shape vector to TensorProperty.
OpInfo::TensorProperties ShapeToTensorProperty(
const std::vector<int> shape, const DataType& data_type) const {
OpInfo::TensorProperties tensor_property;
tensor_property.set_dtype(data_type);
for (const auto& x : shape) {
tensor_property.mutable_shape()->add_dim()->set_size(x);
}
return tensor_property;
}
// SetUp() inits cluster_ and placer_. // SetUp() inits cluster_ and placer_.
std::unique_ptr<VirtualCluster> cluster_; std::unique_ptr<VirtualCluster> cluster_;
std::unique_ptr<VirtualPlacer> placer_; std::unique_ptr<VirtualPlacer> placer_;
// grappler_item_ and scheduler_ will be initialized differently for each test // grappler_item_ and scheduler_ will be initialized differently for each test
// case // case.
std::unique_ptr<GrapplerItem> grappler_item_; std::unique_ptr<GrapplerItem> grappler_item_;
std::unique_ptr<VirtualScheduler> scheduler_; std::unique_ptr<TestVirtualScheduler> scheduler_;
// Node name -> its preceding nodes map for testing scheduling order.
std::unordered_map<string, std::vector<string>> dependency_;
// Shared params for Conv2D related graphs:
const int batch_size_ = 4;
const int width_ = 10;
const int height_ = 10;
const int depth_in_ = 8;
const int kernel_ = 3;
const int depth_out_ = 16;
}; };
TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
CreateSchedulerWithConv2Ds(); // init scheduler_. // Init.
CreateGrapplerItemWithConv2Ds();
InitScheduler();
Costs zero_costs = Costs::ZeroCosts(); // Run the scheduler.
std::unordered_map<string, NodeInfo> ops_executed; auto ops_executed = RunScheduler(""); // Run all the nodes.
do {
NodeInfo node_info = scheduler_->GetCurrNodeInfo();
ops_executed[node_info.name] = node_info;
// Check scheduling order: x and f before c0, and y and f before c1.
if (node_info.name == "c0") {
EXPECT_GT(ops_executed.count("x"), 0);
EXPECT_GT(ops_executed.count("f"), 0);
} else if (node_info.name == "c1") {
EXPECT_GT(ops_executed.count("y"), 0);
EXPECT_GT(ops_executed.count("f"), 0);
}
} while (scheduler_->MarkCurrNodeExecuted(zero_costs));
// [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
// executed. // executed.
@ -132,5 +345,162 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
} }
TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
// Init.
CreateGrapplerItemWithAddN();
InitScheduler();
// Create a set of tensor properties.
std::vector<OpInfo::TensorProperties> output;
output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
// port_num -1 is for control dependency: hard coded 4B.
EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
// Test valid outputs.
EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
// Any uknown shape (-1) shall yield zero output size.
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
// Invalid port_num (though it may be an error) shall yield zero
// output size.
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
}
TEST_F(VirtualSchedulerTest, MemoryUsage) {
// Init.
CreateGrapplerItemWithAddN();
InitScheduler();
// Run the scheduler.
RunScheduler("");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
// out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
// is 4 x the input tensor size while executing the out node.
int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
const std::vector<string> expected_names = {"x", "y", "z", "w"};
EXPECT_EQ(expected_names.size() * one_input_node_size,
cpu_state.max_memory_usage);
ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
cpu_state.mem_usage_snapshot_at_peak);
}
TEST_F(VirtualSchedulerTest, ControlDependency) {
// Init.
CreateGrapplerItemWithControlDependency();
InitScheduler();
// Run the scheduler.
RunScheduler("");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
// The graph has a NoOp that takes control dependency from 7 NoOps. The peak
// memory usage is when executing the final NoOp.
int64 one_input_node_size = 4; // control dependency
const std::vector<string> expected_names = {"x", "y", "z", "w",
"u", "v", "t"};
EXPECT_EQ(expected_names.size() * one_input_node_size,
cpu_state.max_memory_usage);
ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
cpu_state.mem_usage_snapshot_at_peak);
}
TEST_F(VirtualSchedulerTest, ComplexDependency) {
// Init.
CreateGrapplerItemWithBatchNorm();
InitScheduler();
// Run the scheduler.
RunScheduler("bn");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
// The graph is
// bn = FusedBatchNorm(x, scale, offset, mean, var)
// z1 = bn.y + x
// z2 = bn.var + bn.var
// z3 = bn.var + bn.var
// z4 = control dependency from bn.
// Note that bn.mean doesn't have any consumer.
const int x_size = batch_size_ * width_ * height_ * depth_in_;
int64 expected_size =
4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
1 /* control dependency */);
EXPECT_EQ(expected_size, cpu_state.memory_usage);
// Nodes currrently in memory: bn's port -1, 0, and 2, and x's port 0.
std::set<std::pair<string, int>> nodes_in_memory;
std::transform(
cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
std::inserter(nodes_in_memory, nodes_in_memory.begin()),
[](const std::pair<const NodeDef*, int>& node_port) {
return std::make_pair(node_port.first->name(), node_port.second);
});
std::set<std::pair<string, int>> expected = {
std::make_pair("bn", -1), std::make_pair("bn", 0),
std::make_pair("bn", 2), std::make_pair("x", 0),
};
ExpectSetEq(expected, nodes_in_memory);
const auto& node_states = scheduler_->GetNodeStates();
const NodeState* bn_node = nullptr;
const NodeState* x_node = nullptr;
for (const auto& nodedef_node_state : node_states) {
const NodeDef* node = nodedef_node_state.first;
const NodeState& node_state = nodedef_node_state.second;
if (node->name() == "bn") {
bn_node = &node_state;
}
if (node->name() == "x") {
x_node = &node_state;
}
}
CHECK_NOTNULL(bn_node);
CHECK_NOTNULL(x_node);
ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
// z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
}
TEST_F(VirtualSchedulerTest, Variable) {
// Init.
CreateGrapplerItemWithConv2DAndVariable();
InitScheduler();
// Run the scheduler.
RunScheduler("");
const auto& device_states = scheduler_->GetDeviceStates();
const auto& cpu_state = device_states.at(kCPU0);
// There is one Conv2D that takes x and f, but f is variable, so it should be
// in persistent nodes.
// f is variable.
ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
cpu_state.persistent_nodes);
// Only x in peak memory usage snapshot.
ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
cpu_state.mem_usage_snapshot_at_peak);
}
} // end namespace grappler } // end namespace grappler
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -45,18 +45,33 @@ bool IsMerge(const NodeDef& node) {
return op == "Merge"; return op == "Merge";
} }
bool IsNoOp(const NodeDef& node) {
const auto op = node.op();
return op == "NoOp";
}
bool IsPlaceholder(const NodeDef& node) { bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op(); const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2" || return op == "Placeholder" || op == "PlaceholderV2" ||
op == "PlaceholderWithDefault"; op == "PlaceholderWithDefault";
} }
bool IsRecv(const NodeDef& node) {
const auto op = node.op();
return op == "_Recv";
}
bool IsReduction(const NodeDef& node) { bool IsReduction(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
op == "Mean" || op == "Any" || op == "All"; op == "Mean" || op == "Any" || op == "All";
} }
bool IsSend(const NodeDef& node) {
const auto op = node.op();
return op == "_Send";
}
bool IsSwitch(const NodeDef& node) { bool IsSwitch(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Switch"; return op == "Switch";

View File

@ -26,8 +26,11 @@ bool IsConstant(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node); bool IsDequeueOp(const NodeDef& node);
bool IsIdentity(const NodeDef& node); bool IsIdentity(const NodeDef& node);
bool IsMerge(const NodeDef& node); bool IsMerge(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node); bool IsPlaceholder(const NodeDef& node);
bool IsRecv(const NodeDef& node);
bool IsReduction(const NodeDef& node); bool IsReduction(const NodeDef& node);
bool IsSend(const NodeDef& node);
bool IsSwitch(const NodeDef& node); bool IsSwitch(const NodeDef& node);
bool IsTranspose(const NodeDef& node); bool IsTranspose(const NodeDef& node);
bool IsVariable(const NodeDef& node); bool IsVariable(const NodeDef& node);