mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Bugfix: number of input channels is not necessarily in the last dimension, after introduction of data_format param.
PiperOrigin-RevId: 164889729
This commit is contained in:
parent
8f9b1af8ae
commit
58c4a4cb1b
|
|
@ -2548,7 +2548,8 @@ def separable_convolution2d(
|
||||||
dtype = inputs.dtype.base_dtype
|
dtype = inputs.dtype.base_dtype
|
||||||
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
|
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
|
||||||
stride_h, stride_w = utils.two_element_tuple(stride)
|
stride_h, stride_w = utils.two_element_tuple(stride)
|
||||||
num_filters_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
|
num_filters_in = utils.channel_dimension(
|
||||||
|
inputs.get_shape(), df, min_rank=4)
|
||||||
weights_collections = utils.get_variable_collections(
|
weights_collections = utils.get_variable_collections(
|
||||||
variables_collections, 'weights')
|
variables_collections, 'weights')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3230,12 +3230,13 @@ class SeparableConv2dTest(test.TestCase):
|
||||||
def testConvNCHW(self):
|
def testConvNCHW(self):
|
||||||
for num_filters, correct_output_filters in [(None, 6), (8, 8)]:
|
for num_filters, correct_output_filters in [(None, 6), (8, 8)]:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
height, width = 3, 3
|
batch, height, width = 4, 5, 6
|
||||||
images = random_ops.random_uniform((5, 3, height, width), seed=1)
|
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
|
||||||
output = layers_lib.separable_conv2d(
|
output = layers_lib.separable_conv2d(
|
||||||
images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW')
|
images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW')
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
output.get_shape().as_list(), [5, correct_output_filters, 1, 1])
|
output.get_shape().as_list(), [batch, correct_output_filters,
|
||||||
|
height - 2, width - 2])
|
||||||
|
|
||||||
|
|
||||||
class ScaleGradientTests(test.TestCase):
|
class ScaleGradientTests(test.TestCase):
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,8 @@ __all__ = ['collect_named_outputs',
|
||||||
'get_variable_collections',
|
'get_variable_collections',
|
||||||
'two_element_tuple',
|
'two_element_tuple',
|
||||||
'n_positive_integers',
|
'n_positive_integers',
|
||||||
'last_dimension',
|
'channel_dimension',
|
||||||
'first_dimension']
|
'last_dimension']
|
||||||
|
|
||||||
NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs'])
|
NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs'])
|
||||||
|
|
||||||
|
|
@ -220,15 +220,16 @@ def get_variable_collections(variables_collections, name):
|
||||||
return variable_collections
|
return variable_collections
|
||||||
|
|
||||||
|
|
||||||
def first_dimension(shape, min_rank=1):
|
def _get_dimension(shape, dim, min_rank=1):
|
||||||
"""Returns the first dimension of shape while checking it has min_rank.
|
"""Returns the `dim` dimension of `shape`, while checking it has `min_rank`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: A `TensorShape`.
|
shape: A `TensorShape`.
|
||||||
|
dim: Integer, which dimension to return.
|
||||||
min_rank: Integer, minimum rank of shape.
|
min_rank: Integer, minimum rank of shape.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The value of the first dimension.
|
The value of the `dim` dimension.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
||||||
|
|
@ -240,12 +241,32 @@ def first_dimension(shape, min_rank=1):
|
||||||
if len(dims) < min_rank:
|
if len(dims) < min_rank:
|
||||||
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
|
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
|
||||||
len(dims)))
|
len(dims)))
|
||||||
value = dims[0].value
|
value = dims[dim].value
|
||||||
if value is None:
|
if value is None:
|
||||||
raise ValueError('first dimension shape must be known but is None')
|
raise ValueError(
|
||||||
|
'dimension %d of shape must be known but is None: %s' % (dim, shape))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def channel_dimension(shape, data_format, min_rank=1):
|
||||||
|
"""Returns the channel dimension of shape, while checking it has min_rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: A `TensorShape`.
|
||||||
|
data_format: `channels_first` or `channels_last`.
|
||||||
|
min_rank: Integer, minimum rank of shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the first dimension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
||||||
|
first dimension value is not defined.
|
||||||
|
"""
|
||||||
|
return _get_dimension(shape, 1 if data_format == 'channels_first' else -1,
|
||||||
|
min_rank=min_rank)
|
||||||
|
|
||||||
|
|
||||||
def last_dimension(shape, min_rank=1):
|
def last_dimension(shape, min_rank=1):
|
||||||
"""Returns the last dimension of shape while checking it has min_rank.
|
"""Returns the last dimension of shape while checking it has min_rank.
|
||||||
|
|
||||||
|
|
@ -260,16 +281,7 @@ def last_dimension(shape, min_rank=1):
|
||||||
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
||||||
last dimension value is not defined.
|
last dimension value is not defined.
|
||||||
"""
|
"""
|
||||||
dims = shape.dims
|
return _get_dimension(shape, -1, min_rank=min_rank)
|
||||||
if dims is None:
|
|
||||||
raise ValueError('dims of shape must be known but is None')
|
|
||||||
if len(dims) < min_rank:
|
|
||||||
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
|
|
||||||
len(dims)))
|
|
||||||
value = dims[-1].value
|
|
||||||
if value is None:
|
|
||||||
raise ValueError('last dimension shape must be known but is None')
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def two_element_tuple(int_or_tuple):
|
def two_element_tuple(int_or_tuple):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user