mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make graph_callable compatible with functions that do not return anything
PiperOrigin-RevId: 171067061
This commit is contained in:
parent
39565c0cbc
commit
4486b4f69b
|
|
@ -324,7 +324,9 @@ def _graph_callable_internal(func, shape_and_dtypes):
|
||||||
captures):
|
captures):
|
||||||
func_outputs = func(*func_inputs)
|
func_outputs = func(*func_inputs)
|
||||||
outputs_list = nest.flatten(func_outputs)
|
outputs_list = nest.flatten(func_outputs)
|
||||||
output_shapes = [x.shape for x in outputs_list if x is not None]
|
if len(outputs_list) == 1 and outputs_list[0] is None:
|
||||||
|
outputs_list = []
|
||||||
|
output_shapes = [x.shape for x in outputs_list]
|
||||||
if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
|
if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
|
||||||
raise ValueError("Found non-tensor output in %s" % str(outputs_list))
|
raise ValueError("Found non-tensor output in %s" % str(outputs_list))
|
||||||
initializing_operations = tmp_graph.get_operations()
|
initializing_operations = tmp_graph.get_operations()
|
||||||
|
|
@ -420,6 +422,9 @@ def graph_callable(shape_and_dtypes):
|
||||||
Note that the wrapped function is not allowed to change the values of the
|
Note that the wrapped function is not allowed to change the values of the
|
||||||
variables, just use them.
|
variables, just use them.
|
||||||
|
|
||||||
|
The return value of the wrapped function must be one of the following:
|
||||||
|
(1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,29 @@ class GraphCallableTest(test.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
|
3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
def testFunctionWithoutReturnValue(self):
|
||||||
|
|
||||||
|
@graph_callable.graph_callable(
|
||||||
|
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||||
|
def my_function(x):
|
||||||
|
v = variable_scope.get_variable(
|
||||||
|
"v", initializer=init_ops.zeros_initializer(), shape=())
|
||||||
|
v.assign(x)
|
||||||
|
|
||||||
|
my_function(constant_op.constant(4, dtype=dtypes.float32))
|
||||||
|
self.assertEqual(4, my_function.variables[0].read_value().numpy())
|
||||||
|
|
||||||
|
def testFunctionWithoutReturnValueAndArgs(self):
|
||||||
|
|
||||||
|
@graph_callable.graph_callable([])
|
||||||
|
def my_function():
|
||||||
|
v = variable_scope.get_variable(
|
||||||
|
"v", initializer=init_ops.zeros_initializer(), shape=())
|
||||||
|
v.assign(4)
|
||||||
|
|
||||||
|
my_function()
|
||||||
|
self.assertEqual(4, my_function.variables[0].read_value().numpy())
|
||||||
|
|
||||||
def testVariableAPI(self):
|
def testVariableAPI(self):
|
||||||
|
|
||||||
@graph_callable.graph_callable(
|
@graph_callable.graph_callable(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user