Add check to tf.device when called with a function in eager mode.

PiperOrigin-RevId: 173947845
This commit is contained in:
A. Unique TensorFlower 2017-10-30 14:16:18 -07:00 committed by TensorFlower Gardener
parent 3639aa7ff1
commit 1d6dae88ef

View File

@ -4339,11 +4339,18 @@ def device(device_name_or_function):
Returns:
A context manager that specifies the default device to use for newly
created ops.
Raises:
RuntimeError: If eager execution is enabled and a function is passed in.
"""
if context.in_graph_mode():
return get_default_graph().device(device_name_or_function)
else:
# TODO(agarwal): support device functions in EAGER mode.
if callable(device_name_or_function):
raise RuntimeError(
"tf.device does not support functions when eager execution "
"is enabled.")
return context.device(device_name_or_function)