Properly instantiate var handles when running eagerly in tests.

PiperOrigin-RevId: 413194043
Change-Id: I3f10eb0ba2caf8767c1d6b464f2ca38ca0a2cabf
This commit is contained in:
Sagun Bajra 2021-11-30 10:57:35 -08:00 committed by TensorFlower Gardener
parent 0fe244367c
commit ce2a70df4f

View File

@ -58,6 +58,17 @@ from tensorflow.python.training import training_util
from tensorflow.python.util import compat
def _eager_safe_var_handle_op(*args, **kwargs):
# When running in eager mode the `shared_name` should be set to the
# `anonymous_name` to avoid spurious sharing issues. The runtime generates a
# unique name on our behalf when the reserved `anonymous_name` is used as the
# `shared_name`.
if context.executing_eagerly() and "shared_name" not in kwargs:
kwargs["shared_name"] = context.anonymous_name()
return resource_variable_ops.var_handle_op(*args, **kwargs)
@test_util.with_eager_op_as_function
@test_util.with_control_flow_v2
class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@ -72,7 +83,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_deprecated_v1
def testHandleDtypeShapeMatch(self):
with self.cached_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
with self.assertRaises(ValueError):
resource_variable_ops.assign_variable_op(
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
@ -114,7 +125,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
def testReadVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
handle = _eager_safe_var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
resource_variable_ops.assign_variable_op(handle, 1)
with self.assertRaisesRegex(
@ -201,7 +212,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_deprecated_v1
def testFetchHandle(self):
with self.cached_session():
handle = resource_variable_ops.var_handle_op(
handle = _eager_safe_var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
self.assertNotEmpty(self.evaluate(handle))
@ -215,7 +226,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
def testAssignVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
handle = _eager_safe_var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([1]))
@ -247,14 +258,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
def testFormatResourceHandle(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
handle = _eager_safe_var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
self.assertIn("<Resource Tensor>", str(handle))
self.assertIn("<Resource Tensor>", repr(handle))
@test_util.run_in_graph_and_eager_modes
def testDtypeSurvivesIdentity(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
self.evaluate(resource_variable_ops.assign_variable_op(
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
@ -265,7 +276,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testCreateRead(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
self.evaluate(resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)))
value = self.evaluate(
@ -274,7 +285,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testManyAssigns(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
create = resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32))
with ops.control_dependencies([create]):
@ -292,7 +303,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testAssignAdd(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
self.evaluate(resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)))
self.evaluate(resource_variable_ops.assign_add_variable_op(
@ -303,8 +314,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterAdd(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
@ -384,8 +394,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterSub(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
@ -397,8 +406,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterMul(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
@ -429,8 +437,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterDiv(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
@ -452,8 +459,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterMin(self):
with ops.device("cpu:0"):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(handle,
constant_op.constant(
@ -488,8 +494,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterMax(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
@ -501,8 +506,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterAddScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
@ -514,8 +518,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterSubScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
@ -527,8 +530,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterMulScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
@ -540,8 +542,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterDivScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
@ -553,8 +554,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterMinScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
@ -566,8 +566,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_in_graph_and_eager_modes
def testScatterMaxScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
@ -690,8 +689,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_deprecated_v1
def testScatterUpdateString(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.string, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
self.evaluate(resource_variable_ops.assign_variable_op(
handle, constant_op.constant([["a"]], dtype=dtypes.string)))
self.evaluate(resource_variable_ops.resource_scatter_update(
@ -702,8 +700,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
@test_util.run_deprecated_v1
def testScatterUpdateStringScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.string, shape=[1, 1])
handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
self.evaluate(
resource_variable_ops.assign_variable_op(handle,
constant_op.constant(
@ -1016,7 +1013,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
with self.assertRaises(errors.FailedPreconditionError):
self.evaluate(v.value())
# Handle to a resource not actually created.
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
# Should raise no exception
self.evaluate(resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=True))
@ -1136,15 +1133,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
self.evaluate(variables.global_variables_initializer())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
w = _eager_safe_var_handle_op(
dtype=v.dtype.base_dtype,
shape=v.get_shape(),
shared_name="var4",
# Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
x = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
x = _eager_safe_var_handle_op(
dtype=v.dtype.base_dtype,
shape=v.get_shape(),
shared_name="var5",
container=ops.get_default_graph()._container)
with self.assertRaisesOpError(
"(Resource .*/var5/.* does not exist|uninitialized)"):
@ -1159,8 +1160,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertEqual("foo/var6:0", v.name)
self.evaluate(variables.global_variables_initializer())
w = resource_variable_ops.var_handle_op(
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
w = _eager_safe_var_handle_op(
dtype=v.dtype.base_dtype,
shape=v.get_shape(),
shared_name="foo/var6",
# Needed in Eager since we get a unique container name by default.
container=ops.get_default_graph()._container)
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
@ -1366,11 +1369,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
def testScatterUpdateInvalidArgs(self):
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
# The exact error and message differ between graph construction (where the
# error is realized during shape inference at graph construction time) and
# eager execution (where the error is realized during kernel execution).
with self.assertRaisesRegex(Exception, r"shape.*2.*3"):
# error is realized during shape inference at graph construction time),
# eager execution (where the error is realized during kernel execution),
# and XLA auto-clustering execution (where the error is realized in the xla
# op kernel) which is triggered when running in eager op as function mode.
with self.assertRaisesRegex(Exception, r"shape.*2.*3|RET_CHECK failure"):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
@test_util.disable_xla("b/208334252") # XLA doesn't have a deterministic impl
def testScatterAddDeterministic(self):
with context.eager_mode(), test_util.deterministic_ops():
# Normally a nondeterministic codepath occurs when the variable has at