Split up session_test.py -> session_list_devices_test.py

session_test.py has gotten very large. Additionally, recently it has become
flaky. In order to both (1) improve overall code health, and (2) to facilitate
root-causing the test flakiness, this CL begins to split apart session_test
into focused subsets.

I've suffixed the scoping of the session_test in order to preserve filesystem
sort-order grouping.

PiperOrigin-RevId: 157640788
This commit is contained in:
Brennan Saeta 2017-05-31 15:05:07 -07:00 committed by TensorFlower Gardener
parent 8e868cf6a1
commit d310de4fac
3 changed files with 94 additions and 38 deletions

View File

@ -2945,7 +2945,7 @@ tf_cuda_library(
# Disabled due to http://b/62145493 # Disabled due to http://b/62145493
# py_test( # py_test(
# name = "session_test", # name = "session_test",
# size = "medium", # http://62144199 # size = "medium", # http://b/62144199
# srcs = ["client/session_test.py"], # srcs = ["client/session_test.py"],
# srcs_version = "PY2AND3", # srcs_version = "PY2AND3",
# tags = [ # tags = [
@ -2975,6 +2975,25 @@ tf_cuda_library(
# ], # ],
# ) # )
py_test(
name = "session_list_devices_test",
size = "small",
srcs = ["client/session_list_devices_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_gpu",
],
deps = [
":client",
":framework",
":framework_test_lib",
":platform_test",
":training",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
],
)
cuda_py_test( cuda_py_test(
name = "timeline_test", name = "timeline_test",
size = "small", size = "small",

View File

@ -0,0 +1,74 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.client.session.Session's list_devices API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
ops._USE_C_API = True
class SessionListDevicesTest(test_util.TensorFlowTestCase):
@test_util.disable_c_api # list_devices doesn't work with C API
def testListDevices(self):
with session.Session() as sess:
devices = sess.list_devices()
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
self.assertGreaterEqual(1, len(devices), devices)
@test_util.disable_c_api # list_devices doesn't work with C API
def testListDevicesGrpcSession(self):
server = server_lib.Server.create_local_server()
with session.Session(server.target) as sess:
devices = sess.list_devices()
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
self.assertGreaterEqual(1, len(devices), devices)
@test_util.disable_c_api # list_devices doesn't work with C API
def testListDevicesClusterSpecPropagation(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server1.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with session.Session(server1.target, config=config) as sess:
devices = sess.list_devices()
device_names = set([d.name for d in devices])
self.assertTrue(
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
self.assertTrue(
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
self.assertGreaterEqual(2, len(devices), devices)
if __name__ == '__main__':
googletest.main()

View File

@ -1970,43 +1970,6 @@ class SessionTest(test_util.TensorFlowTestCase):
str_repr = '%s' % attrs str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
@test_util.disable_c_api # list_devices doesn't work with C API
def testListDevices(self):
with session.Session() as sess:
devices = sess.list_devices()
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
self.assertGreaterEqual(1, len(devices), devices)
@test_util.disable_c_api # list_devices doesn't work with C API
def testListDevicesGrpcSession(self):
server = server_lib.Server.create_local_server()
with session.Session(server.target) as sess:
devices = sess.list_devices()
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
self.assertGreaterEqual(1, len(devices), devices)
@test_util.disable_c_api # list_devices doesn't work with C API
def testListDevicesClusterSpecPropagation(self):
server1 = server_lib.Server.create_local_server()
server2 = server_lib.Server.create_local_server()
cluster_def = cluster_pb2.ClusterDef()
job = cluster_def.job.add()
job.name = 'worker'
job.tasks[0] = server1.target[len('grpc://'):]
job.tasks[1] = server2.target[len('grpc://'):]
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with session.Session(server1.target, config=config) as sess:
devices = sess.list_devices()
device_names = set([d.name for d in devices])
self.assertTrue(
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
self.assertTrue(
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
self.assertGreaterEqual(2, len(devices), devices)
class PartialRunTest(test_util.TensorFlowTestCase): class PartialRunTest(test_util.TensorFlowTestCase):