mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add fp16 support to convolutional layers that support it.
PiperOrigin-RevId: 158086284
This commit is contained in:
parent
7d3fbba48f
commit
cc1a02d37a
|
|
@ -371,6 +371,7 @@ def conv1d(inputs,
|
|||
activity_regularizer=activity_regularizer,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
dtype=inputs.dtype.base_dtype,
|
||||
_reuse=reuse,
|
||||
_scope=name)
|
||||
return layer.apply(inputs)
|
||||
|
|
@ -546,6 +547,7 @@ def conv2d(inputs,
|
|||
activity_regularizer=activity_regularizer,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
dtype=inputs.dtype.base_dtype,
|
||||
_reuse=reuse,
|
||||
_scope=name)
|
||||
return layer.apply(inputs)
|
||||
|
|
@ -1277,6 +1279,7 @@ def conv2d_transpose(inputs,
|
|||
activity_regularizer=activity_regularizer,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
dtype=inputs.dtype.base_dtype,
|
||||
_reuse=reuse,
|
||||
_scope=name)
|
||||
return layer.apply(inputs)
|
||||
|
|
|
|||
|
|
@ -69,6 +69,13 @@ class ConvTest(test.TestCase):
|
|||
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 4, 32])
|
||||
self.assertListEqual(layer.bias.get_shape().as_list(), [32])
|
||||
|
||||
def testConv2DFloat16(self):
|
||||
height, width = 7, 9
|
||||
images = random_ops.random_uniform((5, height, width, 4), dtype='float16')
|
||||
output = conv_layers.conv2d(images, 32, [3, 3], activation=nn_ops.relu)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[5, height - 2, width - 2, 32])
|
||||
|
||||
def testCreateConv2DIntegerKernelSize(self):
|
||||
height, width = 7, 9
|
||||
images = random_ops.random_uniform((5, height, width, 4))
|
||||
|
|
@ -144,6 +151,12 @@ class ConvTest(test.TestCase):
|
|||
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 4, 32])
|
||||
self.assertListEqual(layer.bias.get_shape().as_list(), [32])
|
||||
|
||||
def testConv1DFloat16(self):
|
||||
width = 7
|
||||
data = random_ops.random_uniform((5, width, 4), dtype='float16')
|
||||
output = conv_layers.conv1d(data, 32, 3, activation=nn_ops.relu)
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, width - 2, 32])
|
||||
|
||||
def testCreateConv1DChannelsFirst(self):
|
||||
width = 7
|
||||
data = random_ops.random_uniform((5, 4, width))
|
||||
|
|
@ -522,6 +535,14 @@ class Conv2DTransposeTest(test.TestCase):
|
|||
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 32, 4])
|
||||
self.assertListEqual(layer.bias.get_shape().as_list(), [32])
|
||||
|
||||
def testConv2DTransposeFloat16(self):
|
||||
height, width = 7, 9
|
||||
images = random_ops.random_uniform((5, height, width, 4), dtype='float16')
|
||||
output = conv_layers.conv2d_transpose(images, 32, [3, 3],
|
||||
activation=nn_ops.relu)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[5, height + 2, width + 2, 32])
|
||||
|
||||
def testCreateConv2DTransposeIntegerKernelSize(self):
|
||||
height, width = 7, 9
|
||||
images = random_ops.random_uniform((5, height, width, 4))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user