mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
lib.model.layers.KResizeImages - Add arbritary resize (tf)
This commit is contained in:
parent
514f3d60a7
commit
3c9a0f9e53
|
|
@ -196,7 +196,7 @@ class KResizeImages(Layer):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
size: int, optional
|
||||
size: int or float, optional
|
||||
The scale to upsample to. Default: `2`
|
||||
interpolation: ["nearest", "bilinear"], optional
|
||||
The interpolation to use. Default: `"nearest"`
|
||||
|
|
@ -223,11 +223,20 @@ class KResizeImages(Layer):
|
|||
tensor
|
||||
A tensor or list/tuple of tensors
|
||||
"""
|
||||
return K.resize_images(inputs,
|
||||
self.size,
|
||||
self.size,
|
||||
"channels_last",
|
||||
interpolation=self.interpolation)
|
||||
if isinstance(self.size, int):
|
||||
retval = K.resize_images(inputs,
|
||||
self.size,
|
||||
self.size,
|
||||
"channels_last",
|
||||
interpolation=self.interpolation)
|
||||
else:
|
||||
# Arbitrary resizing
|
||||
size = int(round(K.int_shape(inputs)[1] * self.size))
|
||||
if get_backend() != "amd":
|
||||
retval = tf.image.resize(inputs, (size, size), method=self.interpolation)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return retval
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
"""Computes the output shape of the layer.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user