mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Properly instantiate var handles when running eagerly in tests.
PiperOrigin-RevId: 413194043 Change-Id: I3f10eb0ba2caf8767c1d6b464f2ca38ca0a2cabf
This commit is contained in:
parent
0fe244367c
commit
ce2a70df4f
|
|
@ -58,6 +58,17 @@ from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
|
def _eager_safe_var_handle_op(*args, **kwargs):
|
||||||
|
# When running in eager mode the `shared_name` should be set to the
|
||||||
|
# `anonymous_name` to avoid spurious sharing issues. The runtime generates a
|
||||||
|
# unique name on our behalf when the reserved `anonymous_name` is used as the
|
||||||
|
# `shared_name`.
|
||||||
|
if context.executing_eagerly() and "shared_name" not in kwargs:
|
||||||
|
kwargs["shared_name"] = context.anonymous_name()
|
||||||
|
return resource_variable_ops.var_handle_op(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.with_eager_op_as_function
|
||||||
@test_util.with_control_flow_v2
|
@test_util.with_control_flow_v2
|
||||||
class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
|
@ -72,7 +83,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testHandleDtypeShapeMatch(self):
|
def testHandleDtypeShapeMatch(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
|
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
|
||||||
|
|
@ -114,7 +125,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
def testReadVariableDtypeMismatchEager(self):
|
def testReadVariableDtypeMismatchEager(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(
|
||||||
dtype=dtypes.int32, shape=[1], name="foo")
|
dtype=dtypes.int32, shape=[1], name="foo")
|
||||||
resource_variable_ops.assign_variable_op(handle, 1)
|
resource_variable_ops.assign_variable_op(handle, 1)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
|
@ -201,7 +212,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testFetchHandle(self):
|
def testFetchHandle(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(
|
||||||
dtype=dtypes.int32, shape=[1], name="foo")
|
dtype=dtypes.int32, shape=[1], name="foo")
|
||||||
self.assertNotEmpty(self.evaluate(handle))
|
self.assertNotEmpty(self.evaluate(handle))
|
||||||
|
|
||||||
|
|
@ -215,7 +226,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
def testAssignVariableDtypeMismatchEager(self):
|
def testAssignVariableDtypeMismatchEager(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(
|
||||||
dtype=dtypes.int32, shape=[1], name="foo")
|
dtype=dtypes.int32, shape=[1], name="foo")
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([1]))
|
handle, constant_op.constant([1]))
|
||||||
|
|
@ -247,14 +258,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
def testFormatResourceHandle(self):
|
def testFormatResourceHandle(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(
|
||||||
dtype=dtypes.int32, shape=[1], name="foo")
|
dtype=dtypes.int32, shape=[1], name="foo")
|
||||||
self.assertIn("<Resource Tensor>", str(handle))
|
self.assertIn("<Resource Tensor>", str(handle))
|
||||||
self.assertIn("<Resource Tensor>", repr(handle))
|
self.assertIn("<Resource Tensor>", repr(handle))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testDtypeSurvivesIdentity(self):
|
def testDtypeSurvivesIdentity(self):
|
||||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||||
id_handle = array_ops.identity(handle)
|
id_handle = array_ops.identity(handle)
|
||||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||||
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
|
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
|
||||||
|
|
@ -265,7 +276,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testCreateRead(self):
|
def testCreateRead(self):
|
||||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||||
value = self.evaluate(
|
value = self.evaluate(
|
||||||
|
|
@ -274,7 +285,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testManyAssigns(self):
|
def testManyAssigns(self):
|
||||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||||
create = resource_variable_ops.assign_variable_op(
|
create = resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant(1, dtype=dtypes.int32))
|
handle, constant_op.constant(1, dtype=dtypes.int32))
|
||||||
with ops.control_dependencies([create]):
|
with ops.control_dependencies([create]):
|
||||||
|
|
@ -292,7 +303,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testAssignAdd(self):
|
def testAssignAdd(self):
|
||||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
handle, constant_op.constant(1, dtype=dtypes.int32)))
|
||||||
self.evaluate(resource_variable_ops.assign_add_variable_op(
|
self.evaluate(resource_variable_ops.assign_add_variable_op(
|
||||||
|
|
@ -303,8 +314,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterAdd(self):
|
def testScatterAdd(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||||
|
|
@ -384,8 +394,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterSub(self):
|
def testScatterSub(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||||
|
|
@ -397,8 +406,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterMul(self):
|
def testScatterMul(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||||
|
|
@ -429,8 +437,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterDiv(self):
|
def testScatterDiv(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||||
|
|
@ -452,8 +459,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterMin(self):
|
def testScatterMin(self):
|
||||||
with ops.device("cpu:0"):
|
with ops.device("cpu:0"):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(handle,
|
resource_variable_ops.assign_variable_op(handle,
|
||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
|
|
@ -488,8 +494,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterMax(self):
|
def testScatterMax(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||||
|
|
@ -501,8 +506,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterAddScalar(self):
|
def testScatterAddScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||||
|
|
@ -514,8 +518,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterSubScalar(self):
|
def testScatterSubScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||||
|
|
@ -527,8 +530,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterMulScalar(self):
|
def testScatterMulScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[1]], dtype=dtypes.int32)))
|
||||||
|
|
@ -540,8 +542,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterDivScalar(self):
|
def testScatterDivScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||||
|
|
@ -553,8 +554,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterMinScalar(self):
|
def testScatterMinScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||||
|
|
@ -566,8 +566,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testScatterMaxScalar(self):
|
def testScatterMaxScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[1, 1])
|
||||||
dtype=dtypes.int32, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
handle, constant_op.constant([[6]], dtype=dtypes.int32)))
|
||||||
|
|
@ -690,8 +689,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testScatterUpdateString(self):
|
def testScatterUpdateString(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
|
||||||
dtype=dtypes.string, shape=[1, 1])
|
|
||||||
self.evaluate(resource_variable_ops.assign_variable_op(
|
self.evaluate(resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([["a"]], dtype=dtypes.string)))
|
handle, constant_op.constant([["a"]], dtype=dtypes.string)))
|
||||||
self.evaluate(resource_variable_ops.resource_scatter_update(
|
self.evaluate(resource_variable_ops.resource_scatter_update(
|
||||||
|
|
@ -702,8 +700,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testScatterUpdateStringScalar(self):
|
def testScatterUpdateStringScalar(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = _eager_safe_var_handle_op(dtype=dtypes.string, shape=[1, 1])
|
||||||
dtype=dtypes.string, shape=[1, 1])
|
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
resource_variable_ops.assign_variable_op(handle,
|
resource_variable_ops.assign_variable_op(handle,
|
||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
|
|
@ -1016,7 +1013,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
with self.assertRaises(errors.FailedPreconditionError):
|
with self.assertRaises(errors.FailedPreconditionError):
|
||||||
self.evaluate(v.value())
|
self.evaluate(v.value())
|
||||||
# Handle to a resource not actually created.
|
# Handle to a resource not actually created.
|
||||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
handle = _eager_safe_var_handle_op(dtype=dtypes.int32, shape=[])
|
||||||
# Should raise no exception
|
# Should raise no exception
|
||||||
self.evaluate(resource_variable_ops.destroy_resource_op(
|
self.evaluate(resource_variable_ops.destroy_resource_op(
|
||||||
handle, ignore_lookup_error=True))
|
handle, ignore_lookup_error=True))
|
||||||
|
|
@ -1136,15 +1133,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
w = resource_variable_ops.var_handle_op(
|
w = _eager_safe_var_handle_op(
|
||||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4",
|
dtype=v.dtype.base_dtype,
|
||||||
|
shape=v.get_shape(),
|
||||||
|
shared_name="var4",
|
||||||
# Needed in Eager since we get a unique container name by default.
|
# Needed in Eager since we get a unique container name by default.
|
||||||
container=ops.get_default_graph()._container)
|
container=ops.get_default_graph()._container)
|
||||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||||
self.assertEqual(300.0, self.evaluate(w_read))
|
self.assertEqual(300.0, self.evaluate(w_read))
|
||||||
|
|
||||||
x = resource_variable_ops.var_handle_op(
|
x = _eager_safe_var_handle_op(
|
||||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5",
|
dtype=v.dtype.base_dtype,
|
||||||
|
shape=v.get_shape(),
|
||||||
|
shared_name="var5",
|
||||||
container=ops.get_default_graph()._container)
|
container=ops.get_default_graph()._container)
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
"(Resource .*/var5/.* does not exist|uninitialized)"):
|
"(Resource .*/var5/.* does not exist|uninitialized)"):
|
||||||
|
|
@ -1159,8 +1160,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
self.assertEqual("foo/var6:0", v.name)
|
self.assertEqual("foo/var6:0", v.name)
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
w = resource_variable_ops.var_handle_op(
|
w = _eager_safe_var_handle_op(
|
||||||
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6",
|
dtype=v.dtype.base_dtype,
|
||||||
|
shape=v.get_shape(),
|
||||||
|
shared_name="foo/var6",
|
||||||
# Needed in Eager since we get a unique container name by default.
|
# Needed in Eager since we get a unique container name by default.
|
||||||
container=ops.get_default_graph()._container)
|
container=ops.get_default_graph()._container)
|
||||||
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
|
||||||
|
|
@ -1366,11 +1369,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
def testScatterUpdateInvalidArgs(self):
|
def testScatterUpdateInvalidArgs(self):
|
||||||
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
|
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
|
||||||
# The exact error and message differ between graph construction (where the
|
# The exact error and message differ between graph construction (where the
|
||||||
# error is realized during shape inference at graph construction time) and
|
# error is realized during shape inference at graph construction time),
|
||||||
# eager execution (where the error is realized during kernel execution).
|
# eager execution (where the error is realized during kernel execution),
|
||||||
with self.assertRaisesRegex(Exception, r"shape.*2.*3"):
|
# and XLA auto-clustering execution (where the error is realized in the xla
|
||||||
|
# op kernel) which is triggered when running in eager op as function mode.
|
||||||
|
with self.assertRaisesRegex(Exception, r"shape.*2.*3|RET_CHECK failure"):
|
||||||
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
|
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
|
||||||
|
|
||||||
|
@test_util.disable_xla("b/208334252") # XLA doesn't have a deterministic impl
|
||||||
def testScatterAddDeterministic(self):
|
def testScatterAddDeterministic(self):
|
||||||
with context.eager_mode(), test_util.deterministic_ops():
|
with context.eager_mode(), test_util.deterministic_ops():
|
||||||
# Normally a nondeterministic codepath occurs when the variable has at
|
# Normally a nondeterministic codepath occurs when the variable has at
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user