mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Properly instantiate var handles when running eagerly in tests.
PiperOrigin-RevId: 413194043 Change-Id: I3f10eb0ba2caf8767c1d6b464f2ca38ca0a2cabf
This commit is contained in:
parent
0fe244367c
commit
ce2a70df4f
|
|
@ -58,6 +58,17 @@ from tensorflow.python.training import training_util
|
|||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def _eager_safe_var_handle_op(*args, **kwargs):
|
||||
# When running in eager mode the `shared_name` should be set to the
|
||||
# `anonymous_name` to avoid spurious sharing issues. The runtime generates a
|
||||
# unique name on our behalf when the reserved `anonymous_name` is used as the
|
||||
# `shared_name`.
|
||||
if context.executing_eagerly() and "shared_name" not in kwargs:
|
||||
kwargs["shared_name"] = context.anonymous_name()
|
||||
return resource_variable_ops.var_handle_op(*args, **kwargs)
|
||||
|
||||
|
||||
@test_util.with_eager_op_as_function
|
||||
@test_util.with_control_flow_v2
|
||||
class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
|
@ -72,7 +83,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
@test_util.run_deprecated_v1
|
||||
def testHandleDtypeShapeMatch(self):
|
||||
with self.cached_session():
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
with self.assertRaises(ValueError):
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
|
||||
|
|
@ -114,7 +125,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
def testReadVariableDtypeMismatchEager(self):
|
||||
with context.eager_mode():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
handle = _eager_safe_var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
resource_variable_ops.assign_variable_op(handle, 1)
|
||||
with self.assertRaisesRegex(
|
||||
|
|
@ -201,7 +212,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
@test_util.run_deprecated_v1
|
||||
def testFetchHandle(self):
|
||||
with self.cached_session():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
handle = _eager_safe_var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
self.assertNotEmpty(self.evaluate(handle))
|
||||
|
||||
|
|
@ -215,7 +226,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
def testAssignVariableDtypeMismatchEager(self):
|
||||
with context.eager_mode():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
handle = _eager_safe_var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([1]))
|
||||
|
|
@ -247,14 +258,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
def testFormatResourceHandle(self):
|
||||
with context.eager_mode():
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
handle = _eager_safe_var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
self.assertIn("<Resource Tensor>", str(handle))
|
||||
self.assertIn("<Resource Tensor>", repr(handle))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDtypeSurvivesIdentity(self):
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
id_handle = array_ops.identity(handle)
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
|
||||
|
|
@ -265,7 +276,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCreateRead(self):
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||
value = self.evaluate(
|
||||
|
|
@ -274,7 +285,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testManyAssigns(self):
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
create = resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32))
|
||||
with ops.control_dependencies([create]):
|
||||
|
|
@ -292,7 +303,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssignAdd(self):
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||
self.evaluate(resource_variable_ops.assign_add_variable_op(
|
||||
|
|
@ -303,8 +314,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterAdd(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
|
|
@ -384,8 +394,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterSub(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
|
|
@ -397,8 +406,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterMul(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
|
|
@ -429,8 +437,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterDiv(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||
|
|
@ -452,8 +459,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterMin(self):
|
||||
with ops.device("cpu:0"):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(handle,
|
||||
constant_op.constant(
|
||||
|
|
@ -488,8 +494,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterMax(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||
|
|
@ -501,8 +506,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterAddScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
|
|
@ -514,8 +518,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterSubScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
|
|
@ -527,8 +530,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterMulScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||
|
|
@ -540,8 +542,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterDivScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||
|
|
@ -553,8 +554,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterMinScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||
|
|
@ -566,8 +566,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testScatterMaxScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||
|
|
@ -690,8 +689,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testScatterUpdateString(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.string, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
|
||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([["a"]], dtype=dtypes.string)))
|
||||
self.evaluate(resource_variable_ops.resource_scatter_update(
|
||||
|
|
@ -702,8 +700,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testScatterUpdateStringScalar(self):
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.string, shape=[1, 1])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
|
||||
self.evaluate(
|
||||
resource_variable_ops.assign_variable_op(handle,
|
||||
constant_op.constant(
|
||||
|
|
@ -1016,7 +1013,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
self.evaluate(v.value())
|
||||
# Handle to a resource not actually created.
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
# Should raise no exception
|
||||
self.evaluate(resource_variable_ops.destroy_resource_op(
|
||||
handle, ignore_lookup_error=True))
|
||||
|
|
@ -1136,15 +1133,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
|
||||
w = _eager_safe_var_handle_op(
|
||||
dtype=v.dtype.base_dtype,
|
||||
shape=v.get_shape(),
|
||||
shared_name="var4",
|
||||
# Needed in Eager since we get a unique container name by default.
|
||||
container=ops.get_default_graph()._container)
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
self.assertEqual(300.0, self.evaluate(w_read))
|
||||
|
||||
x = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
||||
x = _eager_safe_var_handle_op(
|
||||
dtype=v.dtype.base_dtype,
|
||||
shape=v.get_shape(),
|
||||
shared_name="var5",
|
||||
container=ops.get_default_graph()._container)
|
||||
with self.assertRaisesOpError(
|
||||
"(Resource .*/var5/.* does not exist|uninitialized)"):
|
||||
|
|
@ -1159,8 +1160,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
self.assertEqual("foo/var6:0", v.name)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
w = resource_variable_ops.var_handle_op(
|
||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
|
||||
w = _eager_safe_var_handle_op(
|
||||
dtype=v.dtype.base_dtype,
|
||||
shape=v.get_shape(),
|
||||
shared_name="foo/var6",
|
||||
# Needed in Eager since we get a unique container name by default.
|
||||
container=ops.get_default_graph()._container)
|
||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||
|
|
@ -1366,11 +1369,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||
def testScatterUpdateInvalidArgs(self):
|
||||
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
|
||||
# The exact error and message differ between graph construction (where the
|
||||
# error is realized during shape inference at graph construction time) and
|
||||
# eager execution (where the error is realized during kernel execution).
|
||||
with self.assertRaisesRegex(Exception, r"shape.*2.*3"):
|
||||
# error is realized during shape inference at graph construction time),
|
||||
# eager execution (where the error is realized during kernel execution),
|
||||
# and XLA auto-clustering execution (where the error is realized in the xla
|
||||
# op kernel) which is triggered when running in eager op as function mode.
|
||||
with self.assertRaisesRegex(Exception, r"shape.*2.*3|RET_CHECK failure"):
|
||||
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
|
||||
|
||||
@test_util.disable_xla("b/208334252") # XLA doesn't have a deterministic impl
|
||||
def testScatterAddDeterministic(self):
|
||||
with context.eager_mode(), test_util.deterministic_ops():
|
||||
# Normally a nondeterministic codepath occurs when the variable has at
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user