mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Call maxcut algorithm in the model_based_cost_estimator.
PiperOrigin-RevId: 158078511
This commit is contained in:
parent
7d76a90be4
commit
0cc851c08f
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
@ -94,6 +95,9 @@ struct Costs {
|
|||
// streams from main memory.
|
||||
// If the time estimation is inaccurate.
|
||||
bool inaccurate = false;
|
||||
|
||||
// Max possible memory usage per device.
|
||||
std::unordered_map<string, uint64> estimated_max_memory_per_device;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
|
||||
|
|
|
|||
|
|
@ -201,14 +201,13 @@ DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
|
|||
return GetDeviceInfo(node.device());
|
||||
}
|
||||
|
||||
OpInfo BuildOpInfo(
|
||||
const NodeDef& node, const string& device_str,
|
||||
OpInfo BuildOpInfoWithoutDevice(
|
||||
const NodeDef& node,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
const std::vector<OpInfo::TensorProperties>& inputs) {
|
||||
OpInfo op_info;
|
||||
op_info.set_op(node.op());
|
||||
*op_info.mutable_attr() = node.attr();
|
||||
*op_info.mutable_device() = GetDeviceInfo(device_str);
|
||||
for (auto& input : inputs) {
|
||||
*op_info.add_inputs() = input;
|
||||
}
|
||||
|
|
@ -263,8 +262,8 @@ OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph,
|
|||
|
||||
std::vector<OpInfo::TensorProperties> inputs =
|
||||
FindInputFeatures(node, name_to_cost, name_to_node);
|
||||
(*perf->mutable_op()) =
|
||||
BuildOpInfo(node, cost_node->device(), name_to_node, inputs);
|
||||
*perf->mutable_op() = BuildOpInfoWithoutDevice(node, name_to_node, inputs);
|
||||
*perf->mutable_op()->mutable_device() = GetDeviceInfo(cost_node->device());
|
||||
|
||||
perf->set_temporary_memory_size(cost_node->temporary_memory_size());
|
||||
// Note that CostGraphDef::Node::compute_cost is microseconds, while
|
||||
|
|
|
|||
|
|
@ -49,11 +49,10 @@ DeviceProperties GetDeviceInfo(const string& device_str);
|
|||
// Return a string describing a node given a nodeinfo.
|
||||
string GetOpDescription(const OpInfo& op_info);
|
||||
|
||||
// Builds the OpInfo proto for node, given all nodes in the graph, the node's
|
||||
// device and its input properties which are typically built by shape inference
|
||||
// or calling FindInputFeatures.
|
||||
OpInfo BuildOpInfo(
|
||||
const NodeDef& node, const string& device_str,
|
||||
// Builds the OpInfo for node without filling its device information, given all
|
||||
// nodes in the graph and its input properties.
|
||||
OpInfo BuildOpInfoWithoutDevice(
|
||||
const NodeDef& node,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
const std::vector<OpInfo::TensorProperties>& inputs);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user