tfdbg: dump debug data from different devices in separate directories

Fixes: #7051
wherein TFDBG failed to load the data dump from a Session.run() involving multiple GPUs.

The root cause of the bug was that TFDBG previously assumed that node names are unique across all partition graphs. This is however not the case when multiple GPUs exist. The Send/Recv nodes in the partition graphs of the GPUs can have duplicate names. There will potentially be other cases like this in the future due to other reasons (e.g., distributed sessions and/or graph optimization).

This CL relaxes this assumption, by dumping the GraphDef and tensor data from different devices into different sub-directories under the dump root directory.

PiperOrigin-RevId: 158029814
This commit is contained in:
Shanqing Cai 2017-06-05 10:29:50 -07:00 committed by TensorFlower Gardener
parent a5909d6432
commit cc2dd4ac85
10 changed files with 880 additions and 322 deletions

View File

@ -119,6 +119,18 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
} // namespace
// static
const char* const DebugIO::kMetadataFilePrefix = "_tfdbg_";
// static
const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
// static
const char* const DebugIO::kDeviceTag = "device_";
// static
const char* const DebugIO::kGraphTag = "graph_";
DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
const int32 output_slot, const string& debug_op)
: device_name(device_name),
@ -126,7 +138,8 @@ DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
output_slot(output_slot),
debug_op(debug_op),
debug_node_name(
strings::StrCat(node_name, ":", output_slot, ":", debug_op)) {}
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
device_path(DeviceNameToDevicePath(device_name)) {}
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
Env* env(Env::Default());
@ -157,6 +170,15 @@ Status ReadEventFromFile(const string& dump_file_path, Event* event) {
return Status::OK();
}
// static
const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
return strings::StrCat(
DebugIO::kMetadataFilePrefix, DebugIO::kDeviceTag,
str_util::StringReplace(
str_util::StringReplace(device_name, ":", "_", true), "/", ",",
true));
}
// static
const char* const DebugIO::kFileURLScheme = "file://";
// static
@ -236,7 +258,8 @@ Status DebugIO::PublishDebugMetadata(
const string core_metadata_path = AppendTimestampToFilePath(
io::JoinPath(
dump_root_dir,
strings::StrCat("_tfdbg_core_metadata_", "sessionrun",
strings::StrCat(DebugIO::kMetadataFilePrefix,
DebugIO::kCoreMetadataTag, "sessionrun",
strings::Printf("%.14lld", session_run_index))),
Env::Default()->NowMicros());
status.Update(DebugFileIO::DumpEventProtoToFile(
@ -325,10 +348,11 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
Status status = Status::OK();
for (const string& debug_url : debug_urls) {
if (debug_url.find(kFileURLScheme) == 0) {
const string dump_root_dir = debug_url.substr(strlen(kFileURLScheme));
// TODO(cais): (b/38325442) Serialize the GraphDef to a directory that
// reflects the device name.
const string file_name = strings::StrCat("_tfdbg_graph_", now_micros);
const string dump_root_dir =
io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
DebugNodeKey::DeviceNameToDevicePath(device_name));
const string file_name = strings::StrCat(DebugIO::kMetadataFilePrefix,
DebugIO::kGraphTag, now_micros);
status.Update(
DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
@ -437,7 +461,7 @@ string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
const DebugNodeKey& debug_node_key,
const uint64 wall_time_us) {
return AppendTimestampToFilePath(
io::JoinPath(dump_root_dir,
io::JoinPath(dump_root_dir, debug_node_key.device_path,
strings::StrCat(debug_node_key.node_name, "_",
debug_node_key.output_slot, "_",
debug_node_key.debug_op)),

View File

@ -44,11 +44,14 @@ struct DebugNodeKey {
DebugNodeKey(const string& device_name, const string& node_name,
const int32 output_slot, const string& debug_op);
static const string DeviceNameToDevicePath(const string& device_name);
const string device_name;
const string node_name;
const int32 output_slot;
const string debug_op;
const string debug_node_name;
const string device_path;
};
class DebugIO {
@ -136,6 +139,11 @@ class DebugIO {
static Status CloseDebugURL(const string& debug_url);
static const char* const kMetadataFilePrefix;
static const char* const kCoreMetadataTag;
static const char* const kDeviceTag;
static const char* const kGraphTag;
static const char* const kFileURLScheme;
static const char* const kGrpcURLScheme;
};

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/event.pb.h"
@ -47,6 +48,18 @@ class DebugIOUtilsTest : public ::testing::Test {
std::unique_ptr<Tensor> tensor_b_;
};
TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
EXPECT_EQ("/job:worker/replica:1/task:0/gpu:2", debug_node_key.device_name);
EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name);
EXPECT_EQ(0, debug_node_key.output_slot);
EXPECT_EQ("DebugIdentity", debug_node_key.debug_op);
EXPECT_EQ("hidden_1/MatMul:0:DebugIdentity", debug_node_key.debug_node_name);
EXPECT_EQ("_tfdbg_device_,job_worker,replica_1,task_0,gpu_2",
debug_node_key.device_path);
}
TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
Initialize();
@ -138,10 +151,14 @@ TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) {
// First, create the file at the path.
const string test_dir = testing::TmpDir();
const string txt_file_name = strings::StrCat(test_dir, "/baz");
if (!env_->FileExists(test_dir).ok()) {
ASSERT_TRUE(env_->CreateDir(test_dir).ok());
const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0";
const DebugNodeKey kDebugNodeKey(kDeviceName, "baz/tensor_a", 0,
"DebugIdentity");
const string txt_file_dir =
io::JoinPath(test_dir, DebugNodeKey::DeviceNameToDevicePath(kDeviceName));
const string txt_file_name = io::JoinPath(txt_file_dir, "baz");
if (!env_->FileExists(txt_file_dir).ok()) {
ASSERT_TRUE(env_->RecursivelyCreateDir(txt_file_dir).ok());
}
ASSERT_EQ(error::Code::NOT_FOUND, env_->FileExists(txt_file_name).code());
@ -157,8 +174,7 @@ TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) {
// Second, try to dump the tensor to a path that requires "baz" to be a
// directory, which should lead to an error.
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
"baz/tensor_a", 0, "DebugIdentity");
const uint64 wall_time = env_->NowMicros();
string dump_file_name;

View File

@ -187,7 +187,10 @@ TEST_F(GrpcSessionDebugTest, FileDebugURL) {
IsSingleFloatValue(outputs[0], 4.0);
std::vector<Tensor> dumped_tensors;
LoadTensorDumps("n", &dumped_tensors);
LoadTensorDumps(io::JoinPath(DebugNodeKey::DeviceNameToDevicePath(
cluster->devices()[0].name()),
"n"),
&dumped_tensors);
if (i == 0 || i == 5) {
ASSERT_EQ(0, dumped_tensors.size());
@ -267,7 +270,10 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
TF_CHECK_OK(session->Close());
std::vector<Tensor> dumped_tensors;
LoadTensorDumps("n", &dumped_tensors);
LoadTensorDumps(
io::JoinPath(DebugNodeKey::DeviceNameToDevicePath(a_dev.name()),
"n"),
&dumped_tensors);
ASSERT_EQ(1, dumped_tensors.size());
ASSERT_EQ(TensorShape({2, 2}), dumped_tensors[0].shape());
for (size_t i = 0; i < 4; ++i) {

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@ -94,7 +95,22 @@ TEST_F(DebugIdentityOpTest, Int32Success_6_FileURLs) {
ASSERT_TRUE(env_->FileExists(dump_roots[i]).ok());
ASSERT_TRUE(env_->IsDirectory(dump_roots[i]).ok());
DIR* dir = opendir(dump_roots[i].c_str());
std::vector<string> device_roots;
DIR* dir0 = opendir(dump_roots[i].c_str());
struct dirent* ent0;
const string kDeviceDirPrefix =
strings::StrCat(DebugIO::kMetadataFilePrefix, DebugIO::kDeviceTag);
while ((ent0 = readdir(dir0)) != nullptr) {
if (!strncmp(ent0->d_name, kDeviceDirPrefix.c_str(),
kDeviceDirPrefix.size())) {
device_roots.push_back(io::JoinPath(dump_roots[i], ent0->d_name));
}
}
ASSERT_EQ(1, device_roots.size());
closedir(dir0);
const string& device_root = device_roots[0];
DIR* dir = opendir(device_root.c_str());
struct dirent* ent;
int dump_files_found = 0;
while ((ent = readdir(dir)) != nullptr) {
@ -102,8 +118,7 @@ TEST_F(DebugIdentityOpTest, Int32Success_6_FileURLs) {
dump_files_found++;
// Try reading the file into a Event proto.
const string dump_file_path =
strings::StrCat(dump_roots[i], "/", ent->d_name);
const string dump_file_path = io::JoinPath(device_root, ent->d_name);
std::fstream ifs(dump_file_path, std::ios::in | std::ios::binary);
Event event;
event.ParseFromIstream(&ifs);

View File

@ -545,6 +545,22 @@ cuda_py_test(
tags = ["notsan"],
)
cuda_py_test(
name = "session_debug_multi_gpu_test",
size = "small",
srcs = ["lib/session_debug_multi_gpu_test.py"],
additional_deps = [
":debug_data",
":debug_utils",
"//tensorflow/python:client",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)
py_test(
name = "debugger_cli_common_test",
size = "small",

File diff suppressed because it is too large Load Diff

View File

@ -23,12 +23,29 @@ import tempfile
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class DeviceNamePathConversionTest(test_util.TensorFlowTestCase):
def testDeviceNameToDevicePath(self):
self.assertEqual(
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_ps,replica_1,task_2,cpu_0",
debug_data.device_name_to_device_path("/job:ps/replica:1/task:2/cpu:0"))
def testDevicePathToDeviceName(self):
self.assertEqual(
"/job:ps/replica:1/task:2/cpu:0",
debug_data.device_path_to_device_name(
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_ps,replica_1,task_2,cpu_0"))
class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase):
def testParseNodeName(self):
@ -163,7 +180,10 @@ class DebugTensorDatumTest(test_util.TensorFlowTestCase):
def testDebugDatum(self):
dump_root = "/tmp/tfdbg_1"
debug_dump_rel_path = "ns1/ns2/node_a_1_2_DebugIdentity_1472563253536385"
debug_dump_rel_path = (
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_localhost,replica_0,task_0,cpu_0" +
"/ns1/ns2/node_a_1_2_DebugIdentity_1472563253536385")
datum = debug_data.DebugTensorDatum(dump_root, debug_dump_rel_path)
@ -175,16 +195,18 @@ class DebugTensorDatumTest(test_util.TensorFlowTestCase):
self.assertEqual("ns1/ns2/node_a_1:2:DebugIdentity", datum.watch_key)
self.assertEqual(
os.path.join(dump_root, debug_dump_rel_path), datum.file_path)
self.assertEqual("{DebugTensorDatum: %s:%d @ %s @ %d}" % (datum.node_name,
datum.output_slot,
datum.debug_op,
datum.timestamp),
str(datum))
self.assertEqual("{DebugTensorDatum: %s:%d @ %s @ %d}" % (datum.node_name,
datum.output_slot,
datum.debug_op,
datum.timestamp),
repr(datum))
self.assertEqual(
"{DebugTensorDatum (/job:localhost/replica:0/task:0/cpu:0) "
"%s:%d @ %s @ %d}" % (datum.node_name,
datum.output_slot,
datum.debug_op,
datum.timestamp), str(datum))
self.assertEqual(
"{DebugTensorDatum (/job:localhost/replica:0/task:0/cpu:0) "
"%s:%d @ %s @ %d}" % (datum.node_name,
datum.output_slot,
datum.debug_op,
datum.timestamp), repr(datum))
def testDumpSizeBytesIsNoneForNonexistentFilePath(self):
dump_root = "/tmp/tfdbg_1"
@ -204,18 +226,112 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
# Tear down temporary dump directory.
shutil.rmtree(self._dump_root)
def _makeDataDirWithMultipleDevicesAndDuplicateNodeNames(self):
cpu_0_dir = os.path.join(
self._dump_root,
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_localhost,replica_0,task_0,cpu_0")
gpu_0_dir = os.path.join(
self._dump_root,
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_localhost,replica_0,task_0,gpu_0")
gpu_1_dir = os.path.join(
self._dump_root,
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_localhost,replica_0,task_0,gpu_1")
os.makedirs(cpu_0_dir)
os.makedirs(gpu_0_dir)
os.makedirs(gpu_1_dir)
open(os.path.join(
cpu_0_dir, "node_foo_1_2_DebugIdentity_1472563253536386"), "wb")
open(os.path.join(
gpu_0_dir, "node_foo_1_2_DebugIdentity_1472563253536385"), "wb")
open(os.path.join(
gpu_1_dir, "node_foo_1_2_DebugIdentity_1472563253536387"), "wb")
def testDebugDumpDir_nonexistentDumpRoot(self):
with self.assertRaisesRegexp(IOError, "does not exist"):
debug_data.DebugDumpDir(tempfile.mktemp() + "_foo")
def testDebugDumpDir_invalidFileNamingPattern(self):
# File name with too few underscores should lead to an exception.
open(os.path.join(self._dump_root, "node1_DebugIdentity_1234"), "wb")
device_dir = os.path.join(
self._dump_root,
debug_data.METADATA_FILE_PREFIX + debug_data.DEVICE_TAG +
",job_localhost,replica_0,task_0,cpu_0")
os.makedirs(device_dir)
open(os.path.join(device_dir, "node1_DebugIdentity_1234"), "wb")
with self.assertRaisesRegexp(ValueError,
"does not conform to the naming pattern"):
debug_data.DebugDumpDir(self._dump_root)
def testDebugDumpDir_validDuplicateNodeNamesWithMultipleDevices(self):
self._makeDataDirWithMultipleDevicesAndDuplicateNodeNames()
graph_cpu_0 = graph_pb2.GraphDef()
node = graph_cpu_0.node.add()
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/cpu:0"
graph_gpu_0 = graph_pb2.GraphDef()
node = graph_gpu_0.node.add()
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/gpu:0"
graph_gpu_1 = graph_pb2.GraphDef()
node = graph_gpu_1.node.add()
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/gpu:1"
dump_dir = debug_data.DebugDumpDir(
self._dump_root,
partition_graphs=[graph_cpu_0, graph_gpu_0, graph_gpu_1])
self.assertItemsEqual(
["/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/gpu:0",
"/job:localhost/replica:0/task:0/gpu:1"], dump_dir.devices())
self.assertEqual(1472563253536385, dump_dir.t0)
self.assertEqual(3, dump_dir.size)
with self.assertRaisesRegexp(
ValueError,
r"There are multiple \(3\) devices, but device_name is not specified"):
dump_dir.nodes()
self.assertItemsEqual(
["node_foo_1"],
dump_dir.nodes(device_name="/job:localhost/replica:0/task:0/cpu:0"))
def testDuplicateNodeNamesInGraphDefOfSingleDeviceRaisesException(self):
self._makeDataDirWithMultipleDevicesAndDuplicateNodeNames()
graph_cpu_0 = graph_pb2.GraphDef()
node = graph_cpu_0.node.add()
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/cpu:0"
graph_gpu_0 = graph_pb2.GraphDef()
node = graph_gpu_0.node.add()
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/gpu:0"
graph_gpu_1 = graph_pb2.GraphDef()
node = graph_gpu_1.node.add()
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/gpu:1"
node = graph_gpu_1.node.add() # Here is the duplicate.
node.name = "node_foo_1"
node.op = "FooOp"
node.device = "/job:localhost/replica:0/task:0/gpu:1"
with self.assertRaisesRegexp(
ValueError, r"Duplicate node name on device "):
debug_data.DebugDumpDir(
self._dump_root,
partition_graphs=[graph_cpu_0, graph_gpu_0, graph_gpu_1])
def testDebugDumpDir_emptyDumpDir(self):
dump_dir = debug_data.DebugDumpDir(self._dump_root)

View File

@ -0,0 +1,93 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for debugger functionalities under multiple (i.e., >1) GPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import tempfile
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
class SessionDebugMultiGPUTest(test_util.TensorFlowTestCase):
def setUp(self):
self._dump_root = tempfile.mkdtemp()
def tearDown(self):
ops.reset_default_graph()
# Tear down temporary dump directory.
if os.path.isdir(self._dump_root):
shutil.rmtree(self._dump_root)
def testMultiGPUSessionRun(self):
local_devices = device_lib.list_local_devices()
gpu_device_names = []
for device in local_devices:
if device.device_type == "GPU":
gpu_device_names.append(device.name)
gpu_device_names = sorted(gpu_device_names)
if len(gpu_device_names) < 2:
self.skipTest(
"This test requires at least 2 GPUs, but only %d is available." %
len(gpu_device_names))
with session.Session() as sess:
v = variables.Variable([10.0, 15.0], dtype=dtypes.float32, name="v")
with ops.device(gpu_device_names[0]):
u0 = math_ops.add(v, v, name="u0")
with ops.device(gpu_device_names[1]):
u1 = math_ops.multiply(v, v, name="u1")
w = math_ops.subtract(u1, u0, name="w")
sess.run(v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options, sess.graph,
debug_urls="file://" + self._dump_root)
run_metadata = config_pb2.RunMetadata()
self.assertAllClose(
[80.0, 195.0],
sess.run(w, options=run_options, run_metadata=run_metadata))
debug_dump_dir = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=run_metadata.partition_graphs)
self.assertEqual(3, len(debug_dump_dir.devices()))
self.assertAllClose(
[10.0, 15.0], debug_dump_dir.get_tensors("v", 0, "DebugIdentity")[0])
self.assertAllClose(
[20.0, 30.0], debug_dump_dir.get_tensors("u0", 0, "DebugIdentity")[0])
self.assertAllClose(
[100.0, 225.0],
debug_dump_dir.get_tensors("u1", 0, "DebugIdentity")[0])
if __name__ == "__main__":
googletest.main()

View File

@ -249,8 +249,12 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertIn(results.v.op.type, results.dump.node_op_type(results.v_name))
self.assertIn(results.w.op.type, results.dump.node_op_type(results.w_name))
with self.assertRaisesRegexp(
ValueError, "Node 'foo_bar' does not exist in partition graphs."):
if test_util.gpu_device_name():
expected_error_regexp = r"None of the .* devices has a node named "
else:
expected_error_regexp = (
r"Node \'foo_bar\' does not exist in the partition graph of device")
with self.assertRaisesRegexp(ValueError, expected_error_regexp):
results.dump.node_op_type("foo_bar")
def testDumpStringTensorsWorks(self):
@ -436,9 +440,11 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
# Verify dump files
self.assertTrue(os.path.isdir(self._dump_root))
self.assertTrue(os.path.isdir(os.path.join(self._dump_root, u_namespace)))
self.assertTrue(
os.path.isdir(os.path.join(self._dump_root, v_namespace, "v")))
u_glob_out = glob.glob(os.path.join(self._dump_root, "*", u_namespace))
v_glob_out = glob.glob(os.path.join(
self._dump_root, "*", v_namespace, "v"))
self.assertTrue(os.path.isdir(u_glob_out[0]))
self.assertTrue(os.path.isdir(v_glob_out[0]))
dump = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=run_metadata.partition_graphs)
@ -688,7 +694,11 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
u_read_name = u_name + "/read"
# Test node name list lookup of the DebugDumpDir object.
node_names = dump.nodes()
if test_util.gpu_device_name():
node_names = dump.nodes(
device_name="/job:localhost/replica:0/task:0/gpu:0")
else:
node_names = dump.nodes()
self.assertTrue(u_name in node_names)
self.assertTrue(u_read_name in node_names)
@ -698,7 +708,11 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(1, len(u_attr["shape"].shape.dim))
self.assertEqual(2, u_attr["shape"].shape.dim[0].size)
with self.assertRaisesRegexp(ValueError, "No node named \"foo\" exists"):
if test_util.gpu_device_name():
expected_error_regexp = r"None of the .* devices has a node named "
else:
expected_error_regexp = r"No node named \"foo\" exists"
with self.assertRaisesRegexp(ValueError, expected_error_regexp):
dump.node_attributes("foo")
def testGraphStructureLookupGivesDebugWatchKeys(self):
@ -721,7 +735,6 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(0, u_data[0].output_slot)
self.assertEqual("DebugIdentity", u_data[0].debug_op)
self.assertGreaterEqual(u_data[0].timestamp, 0)
self.assertEqual([], dump.watch_key_to_data("foo"))
def testGraphStructureLookupGivesNodeInputsAndRecipients(self):
@ -752,12 +765,13 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual([], dump.node_recipients(w_name, is_control=True))
# Test errors raised on invalid node names.
with self.assertRaisesRegexp(ValueError,
"does not exist in partition graphs"):
if test_util.gpu_device_name():
expected_error_regexp = r"None of the .* devices has a node named "
else:
expected_error_regexp = "does not exist in the partition graph of device "
with self.assertRaisesRegexp(ValueError, expected_error_regexp):
dump.node_inputs(u_name + "foo")
with self.assertRaisesRegexp(ValueError,
"does not exist in partition graphs"):
with self.assertRaisesRegexp(ValueError, expected_error_regexp):
dump.node_recipients(u_name + "foo")
# Test transitive_inputs().
@ -768,8 +782,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(
set([u_name, u_read_name, v_name]), set(dump.transitive_inputs(w_name)))
with self.assertRaisesRegexp(ValueError,
"does not exist in partition graphs"):
with self.assertRaisesRegexp(ValueError, expected_error_regexp):
dump.transitive_inputs(u_name + "foo")
def testGraphStructureLookupWithoutPartitionGraphsDoesNotErrorOut(self):
@ -1066,10 +1079,12 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
y = array_ops.squeeze(ph, name="mismatch/y")
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
debug_utils.watch_graph(
run_options, sess.graph, debug_urls=self._debug_urls(), global_step=1)
sess.run(x, feed_dict={ph: np.array([[7.0, 8.0]])}, options=run_options)
sess.run(x, feed_dict={ph: np.array([[7.0, 8.0]])}, options=run_options,
run_metadata=run_metadata)
dump1 = debug_data.DebugDumpDir(self._dump_root)
self.assertEqual(1, dump1.core_metadata.global_step)
self.assertGreaterEqual(dump1.core_metadata.session_run_index, 0)