mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Profile memory usage in VirtualScheduler and report peak memory usage.
To do so, NodeState now handles different output ports of a node (in case a node has multiple outputs). Also, VirtualScheduler code is cleaned up with more comments. PiperOrigin-RevId: 158209068
This commit is contained in:
parent
0ea0bf5aae
commit
8f89b654f4
|
|
@ -176,6 +176,7 @@ cc_library(
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/clusters:utils",
|
"//tensorflow/core/grappler/clusters:utils",
|
||||||
"//tensorflow/core/grappler/costs:cost_estimator",
|
"//tensorflow/core/grappler/costs:cost_estimator",
|
||||||
|
|
@ -192,6 +193,10 @@ cc_test(
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
|
"//tensorflow/core/grappler:utils",
|
||||||
|
"//tensorflow/core/grappler/clusters:utils",
|
||||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,8 @@ constexpr char kNoOp[] = "NoOp";
|
||||||
constexpr char kReshape[] = "Reshape";
|
constexpr char kReshape[] = "Reshape";
|
||||||
constexpr char kRecv[] = "_Recv";
|
constexpr char kRecv[] = "_Recv";
|
||||||
constexpr char kBatchMatMul[] = "BatchMatMul";
|
constexpr char kBatchMatMul[] = "BatchMatMul";
|
||||||
|
constexpr char kVariable[] = "Variable";
|
||||||
|
constexpr char kVariableV2[] = "VariableV2";
|
||||||
|
|
||||||
OpLevelCostEstimator::OpLevelCostEstimator() {
|
OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||||
// Syntactic sugar to build and return a lambda that takes an OpInfo and
|
// Syntactic sugar to build and return a lambda that takes an OpInfo and
|
||||||
|
|
@ -53,6 +55,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||||
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||||
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||||
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||||
|
{kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||||
|
{kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
|
||||||
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
|
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -567,7 +571,7 @@ int64 OpLevelCostEstimator::CalculateSingleInputSize(
|
||||||
for (const auto& dim : input_shape.dim()) {
|
for (const auto& dim : input_shape.dim()) {
|
||||||
input_size *= dim.size();
|
input_size *= dim.size();
|
||||||
}
|
}
|
||||||
return input_size * DataTypeSize(input.dtype());
|
return input_size * DataTypeSize(BaseType(input.dtype()));
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 OpLevelCostEstimator::CalculateInputSize(
|
int64 OpLevelCostEstimator::CalculateInputSize(
|
||||||
|
|
@ -589,7 +593,7 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
|
||||||
for (const auto& output : op_features.outputs()) {
|
for (const auto& output : op_features.outputs()) {
|
||||||
DataType dt = output.dtype();
|
DataType dt = output.dtype();
|
||||||
const auto& original_output_shape = output.shape();
|
const auto& original_output_shape = output.shape();
|
||||||
int64 output_size = DataTypeSize(dt);
|
int64 output_size = DataTypeSize(BaseType(dt));
|
||||||
int num_dims = std::max(1, original_output_shape.dim_size());
|
int num_dims = std::max(1, original_output_shape.dim_size());
|
||||||
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
|
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
|
||||||
found_unknown_shapes);
|
found_unknown_shapes);
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||||
#include "tensorflow/core/grappler/costs/utils.h"
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
|
|
@ -55,10 +56,10 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
|
||||||
const bool use_static_shapes,
|
const bool use_static_shapes,
|
||||||
const string& default_device_type,
|
const string& default_device_type,
|
||||||
Cluster* cluster, VirtualPlacer* placer)
|
Cluster* cluster, VirtualPlacer* placer)
|
||||||
: graph_properties_(*grappler_item),
|
: // TODO(dyoon): Use a better way than FIFO.
|
||||||
graph_costs_(Costs::ZeroCosts()),
|
|
||||||
// TODO(dyoon): Use a better way than FIFO.
|
|
||||||
ready_nodes_(new FIFOManager()),
|
ready_nodes_(new FIFOManager()),
|
||||||
|
graph_costs_(Costs::ZeroCosts()),
|
||||||
|
graph_properties_(*grappler_item),
|
||||||
cluster_(cluster),
|
cluster_(cluster),
|
||||||
grappler_item_(grappler_item),
|
grappler_item_(grappler_item),
|
||||||
use_static_shapes_(use_static_shapes),
|
use_static_shapes_(use_static_shapes),
|
||||||
|
|
@ -68,6 +69,11 @@ VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
|
||||||
}
|
}
|
||||||
|
|
||||||
Status VirtualScheduler::Init() {
|
Status VirtualScheduler::Init() {
|
||||||
|
// Init() preprocesses the input grappler_item and graph_properties to extract
|
||||||
|
// necessary information for emulating tensorflow op scheduling and
|
||||||
|
// construct internal data structures (NodeState and DeviceState) for virtual
|
||||||
|
// scheduling.
|
||||||
|
|
||||||
// Construct graph properties.
|
// Construct graph properties.
|
||||||
Status status;
|
Status status;
|
||||||
if (use_static_shapes_) {
|
if (use_static_shapes_) {
|
||||||
|
|
@ -82,13 +88,12 @@ Status VirtualScheduler::Init() {
|
||||||
const auto& graph = grappler_item_->graph;
|
const auto& graph = grappler_item_->graph;
|
||||||
const auto& fetch_nodes = grappler_item_->fetch;
|
const auto& fetch_nodes = grappler_item_->fetch;
|
||||||
|
|
||||||
// First, get the nodes that would run to output fetch_nodes.
|
// Get the nodes that would run to output fetch_nodes.
|
||||||
std::vector<const NodeDef*> nodes =
|
std::vector<const NodeDef*> nodes =
|
||||||
ComputeTransitiveFanin(graph, fetch_nodes);
|
ComputeTransitiveFanin(graph, fetch_nodes);
|
||||||
|
|
||||||
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
|
// TODO(dyoon): this is a bit inefficient as name_to_node is already built in
|
||||||
// ComputeTransitiveFanin().
|
// ComputeTransitiveFanin().
|
||||||
//
|
|
||||||
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached
|
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached
|
||||||
// from the fetch nodes are scheduled. So the scheduled nodes should be
|
// from the fetch nodes are scheduled. So the scheduled nodes should be
|
||||||
// exactly the same as those executed for real. One possible discrepancy could
|
// exactly the same as those executed for real. One possible discrepancy could
|
||||||
|
|
@ -98,61 +103,72 @@ Status VirtualScheduler::Init() {
|
||||||
name_to_node[node->name()] = node;
|
name_to_node[node->name()] = node;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build node_map.
|
// Build node_map; for each node, create its NodeState and connect its inputs
|
||||||
|
// and outputs.
|
||||||
for (const auto* curr_node : nodes) {
|
for (const auto* curr_node : nodes) {
|
||||||
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
|
auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
|
||||||
const string curr_node_device = DeviceName(curr_node);
|
const string curr_node_device = DeviceName(curr_node);
|
||||||
for (const string& input_node_name : curr_node->input()) {
|
for (const string& input_node_name : curr_node->input()) {
|
||||||
// Note that input_node_name may be in <node_name>:<output_number> format,
|
// Note that input_node_name may be in <prefix><node_name>:<port_num>
|
||||||
// where ":<output_number>" may be omitted. NodeName() extracts only the
|
// format, where <prefix> (e.g., "^" for control dependency) and
|
||||||
// node_name (prefeix "^", if there was for control input, is also
|
// ":<port_num>" may be omitted. NodeName() extracts only the node_name.
|
||||||
// deleted).
|
|
||||||
const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
|
const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
|
||||||
|
|
||||||
CHECK(input_node);
|
CHECK(input_node);
|
||||||
// Add input_to_curr_node to curr_node's input, and
|
|
||||||
// add output_to_input_node to input_source_node's output.
|
|
||||||
// Default values for when input_node and curr_node on the same device.
|
|
||||||
const NodeDef* input_to_curr_node = input_node;
|
|
||||||
const NodeDef* input_source_node = input_node;
|
|
||||||
const NodeDef* output_to_input_node = curr_node;
|
|
||||||
const string in_device = DeviceName(input_node);
|
const string in_device = DeviceName(input_node);
|
||||||
if (curr_node_device != in_device) {
|
const auto input_node_port_num = NodePosition(input_node_name);
|
||||||
if (cached_ops_.count(input_node) > 0 &&
|
|
||||||
cached_ops_[input_node].count(curr_node_device) > 0) {
|
if (curr_node_device == in_device) {
|
||||||
// Different device, but found an already-transferred copy; connect
|
// Same device: connect input_node and curr_node directly.
|
||||||
// the cached node to curr_node.
|
curr_node_state.inputs.push_back(
|
||||||
input_to_curr_node = cached_ops_[input_node][curr_node_device];
|
std::make_pair(input_node, input_node_port_num));
|
||||||
input_source_node = input_to_curr_node;
|
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
|
||||||
output_to_input_node = curr_node;
|
input_node_state.outputs[input_node_port_num].push_back(curr_node);
|
||||||
|
} else {
|
||||||
|
if (cached_recv_nodes_.count(input_node) > 0 &&
|
||||||
|
cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
|
||||||
|
// Different device, but found an already-cached copy (a _Recv op);
|
||||||
|
// connect the _Recv to curr_node.
|
||||||
|
const auto* recv_op =
|
||||||
|
cached_recv_nodes_[input_node][curr_node_device];
|
||||||
|
// recv_op's output port is hard-coded to zero.
|
||||||
|
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
|
||||||
|
auto& input_node_state = node_map_.at(recv_op);
|
||||||
|
input_node_state.outputs[0].push_back(curr_node);
|
||||||
} else {
|
} else {
|
||||||
// Different device, no cached copy; transfer input_node to the
|
// Different device, no cached copy; transfer input_node to the
|
||||||
// curr_node's device.
|
// curr_node's device.
|
||||||
auto sendrecv_and_identity =
|
auto send_and_recv =
|
||||||
TransferNode(input_node, curr_node, input_node_name);
|
CreateSendRecv(input_node, curr_node, input_node_name);
|
||||||
const auto* sendrecv = sendrecv_and_identity.first;
|
// Note that CreateSendRecv() already connected input/output between
|
||||||
const auto* identity = sendrecv_and_identity.second;
|
// _Send and _Recv ops.
|
||||||
input_to_curr_node = identity;
|
const auto* send = send_and_recv.first;
|
||||||
input_source_node = input_node;
|
const auto* recv = send_and_recv.second;
|
||||||
output_to_input_node = sendrecv;
|
// recv_op's output port is hard-coded to zero.
|
||||||
|
curr_node_state.inputs.push_back(std::make_pair(recv, 0));
|
||||||
|
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
|
||||||
|
input_node_state.outputs[input_node_port_num].push_back(send);
|
||||||
|
|
||||||
// Cache the identity op for future use.
|
// Cache the _Recv op for future use.
|
||||||
cached_ops_[input_node][curr_node_device] = identity;
|
cached_recv_nodes_[input_node][curr_node_device] = recv;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
curr_node_state.inputs.push_back(input_to_curr_node);
|
|
||||||
|
|
||||||
// Note that we do not care output number (in case a tf op has multiple
|
|
||||||
// outputs), as VirtualScheduler only cares which nodes become ready as
|
|
||||||
// a node is executed.
|
|
||||||
auto& input_node_state = GetNodeStateOrCreateIt(input_source_node);
|
|
||||||
input_node_state.outputs.push_back(output_to_input_node);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (curr_node->input().empty()) {
|
if (curr_node->input().empty()) {
|
||||||
curr_node_state.time_ready =
|
// Node without input: ready at time 0.
|
||||||
Costs::Duration(); // Node without input: ready at time 0.
|
curr_node_state.time_ready = Costs::Duration();
|
||||||
ready_nodes_->AddNode(curr_node);
|
ready_nodes_->AddNode(curr_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (IsPersistentNode(curr_node)) {
|
||||||
|
auto& device_state = device_[curr_node_device];
|
||||||
|
for (int port_num = 0;
|
||||||
|
port_num < curr_node_state.output_properties.size(); ++port_num) {
|
||||||
|
device_state.persistent_nodes.insert(
|
||||||
|
std::make_pair(curr_node, port_num));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ready_nodes_->Empty()) {
|
if (ready_nodes_->Empty()) {
|
||||||
|
|
@ -163,18 +179,26 @@ Status VirtualScheduler::Init() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void VirtualScheduler::MaybeUpdateInputProperties(
|
void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
|
||||||
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const {
|
CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
|
||||||
if (IsSendOp(node) || IsRecvOp(node)) {
|
// This method is called when NodeState is created and adds input and output
|
||||||
// _Send and _Recv ops are inserted from VirtualScheduler, so
|
// properties for a few exceptional cases that GraphProperties cannot provide
|
||||||
|
// input/output properties.
|
||||||
|
if (IsSend(*node) || IsRecv(*node)) {
|
||||||
|
auto& node_state = node_map_[node];
|
||||||
|
auto& inputs = node_state.input_properties;
|
||||||
|
auto& outputs = node_state.output_properties;
|
||||||
|
|
||||||
|
// _Send and _Recv ops are created from VirtualScheduler, so
|
||||||
// there should be no inputs TensorProperties.
|
// there should be no inputs TensorProperties.
|
||||||
CHECK_EQ(inputs->size(), 0);
|
CHECK(inputs.empty());
|
||||||
|
CHECK(outputs.empty());
|
||||||
const auto& attr = node->attr();
|
const auto& attr = node->attr();
|
||||||
// This is the original input source to the _Send and _Recv, and this
|
// This is the original input source to the _Send and _Recv, and this
|
||||||
// string includes "^" if it was control dependency, and output port
|
// string includes "^" if it was control dependency, and output port
|
||||||
/// (e.g., ":2") if the input source had multiple outputs.
|
/// (e.g., ":2") if the input source had multiple outputs.
|
||||||
const auto& input_source_name = attr.at(kAttrInputSrc).s();
|
const auto& input_source_name = attr.at(kAttrInputSrc).s();
|
||||||
if (input_source_name[0] == '^') {
|
if (IsControlInput(input_source_name)) {
|
||||||
// Control dependency; regardless of the input source tensor size,
|
// Control dependency; regardless of the input source tensor size,
|
||||||
// send 4B.
|
// send 4B.
|
||||||
OpInfo::TensorProperties control_message;
|
OpInfo::TensorProperties control_message;
|
||||||
|
|
@ -182,51 +206,53 @@ void VirtualScheduler::MaybeUpdateInputProperties(
|
||||||
control_message.mutable_shape()->add_dim()->set_size(1);
|
control_message.mutable_shape()->add_dim()->set_size(1);
|
||||||
auto* value = control_message.mutable_value();
|
auto* value = control_message.mutable_value();
|
||||||
value->add_float_val(1);
|
value->add_float_val(1);
|
||||||
inputs->push_back(control_message);
|
inputs.push_back(control_message);
|
||||||
|
outputs.push_back(control_message);
|
||||||
} else {
|
} else {
|
||||||
|
auto output_properties =
|
||||||
|
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
||||||
// Like with HasInputProperties, if a node does not have output
|
// Like with HasInputProperties, if a node does not have output
|
||||||
// properties, it's likely it was pruned during the shape inference run.
|
// properties, it's likely it was pruned during the shape inference run.
|
||||||
if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) {
|
if (!output_properties.empty()) {
|
||||||
const auto input_position = NodePosition(input_source_name);
|
const auto input_node_port_num = NodePosition(input_source_name);
|
||||||
// Use the input source's output property as _Send and _Recv's input
|
// Use the input source's output property as _Send and _Recv's input
|
||||||
// property.
|
// property.
|
||||||
auto outputs =
|
CHECK_GT(output_properties.size(), input_node_port_num);
|
||||||
graph_properties_.GetOutputProperties(NodeName(input_source_name));
|
inputs.push_back(output_properties[input_node_port_num]);
|
||||||
CHECK_GT(outputs.size(), input_position);
|
outputs.push_back(output_properties[input_node_port_num]);
|
||||||
inputs->push_back(outputs[input_position]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VirtualScheduler::IsSendOp(const NodeDef* node) const {
|
float VirtualScheduler::Round2(const float x) const {
|
||||||
return node->op() == kSend;
|
return std::round(100.0 * x) / 100.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VirtualScheduler::IsRecvOp(const NodeDef* node) const {
|
bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
|
||||||
return node->op() == kRecv;
|
// Variables are persistent nodes.
|
||||||
|
return IsVariable(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
string VirtualScheduler::DeviceName(const NodeDef* node) const {
|
||||||
|
CHECK(!initialized_) << "DeviceName is called after Init().";
|
||||||
|
|
||||||
// TODO(dyoon): integrate this part with VirtualPlacer.
|
// TODO(dyoon): integrate this part with VirtualPlacer.
|
||||||
if (IsSendOp(node)) {
|
return node->device().empty() ? "/device:" + default_device_type_ + ":0"
|
||||||
const auto& node_state = node_map_.at(node);
|
: node->device();
|
||||||
const auto* from = node_state.inputs[0];
|
|
||||||
const auto* to = node_state.outputs[0];
|
|
||||||
return ChannelDeviceName(from, to);
|
|
||||||
} else {
|
|
||||||
return node->device().empty() ? "/" + default_device_type_ + ":0"
|
|
||||||
: node->device();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
|
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
|
||||||
const NodeDef* to) const {
|
const NodeDef* to) const {
|
||||||
|
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
|
||||||
|
|
||||||
return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to);
|
return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
|
||||||
const NodeDef* from, const NodeDef* to, const string& input_name) {
|
const NodeDef* from, const NodeDef* to, const string& input_name) {
|
||||||
|
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
|
||||||
|
|
||||||
// Connect "from" node to "to" node with _Send and _Recv such that
|
// Connect "from" node to "to" node with _Send and _Recv such that
|
||||||
// from -> _Send -> _Recv -> to.
|
// from -> _Send -> _Recv -> to.
|
||||||
// _Send is placed on "Channel" device, and _Recv is on the same device
|
// _Send is placed on "Channel" device, and _Recv is on the same device
|
||||||
|
|
@ -238,11 +264,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
||||||
// NodeDefs created here need not be correct: in terms of name,
|
// NodeDefs created here need not be correct: in terms of name,
|
||||||
// input names, attrs, etc.
|
// input names, attrs, etc.
|
||||||
|
|
||||||
|
auto input_node_port_num = NodePosition(input_name);
|
||||||
|
|
||||||
// _Send op.
|
// _Send op.
|
||||||
auto* send = new NodeDef();
|
auto* send = new NodeDef();
|
||||||
send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
|
send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
|
||||||
DeviceName(to));
|
DeviceName(to));
|
||||||
send->set_op(kSend);
|
send->set_op("_Send");
|
||||||
send->add_input(from->name());
|
send->add_input(from->name());
|
||||||
send->set_device(ChannelDeviceName(from, to));
|
send->set_device(ChannelDeviceName(from, to));
|
||||||
auto& send_attr = *(send->mutable_attr());
|
auto& send_attr = *(send->mutable_attr());
|
||||||
|
|
@ -253,19 +281,22 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
||||||
// _Recv op.
|
// _Recv op.
|
||||||
auto* recv = new NodeDef();
|
auto* recv = new NodeDef();
|
||||||
recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
|
recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
|
||||||
recv->set_op(kRecv);
|
recv->set_op("_Recv");
|
||||||
recv->add_input(send->name());
|
recv->add_input(send->name());
|
||||||
recv->set_device(DeviceName(to));
|
recv->set_device(DeviceName(to));
|
||||||
auto& recv_attr = *(recv->mutable_attr());
|
auto& recv_attr = *(recv->mutable_attr());
|
||||||
recv_attr[kAttrInputSrc].set_s(input_name);
|
recv_attr[kAttrInputSrc].set_s(input_name);
|
||||||
|
|
||||||
// Update NodeState for _Send and _Recv ops.
|
// NodeState for _Send op.
|
||||||
auto& send_node_state = GetNodeStateOrCreateIt(send);
|
auto& send_node_state = GetNodeStateOrCreateIt(send);
|
||||||
send_node_state.inputs.push_back(from);
|
send_node_state.device_name = send->device(); // Set Channel device.
|
||||||
send_node_state.outputs.push_back(recv);
|
send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
|
||||||
|
send_node_state.outputs[0].push_back(recv);
|
||||||
|
|
||||||
|
// NodeState for _Recv op.
|
||||||
auto& recv_node_state = GetNodeStateOrCreateIt(recv);
|
auto& recv_node_state = GetNodeStateOrCreateIt(recv);
|
||||||
recv_node_state.inputs.push_back(send);
|
recv_node_state.inputs.push_back(std::make_pair(send, 0));
|
||||||
recv_node_state.outputs.push_back(to);
|
recv_node_state.outputs[0].push_back(to);
|
||||||
|
|
||||||
// Keep the created nodes.
|
// Keep the created nodes.
|
||||||
additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
|
additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
|
||||||
|
|
@ -277,13 +308,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::TransferNode(
|
||||||
|
|
||||||
NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
||||||
const NodeDef* node = ready_nodes_->GetCurrNode();
|
const NodeDef* node = ready_nodes_->GetCurrNode();
|
||||||
std::vector<OpInfo::TensorProperties> inputs =
|
|
||||||
graph_properties_.GetInputProperties(node->name());
|
|
||||||
// Some ops created within VirtualScheduler may need further processing to
|
|
||||||
// the input properties.
|
|
||||||
MaybeUpdateInputProperties(node, &inputs);
|
|
||||||
|
|
||||||
// This is for compatibility; we can just use palcer_->get_device() for all
|
// This is for compatibility; we can just use placer_->get_device() for all
|
||||||
// cases, once VirtualCluster is properly set up.
|
// cases, once VirtualCluster is properly set up.
|
||||||
DeviceProperties device;
|
DeviceProperties device;
|
||||||
if (placer_) {
|
if (placer_) {
|
||||||
|
|
@ -294,7 +320,8 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
||||||
int device_id;
|
int device_id;
|
||||||
DeviceNameUtils::ParsedName parsed;
|
DeviceNameUtils::ParsedName parsed;
|
||||||
if (!node->device().empty() &&
|
if (!node->device().empty() &&
|
||||||
DeviceNameUtils::ParseFullName(DeviceName(node), &parsed)) {
|
DeviceNameUtils::ParseFullName(node_map_.at(node).device_name,
|
||||||
|
&parsed)) {
|
||||||
device_type = parsed.type;
|
device_type = parsed.type;
|
||||||
device_id = parsed.id;
|
device_id = parsed.id;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -309,81 +336,111 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case for _Send op.
|
// Special case for _Send op.
|
||||||
if (IsSendOp(node)) {
|
if (IsSend(*node)) {
|
||||||
device.set_type(kChannelDevice);
|
device.set_type(kChannelDevice);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Construct NodeInfo.
|
||||||
|
const auto& node_state = node_map_.at(node);
|
||||||
NodeInfo node_info;
|
NodeInfo node_info;
|
||||||
node_info.name = node->name();
|
node_info.name = node->name();
|
||||||
node_info.device_name = graph_properties_.GetDeviceName(node->name());
|
node_info.device_name = node_state.device_name;
|
||||||
std::vector<OpInfo::TensorProperties> outputs =
|
|
||||||
graph_properties_.GetOutputProperties(node->name());
|
|
||||||
auto& op_info = node_info.op_info;
|
auto& op_info = node_info.op_info;
|
||||||
op_info.set_op(node->op());
|
op_info.set_op(node->op());
|
||||||
*op_info.mutable_attr() = node->attr();
|
*op_info.mutable_attr() = node->attr();
|
||||||
for (auto& input : inputs) {
|
for (auto& input : node_state.input_properties) {
|
||||||
op_info.add_inputs()->Swap(&input);
|
*op_info.add_inputs() = input;
|
||||||
}
|
}
|
||||||
for (auto& output : outputs) {
|
for (auto& output : node_state.output_properties) {
|
||||||
op_info.add_outputs()->Swap(&output);
|
*op_info.add_outputs() = output;
|
||||||
}
|
}
|
||||||
op_info.mutable_device()->Swap(&device);
|
op_info.mutable_device()->Swap(&device);
|
||||||
// add some more to the node_info.
|
|
||||||
return node_info;
|
return node_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
|
||||||
|
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
|
||||||
|
|
||||||
auto it = node_map_.find(node);
|
auto it = node_map_.find(node);
|
||||||
if (it == node_map_.end()) {
|
if (it == node_map_.end()) {
|
||||||
|
// Not found; create a NodeState for this node.
|
||||||
it = node_map_.emplace(node, NodeState()).first;
|
it = node_map_.emplace(node, NodeState()).first;
|
||||||
|
auto& node_state = it->second;
|
||||||
|
node_state.input_properties =
|
||||||
|
graph_properties_.GetInputProperties(node->name());
|
||||||
|
node_state.output_properties =
|
||||||
|
graph_properties_.GetOutputProperties(node->name());
|
||||||
|
|
||||||
|
// Some ops may need further processing to the input / output properties:
|
||||||
|
// _Send and _Recv.
|
||||||
|
MaybeUpdateInputOutput(node);
|
||||||
|
|
||||||
|
if (!IsSend(*node)) {
|
||||||
|
node_state.device_name = DeviceName(node);
|
||||||
|
// For _Send op, device_name will be set to Channel in CreateSendRecv().
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize output port related data:
|
||||||
|
// Assume the size of OutputProperties represents the number of output ports
|
||||||
|
// of this node.
|
||||||
|
for (int i = 0; i < node_state.output_properties.size(); ++i) {
|
||||||
|
node_state.time_no_references[i] = Costs::Duration::max();
|
||||||
|
node_state.num_outputs_executed[i] = 0;
|
||||||
|
// Populate an empty vector for each port. The caller will add nodes
|
||||||
|
// that use this port as input.
|
||||||
|
node_state.outputs[i] = {};
|
||||||
|
}
|
||||||
|
// Port_num -1 is for control dependency.
|
||||||
|
node_state.time_no_references[-1] = Costs::Duration::max();
|
||||||
|
node_state.num_outputs_executed[-1] = 0;
|
||||||
|
node_state.outputs[-1] = {};
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 VirtualScheduler::CalculateOutputSize(
|
||||||
|
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||||
|
const int port_num) const {
|
||||||
|
if (port_num < 0) {
|
||||||
|
return 4; // 4B for control dependency.
|
||||||
|
}
|
||||||
|
|
||||||
|
if (port_num >= output_properties.size()) {
|
||||||
|
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
|
||||||
|
<< "port_num: " << port_num
|
||||||
|
<< " >= output_properties.size(): " << output_properties.size();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& output = output_properties[port_num];
|
||||||
|
int64 output_size = DataTypeSize(BaseType(output.dtype()));
|
||||||
|
|
||||||
|
for (const auto& dim : output.shape().dim()) {
|
||||||
|
auto dim_size = dim.size();
|
||||||
|
if (dim_size < 0) {
|
||||||
|
// Zero output size if there's any unknown dim.
|
||||||
|
output_size = 0;
|
||||||
|
VLOG(3) << "VirtualScheduler::CalculateOutputSize() -- "
|
||||||
|
<< "unknown dim: " << output_size;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
output_size *= dim_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
return output_size;
|
||||||
|
}
|
||||||
|
|
||||||
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
|
Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
|
||||||
std::map<string, Costs>* op_cost) {
|
std::map<string, Costs>* op_cost) {
|
||||||
auto it = op_cost->find(op_name);
|
auto it = op_cost->find(op_name);
|
||||||
if (it == op_cost->end()) {
|
if (it == op_cost->end()) {
|
||||||
|
// Note that default constructor of Costs sets some memory related fields
|
||||||
|
// to unknown values so we should explicitly initialize it with ZeroCosts.
|
||||||
it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
|
it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VirtualScheduler::PopCurrNode() {
|
|
||||||
const auto* node = ready_nodes_->GetCurrNode();
|
|
||||||
auto& node_state = node_map_[node];
|
|
||||||
auto& device = device_[DeviceName(node)];
|
|
||||||
auto curr_time = device.GetCurrTime();
|
|
||||||
|
|
||||||
// Increment num_inputs_ready of the output nodes.
|
|
||||||
for (auto* output : node_state.outputs) {
|
|
||||||
auto& output_state = node_map_[output];
|
|
||||||
output_state.num_inputs_ready++;
|
|
||||||
if (output_state.num_inputs_ready == output_state.inputs.size()) {
|
|
||||||
// This output node is now ready.
|
|
||||||
output_state.time_ready = curr_time;
|
|
||||||
ready_nodes_->AddNode(output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increment num_outputs_executed of the input nodes.
|
|
||||||
for (auto* input : node_state.inputs) {
|
|
||||||
auto& input_state = node_map_[input];
|
|
||||||
input_state.num_outputs_executed++;
|
|
||||||
if (input_state.num_outputs_executed == input_state.outputs.size()) {
|
|
||||||
// All the outputs are executed; no reference to this input nodel
|
|
||||||
input_state.time_no_reference = curr_time;
|
|
||||||
// TODO(dyoon): collect device memory usage; note that this input node
|
|
||||||
// use device memory between time_scheduled and time_no_reference.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the current node; assume FIFO.
|
|
||||||
ready_nodes_->RemoveCurrNode();
|
|
||||||
|
|
||||||
return !ready_nodes_->Empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
// Update graph_costs_ and per-op costs.
|
// Update graph_costs_ and per-op costs.
|
||||||
graph_costs_ = CombineCosts(graph_costs_, node_costs);
|
graph_costs_ = CombineCosts(graph_costs_, node_costs);
|
||||||
|
|
@ -402,7 +459,7 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
|
|
||||||
// Update node and device states.
|
// Update node and device states.
|
||||||
auto& node_state = node_map_[node];
|
auto& node_state = node_map_[node];
|
||||||
auto& device = device_[DeviceName(node)];
|
auto& device = device_[node_state.device_name];
|
||||||
device.nodes_executed.push_back(node);
|
device.nodes_executed.push_back(node);
|
||||||
// Node is scheduled when the device is available AND all the inputs are
|
// Node is scheduled when the device is available AND all the inputs are
|
||||||
// ready; hence, time_scheduled is time_ready if time_ready > device curr
|
// ready; hence, time_scheduled is time_ready if time_ready > device curr
|
||||||
|
|
@ -415,6 +472,21 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
auto curr_time = device.GetCurrTime();
|
auto curr_time = device.GetCurrTime();
|
||||||
node_state.time_finished = curr_time;
|
node_state.time_finished = curr_time;
|
||||||
|
|
||||||
|
// Update device memory usage.
|
||||||
|
if (!IsPersistentNode(node)) {
|
||||||
|
for (const auto& port_num_output_pair : node_state.outputs) {
|
||||||
|
int port_num = port_num_output_pair.first;
|
||||||
|
// There's a chance that a specific output is not used at all.
|
||||||
|
if (node_state.outputs[port_num].empty()) {
|
||||||
|
node_state.time_no_references[port_num] = curr_time;
|
||||||
|
} else {
|
||||||
|
device.memory_usage +=
|
||||||
|
CalculateOutputSize(node_state.output_properties, port_num);
|
||||||
|
device.nodes_in_memory.insert(std::make_pair(node, port_num));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update device's per-op cost.
|
// Update device's per-op cost.
|
||||||
auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
|
auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
|
||||||
device_op_cost = CombineCosts(device_op_cost, node_costs);
|
device_op_cost = CombineCosts(device_op_cost, node_costs);
|
||||||
|
|
@ -425,7 +497,52 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
|
||||||
<< ", scheduled: " << node_state.time_scheduled.count()
|
<< ", scheduled: " << node_state.time_scheduled.count()
|
||||||
<< ", finished: " << node_state.time_finished.count();
|
<< ", finished: " << node_state.time_finished.count();
|
||||||
|
|
||||||
return PopCurrNode();
|
// Increment num_inputs_ready of the output nodes
|
||||||
|
for (const auto& port_num_output_pair : node_state.outputs) {
|
||||||
|
for (auto* output_node : port_num_output_pair.second) {
|
||||||
|
auto& output_state = node_map_[output_node];
|
||||||
|
output_state.num_inputs_ready++;
|
||||||
|
if (output_state.num_inputs_ready == output_state.inputs.size()) {
|
||||||
|
// This output node is now ready.
|
||||||
|
output_state.time_ready = curr_time;
|
||||||
|
ready_nodes_->AddNode(output_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment num_outputs_executed of the input nodes.
|
||||||
|
for (const auto& input_port : node_state.inputs) {
|
||||||
|
auto* input = input_port.first;
|
||||||
|
auto port = input_port.second;
|
||||||
|
auto& input_state = node_map_[input];
|
||||||
|
input_state.num_outputs_executed[port]++;
|
||||||
|
if (input_state.num_outputs_executed[port] ==
|
||||||
|
input_state.outputs[port].size() &&
|
||||||
|
!IsPersistentNode(input)) {
|
||||||
|
// All the outputs are executed; no reference to this output port of
|
||||||
|
// input node.
|
||||||
|
input_state.time_no_references[port] = curr_time;
|
||||||
|
auto& input_device = device_[input_state.device_name];
|
||||||
|
input_device.memory_usage -=
|
||||||
|
CalculateOutputSize(input_state.output_properties, port);
|
||||||
|
|
||||||
|
input_device.nodes_in_memory.erase(std::make_pair(input, port));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IsPersistentNode(node)) {
|
||||||
|
// Now that output memory is added and used up nodes are deallocated,
|
||||||
|
// check max memory usage.
|
||||||
|
if (device.memory_usage > device.max_memory_usage) {
|
||||||
|
device.max_memory_usage = device.memory_usage;
|
||||||
|
device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the current node; assume FIFO.
|
||||||
|
ready_nodes_->RemoveCurrNode();
|
||||||
|
|
||||||
|
return !ready_nodes_->Empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
Costs VirtualScheduler::Summary() const {
|
Costs VirtualScheduler::Summary() const {
|
||||||
|
|
@ -452,17 +569,59 @@ Costs VirtualScheduler::Summary() const {
|
||||||
for (const auto& device : device_) {
|
for (const auto& device : device_) {
|
||||||
const auto& name = device.first;
|
const auto& name = device.first;
|
||||||
const auto& state = device.second;
|
const auto& state = device.second;
|
||||||
|
|
||||||
|
std::map<string, int64> op_to_memory;
|
||||||
|
// First profile only persistent memory usage.
|
||||||
|
int64 persistent_memory_usage = 0;
|
||||||
|
std::set<string> persisent_ops;
|
||||||
|
for (const auto& node_port : state.persistent_nodes) {
|
||||||
|
const auto* node = node_port.first;
|
||||||
|
const auto port = node_port.second;
|
||||||
|
const auto output_size =
|
||||||
|
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||||
|
persistent_memory_usage += output_size;
|
||||||
|
op_to_memory[node->op()] += output_size;
|
||||||
|
persisent_ops.insert(node->op());
|
||||||
|
}
|
||||||
|
int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
|
||||||
|
|
||||||
VLOG(1) << "Device = " << name
|
VLOG(1) << "Device = " << name
|
||||||
<< ", num_nodes = " << state.nodes_executed.size()
|
<< ", num_nodes = " << state.nodes_executed.size()
|
||||||
<< ", execution_time = " << state.GetCurrTime().count();
|
<< ", execution_time = " << state.GetCurrTime().count()
|
||||||
VLOG(1) << "Per-op execution time:";
|
<< ", memory usage: "
|
||||||
|
<< "persistenst = "
|
||||||
|
<< Round2(persistent_memory_usage / 1024.0 / 1024.0 / 1024.0)
|
||||||
|
<< " GB, peak = "
|
||||||
|
<< Round2(state.max_memory_usage / 1024.0 / 1024.0 / 1024.0)
|
||||||
|
<< " GB, total = "
|
||||||
|
<< Round2(max_memory_usage / 1024.0 / 1024.0 / 1024.0)
|
||||||
|
<< " GB, at the end: " << state.memory_usage << " B";
|
||||||
|
|
||||||
|
VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
|
||||||
|
// Profile non-persistent op memory usage.
|
||||||
|
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
|
||||||
|
const auto* node = node_port.first;
|
||||||
|
const auto port = node_port.second;
|
||||||
|
op_to_memory[node->op()] +=
|
||||||
|
CalculateOutputSize(node_map_.at(node).output_properties, port);
|
||||||
|
}
|
||||||
for (const auto& op_cost_pair : state.op_to_cost) {
|
for (const auto& op_cost_pair : state.op_to_cost) {
|
||||||
const auto& op = op_cost_pair.first;
|
const auto& op = op_cost_pair.first;
|
||||||
const auto& cost = op_cost_pair.second.execution_time.count();
|
const auto& cost = op_cost_pair.second.execution_time.count();
|
||||||
if (cost) { // Skip printing out zero-cost ops.
|
const float mem_usage_gb =
|
||||||
VLOG(1) << " + " << op << " : " << cost;
|
Round2(op_to_memory[op] / 1024.0 / 1024.0 / 1024.0);
|
||||||
|
int64 op_mem_usage = op_to_memory.at(op);
|
||||||
|
const float mem_usage_percent =
|
||||||
|
max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
|
||||||
|
: 0.0;
|
||||||
|
if (cost || mem_usage_percent > 1.0) {
|
||||||
|
// Print out only non-zero cost ops or ops with > 1% memory usage.
|
||||||
|
VLOG(1) << " + " << op << " : " << cost << " (" << mem_usage_gb
|
||||||
|
<< " GB [" << mem_usage_percent << "%] "
|
||||||
|
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
|
||||||
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
|
if (critical_path_costs.execution_time <= state.GetCurrTime()) {
|
||||||
critical_path_costs = state.device_costs;
|
critical_path_costs = state.device_costs;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,36 +29,79 @@ namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
struct NodeState {
|
struct NodeState {
|
||||||
std::vector<const NodeDef*> inputs;
|
// A node (i.e., an op) takes a set of input:port pairs and produces
|
||||||
std::vector<const NodeDef*> outputs;
|
// a set of output ports.
|
||||||
|
|
||||||
|
// Cross references to input and output nodes from graphdef.
|
||||||
|
std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs.
|
||||||
|
// List of output nodes (a list of nodes that takes this output port as input)
|
||||||
|
// keyed by port_num. Note that port_num -1 is used for control dependency.
|
||||||
|
std::unordered_map<int, std::vector<const NodeDef*>> outputs;
|
||||||
|
|
||||||
|
// Info from GraphProperties.
|
||||||
|
std::vector<OpInfo::TensorProperties> input_properties;
|
||||||
|
std::vector<OpInfo::TensorProperties> output_properties;
|
||||||
|
|
||||||
|
// Canonical device name used within VirtualScheduler.
|
||||||
|
string device_name;
|
||||||
|
|
||||||
|
// States updated as scheduling nodes.
|
||||||
int num_inputs_ready;
|
int num_inputs_ready;
|
||||||
int num_outputs_executed;
|
std::unordered_map<int, int> num_outputs_executed;
|
||||||
Costs::Duration time_ready;
|
Costs::Duration time_ready;
|
||||||
Costs::Duration time_scheduled;
|
Costs::Duration time_scheduled;
|
||||||
Costs::Duration time_finished;
|
Costs::Duration time_finished;
|
||||||
Costs::Duration time_no_reference;
|
// Time that all the consumers are executed (hence, no need to keep this
|
||||||
|
// output in memory), keyed by port_num.
|
||||||
|
std::unordered_map<int, Costs::Duration> time_no_references;
|
||||||
|
|
||||||
|
// Note that a node may have multiple output ports. The length of outputs,
|
||||||
|
// num_outputs_executed, and time_no_references should be
|
||||||
|
// identical when a NodeState is fully initialized.
|
||||||
|
// They should be 1 + output_properties.size() as we add [-1] for control
|
||||||
|
// dependency.
|
||||||
|
|
||||||
// Node will be ready to be executed at time_ready, scheduled at
|
// Node will be ready to be executed at time_ready, scheduled at
|
||||||
// time_scheduled, and finishes execution at time_finished.
|
// time_scheduled, and finishes execution at time_finished.
|
||||||
// Between time_scheduled and time_no_reference, the node's output tensor
|
// Each output port uses up memory space from time_scheduled to its
|
||||||
// needs to be on the device, using up device memory.
|
// time_no_references.
|
||||||
|
|
||||||
NodeState() {
|
NodeState() {
|
||||||
num_inputs_ready = 0;
|
num_inputs_ready = 0;
|
||||||
num_outputs_executed = 0;
|
|
||||||
time_ready = Costs::Duration::max();
|
time_ready = Costs::Duration::max();
|
||||||
time_scheduled = Costs::Duration::max();
|
time_scheduled = Costs::Duration::max();
|
||||||
time_finished = Costs::Duration::max();
|
time_finished = Costs::Duration::max();
|
||||||
time_no_reference = Costs::Duration::max();
|
// Note that num_outputs_executed and time_no_references are not initialized
|
||||||
|
// here, since we don't know the size (i.e., # outputs for this node).
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceState {
|
struct DeviceState {
|
||||||
|
// Nodes executed on this device in execution order.
|
||||||
std::vector<const NodeDef*> nodes_executed;
|
std::vector<const NodeDef*> nodes_executed;
|
||||||
Costs device_costs;
|
|
||||||
std::map<string, Costs> op_to_cost; // Per-op cost.
|
|
||||||
|
|
||||||
DeviceState() { device_costs = Costs::ZeroCosts(); }
|
// Nodes currently allocated in memory: set of NodeDef* and port_num pairs
|
||||||
|
// so that we can track which output of the node is in memory.
|
||||||
|
std::set<std::pair<const NodeDef*, int>> nodes_in_memory;
|
||||||
|
|
||||||
|
// Nodes allocated in memory persistently: e.g., Variables.
|
||||||
|
std::set<std::pair<const NodeDef*, int>> persistent_nodes;
|
||||||
|
|
||||||
|
// Snapshot of nodes_in_memory, when memory usage is at peak.
|
||||||
|
// Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs.
|
||||||
|
std::set<std::pair<const NodeDef*, int>> mem_usage_snapshot_at_peak;
|
||||||
|
|
||||||
|
Costs device_costs;
|
||||||
|
std::map<string, Costs> op_to_cost; // Per-op cost.
|
||||||
|
std::map<string, int64> op_to_memory; // Per-op memory usage at peak usage.
|
||||||
|
int64 memory_usage;
|
||||||
|
int64 max_memory_usage;
|
||||||
|
|
||||||
|
DeviceState() {
|
||||||
|
device_costs = Costs::ZeroCosts();
|
||||||
|
memory_usage = 0;
|
||||||
|
max_memory_usage = 0;
|
||||||
|
}
|
||||||
|
|
||||||
Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
|
Costs::Duration GetCurrTime() const { return device_costs.execution_time; }
|
||||||
};
|
};
|
||||||
|
|
@ -106,48 +149,74 @@ class VirtualScheduler {
|
||||||
const string& default_device_type, Cluster* cluster,
|
const string& default_device_type, Cluster* cluster,
|
||||||
VirtualPlacer* placer);
|
VirtualPlacer* placer);
|
||||||
|
|
||||||
|
// Initializes NodeState and DeviceState from grappler_item_ and
|
||||||
|
// graph_properties_.
|
||||||
Status Init();
|
Status Init();
|
||||||
|
|
||||||
NodeInfo GetCurrNodeInfo() const;
|
NodeInfo GetCurrNodeInfo() const;
|
||||||
|
|
||||||
|
// Returns true if there is any node to be scheduled.
|
||||||
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
bool MarkCurrNodeExecuted(const Costs& node_costs);
|
||||||
|
|
||||||
|
// Prints out summary of execution (timing, memory usage, etc.)
|
||||||
Costs Summary() const;
|
Costs Summary() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// GetDeviceStates and GetNodeStates are currently for testing purpuse only.
|
||||||
|
// Retrieves detailed scheduling results.
|
||||||
|
const std::unordered_map<string, DeviceState>& GetDeviceStates() const {
|
||||||
|
return device_;
|
||||||
|
}
|
||||||
|
const std::unordered_map<const NodeDef*, NodeState>& GetNodeStates() const {
|
||||||
|
return node_map_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the size of output at port_num (unit: bytes). A special case is
|
||||||
|
// port_num -1, which is for control dependency and assumed to be 4 bytes.
|
||||||
|
int64 CalculateOutputSize(
|
||||||
|
const std::vector<OpInfo::TensorProperties>& output_properties,
|
||||||
|
const int port_num) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const string kSend = "_Send";
|
// Constants.
|
||||||
const string kRecv = "_Recv";
|
|
||||||
const string kAttrInputSrc = "input_source_";
|
const string kAttrInputSrc = "input_source_";
|
||||||
const string kAttrSrcDevice = "src_device_";
|
const string kAttrSrcDevice = "src_device_";
|
||||||
const string kAttrDstDevice = "dst_device_";
|
const string kAttrDstDevice = "dst_device_";
|
||||||
const string kChannelDevice = "Channel";
|
const string kChannelDevice = "Channel";
|
||||||
|
|
||||||
void MaybeUpdateInputProperties(
|
// Methods called from Init(). Fails if initialize_ is set.
|
||||||
const NodeDef* node, std::vector<OpInfo::TensorProperties>* inputs) const;
|
void MaybeUpdateInputOutput(const NodeDef* node);
|
||||||
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
|
||||||
std::pair<const NodeDef*, const NodeDef*> TransferNode(
|
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
|
||||||
const NodeDef* from, const NodeDef* to, const string& input_name);
|
const NodeDef* from, const NodeDef* to, const string& input_name);
|
||||||
string DeviceName(const NodeDef* node) const;
|
string DeviceName(const NodeDef* node) const;
|
||||||
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
|
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
|
||||||
|
|
||||||
|
// Helper methods.
|
||||||
Costs& FindOrCreateZero(const string& op_name,
|
Costs& FindOrCreateZero(const string& op_name,
|
||||||
std::map<string, Costs>* op_cost);
|
std::map<string, Costs>* op_cost);
|
||||||
|
float Round2(const float x) const;
|
||||||
|
bool IsPersistentNode(const NodeDef* node) const;
|
||||||
|
|
||||||
bool PopCurrNode();
|
// Scheduler states:
|
||||||
bool IsSendOp(const NodeDef* node) const;
|
|
||||||
bool IsRecvOp(const NodeDef* node) const;
|
|
||||||
|
|
||||||
GraphProperties graph_properties_;
|
|
||||||
std::map<string, int> op_counts_; // Op counts with key with input shape.
|
|
||||||
std::map<string, int> op_costs_; // Individual op costs (with input shapes).
|
|
||||||
Costs graph_costs_; // Graph cost.
|
|
||||||
std::map<string, Costs> op_to_cost_; // Per-op cost.
|
|
||||||
std::unique_ptr<ReadyNodeManager> ready_nodes_;
|
std::unique_ptr<ReadyNodeManager> ready_nodes_;
|
||||||
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
std::unordered_map<const NodeDef*, NodeState> node_map_;
|
||||||
std::unordered_map<string, DeviceState> device_;
|
std::unordered_map<string, DeviceState> device_;
|
||||||
|
|
||||||
// Pool of NodeDefs for SendRecv and Identity ops created.
|
// Pool of NodeDefs for SendRecv and Identity ops created.
|
||||||
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
|
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
|
||||||
// Cache of ops transferred to another device.
|
// Cache of nodes transferred to another device.
|
||||||
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
|
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
|
||||||
cached_ops_;
|
cached_recv_nodes_;
|
||||||
|
|
||||||
|
// Stats:
|
||||||
|
std::map<string, int> op_counts_; // Op counts with key with input shape.
|
||||||
|
std::map<string, int> op_costs_; // Individual op costs (with input shapes).
|
||||||
|
Costs graph_costs_; // Graph cost.
|
||||||
|
std::map<string, Costs> op_to_cost_; // Per-op cost.
|
||||||
|
|
||||||
|
// Auxilliary data structures for constructing NodeState and DeviceState.
|
||||||
|
GraphProperties graph_properties_;
|
||||||
Cluster* cluster_; // Not owned.
|
Cluster* cluster_; // Not owned.
|
||||||
const GrapplerItem* grappler_item_; // Not owned.
|
const GrapplerItem* grappler_item_; // Not owned.
|
||||||
bool use_static_shapes_;
|
bool use_static_shapes_;
|
||||||
|
|
|
||||||
|
|
@ -23,42 +23,49 @@ limitations under the License.
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
// Class for testing virtual scheduler.
|
||||||
|
class TestVirtualScheduler : public VirtualScheduler {
|
||||||
|
public:
|
||||||
|
TestVirtualScheduler(const GrapplerItem* grappler_item,
|
||||||
|
const bool use_static_shapes,
|
||||||
|
const string& default_device_type, Cluster* cluster,
|
||||||
|
VirtualPlacer* placer)
|
||||||
|
: VirtualScheduler(grappler_item, use_static_shapes, default_device_type,
|
||||||
|
cluster, placer) {}
|
||||||
|
|
||||||
|
FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
|
||||||
|
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
|
||||||
|
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
|
||||||
|
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
|
||||||
|
FRIEND_TEST(VirtualSchedulerTest, Variable);
|
||||||
|
};
|
||||||
|
|
||||||
class VirtualSchedulerTest : public ::testing::Test {
|
class VirtualSchedulerTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
|
const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
|
||||||
|
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
// Initializes cluster_ and placer_.
|
// Initializes cluster_ and placer_.
|
||||||
std::unordered_map<string, DeviceProperties> devices;
|
std::unordered_map<string, DeviceProperties> devices;
|
||||||
DeviceProperties cpu_device;
|
DeviceProperties cpu_device;
|
||||||
cpu_device.set_type("CPU");
|
cpu_device.set_type("CPU");
|
||||||
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
|
devices[kCPU0] = cpu_device;
|
||||||
DeviceProperties gpu_device;
|
|
||||||
gpu_device.set_type("GPU");
|
|
||||||
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
|
|
||||||
|
|
||||||
cluster_.reset(new VirtualCluster(devices));
|
cluster_.reset(new VirtualCluster(devices));
|
||||||
placer_.reset(new VirtualPlacer(cluster_.get()));
|
placer_.reset(new VirtualPlacer(cluster_.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CreateSchedulerWithConv2Ds() {
|
// Three Conv2Ds with only two in fetch nodes.
|
||||||
// Create a scheduler with a simple graph: 3 Conv2Ds, where only 2 are in
|
void CreateGrapplerItemWithConv2Ds() {
|
||||||
// fetch nodes.
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||||
const int bs = 4;
|
|
||||||
const int width = 10;
|
|
||||||
const int height = 10;
|
|
||||||
const int depth_in = 8;
|
|
||||||
const int kernel = 3;
|
|
||||||
const int depth_out = 16;
|
|
||||||
|
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
|
||||||
auto x = tensorflow::ops::RandomUniform(
|
auto x = tensorflow::ops::RandomUniform(
|
||||||
s.WithOpName("x"), {bs, width, height, depth_in}, DT_FLOAT);
|
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||||
auto y = tensorflow::ops::RandomUniform(
|
auto y = tensorflow::ops::RandomUniform(
|
||||||
s.WithOpName("y"), {bs, width, height, depth_in}, DT_FLOAT);
|
s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||||
auto z = tensorflow::ops::RandomUniform(
|
auto z = tensorflow::ops::RandomUniform(
|
||||||
s.WithOpName("z"), {bs, width, height, depth_in}, DT_FLOAT);
|
s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||||
auto f = tensorflow::ops::RandomUniform(
|
auto f = tensorflow::ops::RandomUniform(
|
||||||
s.WithOpName("f"), {kernel, kernel, depth_in, depth_out}, DT_FLOAT);
|
s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
|
||||||
std::vector<int> strides = {1, 1, 1, 1};
|
std::vector<int> strides = {1, 1, 1, 1};
|
||||||
auto c0 =
|
auto c0 =
|
||||||
tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
|
tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
|
||||||
|
|
@ -68,47 +75,253 @@ class VirtualSchedulerTest : public ::testing::Test {
|
||||||
tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
|
tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
|
||||||
GraphDef def;
|
GraphDef def;
|
||||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||||
LOG(INFO) << def.DebugString();
|
|
||||||
|
|
||||||
grappler_item_.reset(new GrapplerItem);
|
grappler_item_.reset(new GrapplerItem);
|
||||||
grappler_item_->id = "test_conv2d_graph";
|
grappler_item_->id = "test_conv2d_graph";
|
||||||
grappler_item_->graph = def;
|
grappler_item_->graph = def;
|
||||||
grappler_item_->fetch = {"c0", "c1"};
|
grappler_item_->fetch = {"c0", "c1"};
|
||||||
|
|
||||||
scheduler_.reset(new VirtualScheduler(
|
dependency_["c0"] = {"x", "f"};
|
||||||
|
dependency_["c1"] = {"y", "f"};
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Conv2D with a variable.
|
||||||
|
void CreateGrapplerItemWithConv2DAndVariable() {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||||
|
auto x = tensorflow::ops::RandomUniform(
|
||||||
|
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||||
|
auto f = tensorflow::ops::Variable(
|
||||||
|
s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
|
||||||
|
std::vector<int> strides = {1, 1, 1, 1};
|
||||||
|
auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
|
||||||
|
GraphDef def;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||||
|
|
||||||
|
grappler_item_.reset(new GrapplerItem);
|
||||||
|
grappler_item_->id = "test_conv2d_var_graph";
|
||||||
|
grappler_item_->graph = def;
|
||||||
|
grappler_item_->fetch = {"y"};
|
||||||
|
|
||||||
|
dependency_["y"] = {"x", "f"};
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddN that takes 4 tensors with 10x10x10x10.
|
||||||
|
void CreateGrapplerItemWithAddN() {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||||
|
auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10},
|
||||||
|
DT_FLOAT);
|
||||||
|
auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10},
|
||||||
|
DT_FLOAT);
|
||||||
|
auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10},
|
||||||
|
DT_FLOAT);
|
||||||
|
auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10},
|
||||||
|
DT_FLOAT);
|
||||||
|
tensorflow::OutputList input_tensors = {x, y, z, w};
|
||||||
|
auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors);
|
||||||
|
GraphDef def;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||||
|
|
||||||
|
grappler_item_.reset(new GrapplerItem);
|
||||||
|
grappler_item_->id = "test_addn_graph";
|
||||||
|
grappler_item_->graph = def;
|
||||||
|
grappler_item_->fetch = {"out"};
|
||||||
|
|
||||||
|
dependency_["out"] = {"x", "y", "z", "w"};
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoOp that takes 7 NoOps as control dependency.
|
||||||
|
void CreateGrapplerItemWithControlDependency() {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||||
|
std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
|
||||||
|
std::vector<tensorflow::Operation> input_tensors;
|
||||||
|
for (const auto& input : input_noop_names) {
|
||||||
|
auto x = tensorflow::ops::NoOp(s.WithOpName(input));
|
||||||
|
input_tensors.push_back(x.operation);
|
||||||
|
}
|
||||||
|
auto out = tensorflow::ops::NoOp(
|
||||||
|
s.WithControlDependencies(input_tensors).WithOpName("out"));
|
||||||
|
GraphDef def;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||||
|
|
||||||
|
grappler_item_.reset(new GrapplerItem);
|
||||||
|
grappler_item_->id = "test_control_dependency_graph";
|
||||||
|
grappler_item_->graph = def;
|
||||||
|
grappler_item_->fetch = {"out"};
|
||||||
|
|
||||||
|
dependency_["out"] = input_noop_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
// FusedBN [an op with multiple outputs] with multiple consumers (including
|
||||||
|
// control dependency).
|
||||||
|
void CreateGrapplerItemWithBatchNorm() {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
|
||||||
|
auto x = tensorflow::ops::RandomUniform(
|
||||||
|
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
|
||||||
|
auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
|
||||||
|
{depth_in_}, DT_FLOAT);
|
||||||
|
auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
|
||||||
|
{depth_in_}, DT_FLOAT);
|
||||||
|
auto mean =
|
||||||
|
tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
|
||||||
|
auto var =
|
||||||
|
tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
|
||||||
|
|
||||||
|
auto batch_norm = tensorflow::ops::FusedBatchNorm(
|
||||||
|
s.WithOpName("bn"), x, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
|
||||||
|
auto y = batch_norm.y;
|
||||||
|
auto batch_mean = batch_norm.batch_mean;
|
||||||
|
auto batch_var = batch_norm.batch_variance;
|
||||||
|
|
||||||
|
auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y);
|
||||||
|
auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var);
|
||||||
|
auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var);
|
||||||
|
std::vector<tensorflow::Operation> input_tensors = {
|
||||||
|
batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(),
|
||||||
|
};
|
||||||
|
auto z4 = tensorflow::ops::NoOp(
|
||||||
|
s.WithControlDependencies(batch_var).WithOpName("z4"));
|
||||||
|
|
||||||
|
GraphDef def;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||||
|
|
||||||
|
grappler_item_.reset(new GrapplerItem);
|
||||||
|
grappler_item_->id = "test_complex_dependency_graph";
|
||||||
|
grappler_item_->graph = def;
|
||||||
|
grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
|
||||||
|
|
||||||
|
dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
|
||||||
|
dependency_["z1"] = {"x", "bn"};
|
||||||
|
dependency_["z2"] = {"bn"};
|
||||||
|
dependency_["z3"] = {"bn"};
|
||||||
|
dependency_["z4"] = {"bn"};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call this after creating grappler_item_ and setting up dependency_.
|
||||||
|
void InitScheduler() {
|
||||||
|
scheduler_.reset(new TestVirtualScheduler(
|
||||||
grappler_item_.get(), true /* use_static_shapes */,
|
grappler_item_.get(), true /* use_static_shapes */,
|
||||||
"CPU" /* default_device_type */, cluster_.get(), placer_.get()));
|
"CPU" /* default_device_type */, cluster_.get(), placer_.get()));
|
||||||
TF_CHECK_OK(scheduler_->Init());
|
TF_CHECK_OK(scheduler_->Init());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call this after init scheduler_. Scheduler stops after executing
|
||||||
|
// target_node.
|
||||||
|
std::unordered_map<string, NodeInfo> RunScheduler(const string& target_node) {
|
||||||
|
Costs zero_costs = Costs::ZeroCosts();
|
||||||
|
std::unordered_map<string, NodeInfo> ops_executed;
|
||||||
|
bool more_nodes = true;
|
||||||
|
do {
|
||||||
|
NodeInfo node_info = scheduler_->GetCurrNodeInfo();
|
||||||
|
ops_executed[node_info.name] = node_info;
|
||||||
|
|
||||||
|
// Check scheduling order.
|
||||||
|
auto it = dependency_.find(node_info.name);
|
||||||
|
if (it != dependency_.end()) {
|
||||||
|
for (const auto& preceding_node : it->second) {
|
||||||
|
EXPECT_GT(ops_executed.count(preceding_node), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
more_nodes = scheduler_->MarkCurrNodeExecuted(zero_costs);
|
||||||
|
|
||||||
|
if (node_info.name == target_node) {
|
||||||
|
// Scheduler has the state after executing the target node.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (more_nodes);
|
||||||
|
return ops_executed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method for validating a vector.
|
||||||
|
template <typename T>
|
||||||
|
void ExpectVectorEq(const std::vector<T>& expected,
|
||||||
|
const std::vector<T>& test_elements) {
|
||||||
|
// Set of expected elements for an easy comparison.
|
||||||
|
std::set<T> expected_set(expected.begin(), expected.end());
|
||||||
|
for (const auto& element : test_elements) {
|
||||||
|
EXPECT_GT(expected_set.count(element), 0);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(expected.size(), test_elements.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method that checks the name of nodes.
|
||||||
|
void ValidateNodeDefs(const std::vector<string>& expected,
|
||||||
|
const std::vector<const NodeDef*>& node_defs) {
|
||||||
|
std::vector<string> node_names;
|
||||||
|
std::transform(node_defs.begin(), node_defs.end(),
|
||||||
|
std::back_inserter(node_names),
|
||||||
|
[](const NodeDef* node) { return node->name(); });
|
||||||
|
ExpectVectorEq(expected, node_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method for validating a set.
|
||||||
|
template <typename T>
|
||||||
|
void ExpectSetEq(const std::set<T>& expected,
|
||||||
|
const std::set<T>& test_elements) {
|
||||||
|
for (const auto& element : test_elements) {
|
||||||
|
EXPECT_GT(expected.count(element), 0);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(expected.size(), test_elements.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method tthat checks name - port pairs.
|
||||||
|
void ValidateMemoryUsageSnapshot(
|
||||||
|
const std::vector<string>& expected_names, const int port_num_expected,
|
||||||
|
const std::set<std::pair<const NodeDef*, int>>& mem_usage_snapshot) {
|
||||||
|
std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
|
||||||
|
std::transform(
|
||||||
|
mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
|
||||||
|
std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
|
||||||
|
[](const std::pair<const NodeDef*, int>& node_port) {
|
||||||
|
return std::make_pair(node_port.first->name(), node_port.second);
|
||||||
|
});
|
||||||
|
std::set<std::pair<string, int>> expected;
|
||||||
|
std::transform(expected_names.begin(), expected_names.end(),
|
||||||
|
std::inserter(expected, expected.begin()),
|
||||||
|
[port_num_expected](const string& name) {
|
||||||
|
return std::make_pair(name, port_num_expected);
|
||||||
|
});
|
||||||
|
ExpectSetEq(expected, nodes_at_peak_mem_usage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method for converting shape vector to TensorProperty.
|
||||||
|
OpInfo::TensorProperties ShapeToTensorProperty(
|
||||||
|
const std::vector<int> shape, const DataType& data_type) const {
|
||||||
|
OpInfo::TensorProperties tensor_property;
|
||||||
|
tensor_property.set_dtype(data_type);
|
||||||
|
for (const auto& x : shape) {
|
||||||
|
tensor_property.mutable_shape()->add_dim()->set_size(x);
|
||||||
|
}
|
||||||
|
return tensor_property;
|
||||||
|
}
|
||||||
|
|
||||||
// SetUp() inits cluster_ and placer_.
|
// SetUp() inits cluster_ and placer_.
|
||||||
std::unique_ptr<VirtualCluster> cluster_;
|
std::unique_ptr<VirtualCluster> cluster_;
|
||||||
std::unique_ptr<VirtualPlacer> placer_;
|
std::unique_ptr<VirtualPlacer> placer_;
|
||||||
|
|
||||||
// grappler_item_ and scheduler_ will be initialized differently for each test
|
// grappler_item_ and scheduler_ will be initialized differently for each test
|
||||||
// case
|
// case.
|
||||||
std::unique_ptr<GrapplerItem> grappler_item_;
|
std::unique_ptr<GrapplerItem> grappler_item_;
|
||||||
std::unique_ptr<VirtualScheduler> scheduler_;
|
std::unique_ptr<TestVirtualScheduler> scheduler_;
|
||||||
|
// Node name -> its preceding nodes map for testing scheduling order.
|
||||||
|
std::unordered_map<string, std::vector<string>> dependency_;
|
||||||
|
|
||||||
|
// Shared params for Conv2D related graphs:
|
||||||
|
const int batch_size_ = 4;
|
||||||
|
const int width_ = 10;
|
||||||
|
const int height_ = 10;
|
||||||
|
const int depth_in_ = 8;
|
||||||
|
const int kernel_ = 3;
|
||||||
|
const int depth_out_ = 16;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
|
TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
|
||||||
CreateSchedulerWithConv2Ds(); // init scheduler_.
|
// Init.
|
||||||
|
CreateGrapplerItemWithConv2Ds();
|
||||||
|
InitScheduler();
|
||||||
|
|
||||||
Costs zero_costs = Costs::ZeroCosts();
|
// Run the scheduler.
|
||||||
std::unordered_map<string, NodeInfo> ops_executed;
|
auto ops_executed = RunScheduler(""); // Run all the nodes.
|
||||||
do {
|
|
||||||
NodeInfo node_info = scheduler_->GetCurrNodeInfo();
|
|
||||||
ops_executed[node_info.name] = node_info;
|
|
||||||
|
|
||||||
// Check scheduling order: x and f before c0, and y and f before c1.
|
|
||||||
if (node_info.name == "c0") {
|
|
||||||
EXPECT_GT(ops_executed.count("x"), 0);
|
|
||||||
EXPECT_GT(ops_executed.count("f"), 0);
|
|
||||||
} else if (node_info.name == "c1") {
|
|
||||||
EXPECT_GT(ops_executed.count("y"), 0);
|
|
||||||
EXPECT_GT(ops_executed.count("f"), 0);
|
|
||||||
}
|
|
||||||
} while (scheduler_->MarkCurrNodeExecuted(zero_costs));
|
|
||||||
|
|
||||||
// [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
|
// [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
|
||||||
// executed.
|
// executed.
|
||||||
|
|
@ -132,5 +345,162 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
|
||||||
EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
|
EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
|
||||||
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
|
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(VirtualSchedulerTest, CalculateOutputSize) {
|
||||||
|
// Init.
|
||||||
|
CreateGrapplerItemWithAddN();
|
||||||
|
InitScheduler();
|
||||||
|
|
||||||
|
// Create a set of tensor properties.
|
||||||
|
std::vector<OpInfo::TensorProperties> output;
|
||||||
|
output.push_back(ShapeToTensorProperty({4, 4}, DT_FLOAT)); // 0
|
||||||
|
output.push_back(ShapeToTensorProperty({1}, DT_FLOAT)); // 1
|
||||||
|
output.push_back(ShapeToTensorProperty({10, 10, 10}, DT_HALF)); // 2
|
||||||
|
output.push_back(ShapeToTensorProperty({100, 7, 8, 99}, DT_FLOAT)); // 3
|
||||||
|
output.push_back(ShapeToTensorProperty({-1, 7, 8, 99}, DT_FLOAT)); // 4
|
||||||
|
output.push_back(ShapeToTensorProperty({-1, 7, -1, 99}, DT_FLOAT)); // 4
|
||||||
|
|
||||||
|
// port_num -1 is for control dependency: hard coded 4B.
|
||||||
|
EXPECT_EQ(4, scheduler_->CalculateOutputSize(output, -1));
|
||||||
|
|
||||||
|
// Test valid outputs.
|
||||||
|
EXPECT_EQ(4 * 4 * 4, scheduler_->CalculateOutputSize(output, 0));
|
||||||
|
EXPECT_EQ(4 * 1, scheduler_->CalculateOutputSize(output, 1));
|
||||||
|
EXPECT_EQ(2 * 10 * 10 * 10, scheduler_->CalculateOutputSize(output, 2));
|
||||||
|
EXPECT_EQ(4 * 100 * 7 * 8 * 99, scheduler_->CalculateOutputSize(output, 3));
|
||||||
|
|
||||||
|
// Any uknown shape (-1) shall yield zero output size.
|
||||||
|
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 4));
|
||||||
|
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 5));
|
||||||
|
|
||||||
|
// Invalid port_num (though it may be an error) shall yield zero
|
||||||
|
// output size.
|
||||||
|
EXPECT_EQ(0, scheduler_->CalculateOutputSize(output, 6));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(VirtualSchedulerTest, MemoryUsage) {
|
||||||
|
// Init.
|
||||||
|
CreateGrapplerItemWithAddN();
|
||||||
|
InitScheduler();
|
||||||
|
|
||||||
|
// Run the scheduler.
|
||||||
|
RunScheduler("");
|
||||||
|
|
||||||
|
const auto& device_states = scheduler_->GetDeviceStates();
|
||||||
|
const auto& cpu_state = device_states.at(kCPU0);
|
||||||
|
|
||||||
|
// out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
|
||||||
|
// is 4 x the input tensor size while executing the out node.
|
||||||
|
int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
|
||||||
|
const std::vector<string> expected_names = {"x", "y", "z", "w"};
|
||||||
|
EXPECT_EQ(expected_names.size() * one_input_node_size,
|
||||||
|
cpu_state.max_memory_usage);
|
||||||
|
ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
|
||||||
|
cpu_state.mem_usage_snapshot_at_peak);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(VirtualSchedulerTest, ControlDependency) {
|
||||||
|
// Init.
|
||||||
|
CreateGrapplerItemWithControlDependency();
|
||||||
|
InitScheduler();
|
||||||
|
|
||||||
|
// Run the scheduler.
|
||||||
|
RunScheduler("");
|
||||||
|
|
||||||
|
const auto& device_states = scheduler_->GetDeviceStates();
|
||||||
|
const auto& cpu_state = device_states.at(kCPU0);
|
||||||
|
|
||||||
|
// The graph has a NoOp that takes control dependency from 7 NoOps. The peak
|
||||||
|
// memory usage is when executing the final NoOp.
|
||||||
|
int64 one_input_node_size = 4; // control dependency
|
||||||
|
const std::vector<string> expected_names = {"x", "y", "z", "w",
|
||||||
|
"u", "v", "t"};
|
||||||
|
EXPECT_EQ(expected_names.size() * one_input_node_size,
|
||||||
|
cpu_state.max_memory_usage);
|
||||||
|
ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
|
||||||
|
cpu_state.mem_usage_snapshot_at_peak);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(VirtualSchedulerTest, ComplexDependency) {
|
||||||
|
// Init.
|
||||||
|
CreateGrapplerItemWithBatchNorm();
|
||||||
|
InitScheduler();
|
||||||
|
|
||||||
|
// Run the scheduler.
|
||||||
|
RunScheduler("bn");
|
||||||
|
|
||||||
|
const auto& device_states = scheduler_->GetDeviceStates();
|
||||||
|
const auto& cpu_state = device_states.at(kCPU0);
|
||||||
|
|
||||||
|
// The graph is
|
||||||
|
// bn = FusedBatchNorm(x, scale, offset, mean, var)
|
||||||
|
// z1 = bn.y + x
|
||||||
|
// z2 = bn.var + bn.var
|
||||||
|
// z3 = bn.var + bn.var
|
||||||
|
// z4 = control dependency from bn.
|
||||||
|
// Note that bn.mean doesn't have any consumer.
|
||||||
|
const int x_size = batch_size_ * width_ * height_ * depth_in_;
|
||||||
|
int64 expected_size =
|
||||||
|
4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
|
||||||
|
1 /* control dependency */);
|
||||||
|
EXPECT_EQ(expected_size, cpu_state.memory_usage);
|
||||||
|
|
||||||
|
// Nodes currrently in memory: bn's port -1, 0, and 2, and x's port 0.
|
||||||
|
std::set<std::pair<string, int>> nodes_in_memory;
|
||||||
|
std::transform(
|
||||||
|
cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
|
||||||
|
std::inserter(nodes_in_memory, nodes_in_memory.begin()),
|
||||||
|
[](const std::pair<const NodeDef*, int>& node_port) {
|
||||||
|
return std::make_pair(node_port.first->name(), node_port.second);
|
||||||
|
});
|
||||||
|
std::set<std::pair<string, int>> expected = {
|
||||||
|
std::make_pair("bn", -1), std::make_pair("bn", 0),
|
||||||
|
std::make_pair("bn", 2), std::make_pair("x", 0),
|
||||||
|
};
|
||||||
|
ExpectSetEq(expected, nodes_in_memory);
|
||||||
|
|
||||||
|
const auto& node_states = scheduler_->GetNodeStates();
|
||||||
|
const NodeState* bn_node = nullptr;
|
||||||
|
const NodeState* x_node = nullptr;
|
||||||
|
for (const auto& nodedef_node_state : node_states) {
|
||||||
|
const NodeDef* node = nodedef_node_state.first;
|
||||||
|
const NodeState& node_state = nodedef_node_state.second;
|
||||||
|
if (node->name() == "bn") {
|
||||||
|
bn_node = &node_state;
|
||||||
|
}
|
||||||
|
if (node->name() == "x") {
|
||||||
|
x_node = &node_state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK_NOTNULL(bn_node);
|
||||||
|
CHECK_NOTNULL(x_node);
|
||||||
|
|
||||||
|
ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
|
||||||
|
ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
|
||||||
|
ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
|
||||||
|
// z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
|
||||||
|
ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(VirtualSchedulerTest, Variable) {
|
||||||
|
// Init.
|
||||||
|
CreateGrapplerItemWithConv2DAndVariable();
|
||||||
|
InitScheduler();
|
||||||
|
|
||||||
|
// Run the scheduler.
|
||||||
|
RunScheduler("");
|
||||||
|
|
||||||
|
const auto& device_states = scheduler_->GetDeviceStates();
|
||||||
|
const auto& cpu_state = device_states.at(kCPU0);
|
||||||
|
|
||||||
|
// There is one Conv2D that takes x and f, but f is variable, so it should be
|
||||||
|
// in persistent nodes.
|
||||||
|
// f is variable.
|
||||||
|
ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
|
||||||
|
cpu_state.persistent_nodes);
|
||||||
|
// Only x in peak memory usage snapshot.
|
||||||
|
ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
|
||||||
|
cpu_state.mem_usage_snapshot_at_peak);
|
||||||
|
}
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
|
||||||
|
|
@ -45,18 +45,33 @@ bool IsMerge(const NodeDef& node) {
|
||||||
return op == "Merge";
|
return op == "Merge";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsNoOp(const NodeDef& node) {
|
||||||
|
const auto op = node.op();
|
||||||
|
return op == "NoOp";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsPlaceholder(const NodeDef& node) {
|
bool IsPlaceholder(const NodeDef& node) {
|
||||||
const auto op = node.op();
|
const auto op = node.op();
|
||||||
return op == "Placeholder" || op == "PlaceholderV2" ||
|
return op == "Placeholder" || op == "PlaceholderV2" ||
|
||||||
op == "PlaceholderWithDefault";
|
op == "PlaceholderWithDefault";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsRecv(const NodeDef& node) {
|
||||||
|
const auto op = node.op();
|
||||||
|
return op == "_Recv";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsReduction(const NodeDef& node) {
|
bool IsReduction(const NodeDef& node) {
|
||||||
const auto& op = node.op();
|
const auto& op = node.op();
|
||||||
return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
|
return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
|
||||||
op == "Mean" || op == "Any" || op == "All";
|
op == "Mean" || op == "Any" || op == "All";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsSend(const NodeDef& node) {
|
||||||
|
const auto op = node.op();
|
||||||
|
return op == "_Send";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsSwitch(const NodeDef& node) {
|
bool IsSwitch(const NodeDef& node) {
|
||||||
const auto& op = node.op();
|
const auto& op = node.op();
|
||||||
return op == "Switch";
|
return op == "Switch";
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,11 @@ bool IsConstant(const NodeDef& node);
|
||||||
bool IsDequeueOp(const NodeDef& node);
|
bool IsDequeueOp(const NodeDef& node);
|
||||||
bool IsIdentity(const NodeDef& node);
|
bool IsIdentity(const NodeDef& node);
|
||||||
bool IsMerge(const NodeDef& node);
|
bool IsMerge(const NodeDef& node);
|
||||||
|
bool IsNoOp(const NodeDef& node);
|
||||||
bool IsPlaceholder(const NodeDef& node);
|
bool IsPlaceholder(const NodeDef& node);
|
||||||
|
bool IsRecv(const NodeDef& node);
|
||||||
bool IsReduction(const NodeDef& node);
|
bool IsReduction(const NodeDef& node);
|
||||||
|
bool IsSend(const NodeDef& node);
|
||||||
bool IsSwitch(const NodeDef& node);
|
bool IsSwitch(const NodeDef& node);
|
||||||
bool IsTranspose(const NodeDef& node);
|
bool IsTranspose(const NodeDef& node);
|
||||||
bool IsVariable(const NodeDef& node);
|
bool IsVariable(const NodeDef& node);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user