Make graph_callable compatible with functions that do not return anything

PiperOrigin-RevId: 171067061
This commit is contained in:
Akshay Agrawal 2017-10-04 14:52:13 -07:00 committed by TensorFlower Gardener
parent 39565c0cbc
commit 4486b4f69b
2 changed files with 29 additions and 1 deletions

View File

@ -324,7 +324,9 @@ def _graph_callable_internal(func, shape_and_dtypes):
captures):
func_outputs = func(*func_inputs)
outputs_list = nest.flatten(func_outputs)
output_shapes = [x.shape for x in outputs_list if x is not None]
if len(outputs_list) == 1 and outputs_list[0] is None:
outputs_list = []
output_shapes = [x.shape for x in outputs_list]
if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
raise ValueError("Found non-tensor output in %s" % str(outputs_list))
initializing_operations = tmp_graph.get_operations()
@ -420,6 +422,9 @@ def graph_callable(shape_and_dtypes):
Note that the wrapped function is not allowed to change the values of the
variables, just use them.
The return value of the wrapped function must be one of the following:
(1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
Example:
```python

View File

@ -45,6 +45,29 @@ class GraphCallableTest(test.TestCase):
self.assertEqual(
3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
def testFunctionWithoutReturnValue(self):
@graph_callable.graph_callable(
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
def my_function(x):
v = variable_scope.get_variable(
"v", initializer=init_ops.zeros_initializer(), shape=())
v.assign(x)
my_function(constant_op.constant(4, dtype=dtypes.float32))
self.assertEqual(4, my_function.variables[0].read_value().numpy())
def testFunctionWithoutReturnValueAndArgs(self):
@graph_callable.graph_callable([])
def my_function():
v = variable_scope.get_variable(
"v", initializer=init_ops.zeros_initializer(), shape=())
v.assign(4)
my_function()
self.assertEqual(4, my_function.variables[0].read_value().numpy())
def testVariableAPI(self):
@graph_callable.graph_callable(