mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add check to tf.device when called with a function in eager mode.
PiperOrigin-RevId: 173947845
This commit is contained in:
parent
3639aa7ff1
commit
1d6dae88ef
|
|
@ -4339,11 +4339,18 @@ def device(device_name_or_function):
|
|||
Returns:
|
||||
A context manager that specifies the default device to use for newly
|
||||
created ops.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If eager execution is enabled and a function is passed in.
|
||||
"""
|
||||
if context.in_graph_mode():
|
||||
return get_default_graph().device(device_name_or_function)
|
||||
else:
|
||||
# TODO(agarwal): support device functions in EAGER mode.
|
||||
if callable(device_name_or_function):
|
||||
raise RuntimeError(
|
||||
"tf.device does not support functions when eager execution "
|
||||
"is enabled.")
|
||||
return context.device(device_name_or_function)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user