mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Make graph_callable compatible with functions that do not return anything
PiperOrigin-RevId: 171067061
This commit is contained in:
parent
39565c0cbc
commit
4486b4f69b
|
|
@ -324,7 +324,9 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
|||
captures):
|
||||
func_outputs = func(*func_inputs)
|
||||
outputs_list = nest.flatten(func_outputs)
|
||||
output_shapes = [x.shape for x in outputs_list if x is not None]
|
||||
if len(outputs_list) == 1 and outputs_list[0] is None:
|
||||
outputs_list = []
|
||||
output_shapes = [x.shape for x in outputs_list]
|
||||
if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
|
||||
raise ValueError("Found non-tensor output in %s" % str(outputs_list))
|
||||
initializing_operations = tmp_graph.get_operations()
|
||||
|
|
@ -420,6 +422,9 @@ def graph_callable(shape_and_dtypes):
|
|||
Note that the wrapped function is not allowed to change the values of the
|
||||
variables, just use them.
|
||||
|
||||
The return value of the wrapped function must be one of the following:
|
||||
(1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -45,6 +45,29 @@ class GraphCallableTest(test.TestCase):
|
|||
self.assertEqual(
|
||||
3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
def testFunctionWithoutReturnValue(self):
|
||||
|
||||
@graph_callable.graph_callable(
|
||||
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||
def my_function(x):
|
||||
v = variable_scope.get_variable(
|
||||
"v", initializer=init_ops.zeros_initializer(), shape=())
|
||||
v.assign(x)
|
||||
|
||||
my_function(constant_op.constant(4, dtype=dtypes.float32))
|
||||
self.assertEqual(4, my_function.variables[0].read_value().numpy())
|
||||
|
||||
def testFunctionWithoutReturnValueAndArgs(self):
|
||||
|
||||
@graph_callable.graph_callable([])
|
||||
def my_function():
|
||||
v = variable_scope.get_variable(
|
||||
"v", initializer=init_ops.zeros_initializer(), shape=())
|
||||
v.assign(4)
|
||||
|
||||
my_function()
|
||||
self.assertEqual(4, my_function.variables[0].read_value().numpy())
|
||||
|
||||
def testVariableAPI(self):
|
||||
|
||||
@graph_callable.graph_callable(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user