Speed up atrous_convolution_test by combining evaluations.

To make this test run faster (and prevent it from timing out under
certain circumstances), this change combines all evaluations for each
test method into a single call to Session.run, to eliminate overhead.

This reduces the test time from about 40 seconds to 10 seconds.

RELNOTES: n/a
PiperOrigin-RevId: 158175227
This commit is contained in:
A. Unique TensorFlower 2017-06-06 12:36:14 -07:00 committed by TensorFlower Gardener
parent 81cf61fdb6
commit b65eb3f9b5

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import numpy as np
from tensorflow.python.framework import constant_op
@ -53,124 +55,170 @@ def upsample_filters(filters, rate):
return output
@contextlib.contextmanager
def delay_checks(sess):
"""Context manager for combining checks depending on tensor evaluations.
Each call to Session.run has some overhead, and this overhead can easily
account for the majority of the time spent in tests that call Session.run (or
Tensor.eval) many times.
This context manager provides a mechanism for registering callback functions
and associated tensors. When the context is exited, all of the tensors
associated with all of the registrations are evaluated with a single call to
Session.run, and then each registered callback function is called with the
values of its associated tensors.
Args:
sess: The session to use to evaluate the tensors.
Yields:
A function `add_check(check, *args, **kwargs)` where `check` is the
callback function to be invoked, and `*args` and `**kwargs` specify the
associated Tensors.
"""
checks = []
def add_check(check, *args, **kwargs):
checks.append((check, args, kwargs))
yield add_check
all_values = sess.run([[args, kwargs] for _, args, kwargs in checks])
for (check, _, _), (args, kwargs) in zip(checks, all_values):
check(*args, **kwargs)
class AtrousConvolutionTest(test.TestCase):
def _test_atrous_convolution(self, input_shape, filter_shape, dilation_rate,
**kwargs):
filters = np.arange(
np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
def _test_atrous_convolution(self, add_check, input_shape, filter_shape,
dilation_rate, **kwargs):
filters = np.arange(np.prod(filter_shape),
dtype=np.float32).reshape(filter_shape)
filters_upsampled = upsample_filters(filters, dilation_rate)
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
y1 = nn_ops.convolution(
input=x, filter=filters, dilation_rate=dilation_rate, **kwargs)
y1 = nn_ops.convolution(input=x, filter=filters,
dilation_rate=dilation_rate, **kwargs)
y2 = nn_ops.convolution(input=x, filter=filters_upsampled, **kwargs)
self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2)
def check(y1_eval, y2_eval):
self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2)
add_check(check, y1, y2)
def testAtrousConvolution2D(self):
with self.test_session():
for padding in ["SAME", "VALID"]:
for height, width in [[9, 9], [9, 10]]:
for kernel_height, kernel_width in [[1, 1], [2, 2], [2, 3]]:
for dilation_rate in [[1, 1], [3, 2], [2, 1]]:
self._test_atrous_convolution(
input_shape=[2, height, width, 2],
filter_shape=[kernel_height, kernel_width, 2, 2],
padding=padding,
dilation_rate=dilation_rate)
with self.test_session() as sess:
with delay_checks(sess) as add_check:
for padding in ["SAME", "VALID"]:
for height, width in [[9, 9], [9, 10]]:
for kernel_height, kernel_width in [[1, 1], [2, 2], [2, 3]]:
for dilation_rate in [[1, 1], [3, 2], [2, 1]]:
self._test_atrous_convolution(
add_check=add_check,
input_shape=[2, height, width, 2],
filter_shape=[kernel_height, kernel_width, 2, 2],
padding=padding,
dilation_rate=dilation_rate,
)
def testAtrousConvolution3D(self):
with self.test_session():
for padding in ["SAME", "VALID"]:
for depth, height, width in [[9, 9, 10], [9, 10, 9]]:
for kernel_depth, kernel_height, kernel_width in [[3, 3, 3],
[3, 2, 2],
[2, 1, 3]]:
for dilation_rate in [[1, 1, 1], [3, 3, 3], [3, 2, 3], [3, 1, 2]]:
self._test_atrous_convolution(
input_shape=[2, depth, height, width, 2],
filter_shape=[
kernel_depth, kernel_height, kernel_width, 2, 2
],
padding=padding,
dilation_rate=dilation_rate)
with self.test_session() as sess:
with delay_checks(sess) as add_check:
for padding in ["SAME", "VALID"]:
for depth, height, width in [[9, 9, 10], [9, 10, 9]]:
for kernel_depth, kernel_height, kernel_width in [[3, 3,
3], [3, 2, 2],
[2, 1, 3]]:
for dilation_rate in [[1, 1, 1], [3, 3, 3], [3, 2, 3], [3, 1, 2]]:
self._test_atrous_convolution(
add_check=add_check,
input_shape=[2, depth, height, width, 2],
filter_shape=[
kernel_depth, kernel_height, kernel_width, 2, 2
],
padding=padding,
dilation_rate=dilation_rate,
)
def testAtrousConvolution1D(self):
with self.test_session():
for padding in ["SAME", "VALID"]:
for width in [9, 10]:
for kernel_width in range(1, 4):
for rate in range(1, 4):
self._test_atrous_convolution(
input_shape=[2, width, 2],
filter_shape=[kernel_width, 2, 2],
padding=padding,
dilation_rate=[rate])
with self.test_session() as sess:
with delay_checks(sess) as add_check:
for padding in ["SAME", "VALID"]:
for width in [9, 10]:
for kernel_width in range(1, 4):
for rate in range(1, 4):
self._test_atrous_convolution(
add_check=add_check,
input_shape=[2, width, 2],
filter_shape=[kernel_width, 2, 2],
padding=padding,
dilation_rate=[rate],
)
def testAtrousConvolutionNC(self):
if test.is_gpu_available(cuda_only=True):
# "NCW" and "NCHW" formats are currently supported only on CUDA.
with self.test_session(use_gpu=True):
for padding in ["SAME", "VALID"]:
self._test_atrous_convolution(
input_shape=[2, 2, 9],
padding=padding,
filter_shape=[3, 2, 2],
dilation_rate=[2],
data_format="NCW")
self._test_atrous_convolution(
input_shape=[2, 2, 9, 5],
padding=padding,
filter_shape=[3, 3, 2, 2],
dilation_rate=[2, 1],
data_format="NCHW")
with self.test_session(use_gpu=True) as sess:
with delay_checks(sess) as add_check:
for padding in ["SAME", "VALID"]:
self._test_atrous_convolution(
add_check=add_check,
input_shape=[2, 2, 9],
padding=padding,
filter_shape=[3, 2, 2],
dilation_rate=[2],
data_format="NCW",
)
self._test_atrous_convolution(
add_check=add_check,
input_shape=[2, 2, 9, 5],
padding=padding,
filter_shape=[3, 3, 2, 2],
dilation_rate=[2, 1],
data_format="NCHW",
)
def testAtrousSequence(self):
"""Tests optimization of sequence of atrous convolutions.
See the documentation of with_space_to_batch.
"""
with self.test_session():
for padding in ["SAME", "VALID"]:
for height in range(15, 17):
for width in range(15, 17):
x_shape = [3, height, width, 2]
x = np.random.random_sample(x_shape).astype(np.float32)
with self.test_session() as sess:
with delay_checks(sess) as add_check:
for padding in ["SAME", "VALID"]:
for height in range(15, 17):
for width in range(15, 17):
x_shape = [3, height, width, 2]
x = np.random.random_sample(x_shape).astype(np.float32)
kernel_sizes = [1, 3] if padding == "SAME" else range(1, 3)
for kernel in kernel_sizes:
f_shape = [kernel, kernel, 2, 2]
f1 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
f2 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
kernel_sizes = [1, 3] if padding == "SAME" else range(1, 3)
for kernel in kernel_sizes:
f_shape = [kernel, kernel, 2, 2]
f1 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
f2 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
def combined_op(converted_input, num_spatial_dims, padding_arg): # pylint: disable=unused-argument
result = nn_ops.convolution(
input=converted_input, filter=f1,
padding=padding) # pylint: disable=cell-var-from-loop
result = nn_ops.convolution(
input=result, filter=f2,
padding=padding) # pylint: disable=cell-var-from-loop
return result
def combined_op(converted_input, num_spatial_dims, padding_arg): # pylint: disable=unused-argument
# pylint: disable=cell-var-from-loop
result = nn_ops.convolution(input=converted_input, filter=f1,
padding=padding)
result = nn_ops.convolution(input=result, filter=f2,
padding=padding)
# pylint: enable=cell-var-from-loop
return result
for rate_height in range(2, 4):
for rate_width in range(2, 4):
dilation_rate = [rate_height, rate_width]
y1 = nn_ops.convolution(
input=x,
filter=f1,
padding=padding,
dilation_rate=dilation_rate)
y1 = nn_ops.convolution(
input=y1,
filter=f2,
padding=padding,
dilation_rate=dilation_rate)
y2 = nn_ops.with_space_to_batch(
input=x,
dilation_rate=dilation_rate,
op=combined_op,
padding="VALID")
self.assertAllClose(
y1.eval(), y2.eval(), rtol=1e-2, atol=1e-2)
for rate_height in range(2, 4):
for rate_width in range(2, 4):
dilation_rate = [rate_height, rate_width]
y1 = nn_ops.convolution(input=x, filter=f1, padding=padding,
dilation_rate=dilation_rate)
y1 = nn_ops.convolution(input=y1, filter=f2,
padding=padding,
dilation_rate=dilation_rate)
y2 = nn_ops.with_space_to_batch(
input=x, dilation_rate=dilation_rate, op=combined_op,
padding="VALID")
def check(y1_eval, y2_eval):
self.assertAllClose(y1_eval, y2_eval, rtol=1e-2,
atol=1e-2)
add_check(check, y1, y2)
def _test_gradient(self, x_shape, f_shape, dilation_rate, padding):
x_val = np.random.random_sample(x_shape).astype(np.float32)