mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Split up session_test.py -> session_list_devices_test.py
session_test.py has gotten very large. Additionally, recently it has become flaky. In order to both (1) improve overall code health, and (2) to facilitate root-causing the test flakiness, this CL begins to split apart session_test into focused subsets. I've suffixed the scoping of the session_test in order to preserve filesystem sort-order grouping. PiperOrigin-RevId: 157640788
This commit is contained in:
parent
8e868cf6a1
commit
d310de4fac
|
|
@ -2945,7 +2945,7 @@ tf_cuda_library(
|
||||||
# Disabled due to http://b/62145493
|
# Disabled due to http://b/62145493
|
||||||
# py_test(
|
# py_test(
|
||||||
# name = "session_test",
|
# name = "session_test",
|
||||||
# size = "medium", # http://62144199
|
# size = "medium", # http://b/62144199
|
||||||
# srcs = ["client/session_test.py"],
|
# srcs = ["client/session_test.py"],
|
||||||
# srcs_version = "PY2AND3",
|
# srcs_version = "PY2AND3",
|
||||||
# tags = [
|
# tags = [
|
||||||
|
|
@ -2975,6 +2975,25 @@ tf_cuda_library(
|
||||||
# ],
|
# ],
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "session_list_devices_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["client/session_list_devices_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = [
|
||||||
|
"no_gpu",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":client",
|
||||||
|
":framework",
|
||||||
|
":framework_test_lib",
|
||||||
|
":platform_test",
|
||||||
|
":training",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "timeline_test",
|
name = "timeline_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
|
||||||
74
tensorflow/python/client/session_list_devices_test.py
Normal file
74
tensorflow/python/client/session_list_devices_test.py
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for tensorflow.python.client.session.Session's list_devices API."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import cluster_pb2
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
|
ops._USE_C_API = True
|
||||||
|
|
||||||
|
|
||||||
|
class SessionListDevicesTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.disable_c_api # list_devices doesn't work with C API
|
||||||
|
def testListDevices(self):
|
||||||
|
with session.Session() as sess:
|
||||||
|
devices = sess.list_devices()
|
||||||
|
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
|
||||||
|
[d.name for d in devices]), devices)
|
||||||
|
self.assertGreaterEqual(1, len(devices), devices)
|
||||||
|
|
||||||
|
@test_util.disable_c_api # list_devices doesn't work with C API
|
||||||
|
def testListDevicesGrpcSession(self):
|
||||||
|
server = server_lib.Server.create_local_server()
|
||||||
|
with session.Session(server.target) as sess:
|
||||||
|
devices = sess.list_devices()
|
||||||
|
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
|
||||||
|
[d.name for d in devices]), devices)
|
||||||
|
self.assertGreaterEqual(1, len(devices), devices)
|
||||||
|
|
||||||
|
@test_util.disable_c_api # list_devices doesn't work with C API
|
||||||
|
def testListDevicesClusterSpecPropagation(self):
|
||||||
|
server1 = server_lib.Server.create_local_server()
|
||||||
|
server2 = server_lib.Server.create_local_server()
|
||||||
|
|
||||||
|
cluster_def = cluster_pb2.ClusterDef()
|
||||||
|
job = cluster_def.job.add()
|
||||||
|
job.name = 'worker'
|
||||||
|
job.tasks[0] = server1.target[len('grpc://'):]
|
||||||
|
job.tasks[1] = server2.target[len('grpc://'):]
|
||||||
|
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
||||||
|
with session.Session(server1.target, config=config) as sess:
|
||||||
|
devices = sess.list_devices()
|
||||||
|
device_names = set([d.name for d in devices])
|
||||||
|
self.assertTrue(
|
||||||
|
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
|
||||||
|
self.assertTrue(
|
||||||
|
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
|
||||||
|
self.assertGreaterEqual(2, len(devices), devices)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
||||||
|
|
@ -1970,43 +1970,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||||
str_repr = '%s' % attrs
|
str_repr = '%s' % attrs
|
||||||
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
|
||||||
|
|
||||||
@test_util.disable_c_api # list_devices doesn't work with C API
|
|
||||||
def testListDevices(self):
|
|
||||||
with session.Session() as sess:
|
|
||||||
devices = sess.list_devices()
|
|
||||||
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
|
|
||||||
[d.name for d in devices]), devices)
|
|
||||||
self.assertGreaterEqual(1, len(devices), devices)
|
|
||||||
|
|
||||||
@test_util.disable_c_api # list_devices doesn't work with C API
|
|
||||||
def testListDevicesGrpcSession(self):
|
|
||||||
server = server_lib.Server.create_local_server()
|
|
||||||
with session.Session(server.target) as sess:
|
|
||||||
devices = sess.list_devices()
|
|
||||||
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
|
|
||||||
[d.name for d in devices]), devices)
|
|
||||||
self.assertGreaterEqual(1, len(devices), devices)
|
|
||||||
|
|
||||||
@test_util.disable_c_api # list_devices doesn't work with C API
|
|
||||||
def testListDevicesClusterSpecPropagation(self):
|
|
||||||
server1 = server_lib.Server.create_local_server()
|
|
||||||
server2 = server_lib.Server.create_local_server()
|
|
||||||
|
|
||||||
cluster_def = cluster_pb2.ClusterDef()
|
|
||||||
job = cluster_def.job.add()
|
|
||||||
job.name = 'worker'
|
|
||||||
job.tasks[0] = server1.target[len('grpc://'):]
|
|
||||||
job.tasks[1] = server2.target[len('grpc://'):]
|
|
||||||
config = config_pb2.ConfigProto(cluster_def=cluster_def)
|
|
||||||
with session.Session(server1.target, config=config) as sess:
|
|
||||||
devices = sess.list_devices()
|
|
||||||
device_names = set([d.name for d in devices])
|
|
||||||
self.assertTrue(
|
|
||||||
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
|
|
||||||
self.assertTrue(
|
|
||||||
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
|
|
||||||
self.assertGreaterEqual(2, len(devices), devices)
|
|
||||||
|
|
||||||
|
|
||||||
class PartialRunTest(test_util.TensorFlowTestCase):
|
class PartialRunTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user