mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Sets the incarnation number even when the attribute is set.
PiperOrigin-RevId: 163299121
This commit is contained in:
parent
a49fe03668
commit
a524701723
|
|
@ -909,8 +909,13 @@ void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
|
||||||
// No known send_device. The runtime will detect it later.
|
// No known send_device. The runtime will detect it later.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int64 incarnation = opts.get_incarnation(send_device);
|
int64 incarnation = PartitionOptions::kIllegalIncarnation;
|
||||||
AddNodeAttr("send_device_incarnation", incarnation, ndef);
|
if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
|
||||||
|
(incarnation == PartitionOptions::kIllegalIncarnation)) {
|
||||||
|
incarnation = opts.get_incarnation(send_device);
|
||||||
|
SetAttrValue(incarnation,
|
||||||
|
&((*ndef->mutable_attr())["send_device_incarnation"]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets attribute send_device_incarnation of all Send/Recv nodes in
|
// Sets attribute send_device_incarnation of all Send/Recv nodes in
|
||||||
|
|
|
||||||
|
|
@ -445,6 +445,41 @@ TEST_F(GraphPartitionTest, Functions) {
|
||||||
ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"});
|
ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphPartitionTest, SetIncarnation) {
|
||||||
|
GraphDef gdef;
|
||||||
|
const char* const kSendRecvAttrs = R"proto(
|
||||||
|
attr { key: 'T' value { type: DT_FLOAT } }
|
||||||
|
attr { key: 'client_terminated' value { b: false } }
|
||||||
|
attr { key: 'recv_device' value { s: 'B' } }
|
||||||
|
attr { key: 'send_device' value { s: 'A' } }
|
||||||
|
attr { key: 'send_device_incarnation' value { i: 0 } }
|
||||||
|
attr { key: 'tensor_name' value { s: 'test' } }
|
||||||
|
)proto";
|
||||||
|
CHECK(protobuf::TextFormat::ParseFromString(
|
||||||
|
StrCat("node { name: 'A/Pi' op: 'Const' ",
|
||||||
|
" attr { key: 'dtype' value { type: DT_FLOAT } } ",
|
||||||
|
" attr { key: 'value' value { tensor { ",
|
||||||
|
" dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }",
|
||||||
|
"node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}",
|
||||||
|
"node { name: 'B' op: '_Recv' ", kSendRecvAttrs,
|
||||||
|
" attr { key: 'tensor_type' value { type:DT_FLOAT}}}"),
|
||||||
|
&gdef));
|
||||||
|
gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION);
|
||||||
|
Partition(gdef, &partitions_);
|
||||||
|
EXPECT_EQ(2, partitions_.size());
|
||||||
|
|
||||||
|
for (const auto& kv : partitions_) {
|
||||||
|
const GraphDef& gdef = kv.second;
|
||||||
|
for (const NodeDef& ndef : gdef.node()) {
|
||||||
|
if (ndef.name() == "A" || ndef.name() == "B") {
|
||||||
|
int64 val;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val));
|
||||||
|
EXPECT_EQ(val, 100); // Send device is "A".
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) {
|
TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) {
|
||||||
// Create placeholders, shuffle them so the order in the graph is not strictly
|
// Create placeholders, shuffle them so the order in the graph is not strictly
|
||||||
// increasing.
|
// increasing.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user