Merge pull request #50500 from geetachavan1/cherrypicks_0Q30D

[CherryPick:r2.6]Rollback breaking change
This commit is contained in:
Mihai Maruseac 2021-06-28 16:08:27 -07:00 committed by GitHub
commit 73fb51374d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 57 deletions

View File

@ -19,9 +19,6 @@ limitations under the License.
namespace tensorflow {
// Must be declared here for pre-C++17 compatibility.
/* static */ constexpr const char* ResourceHandle::ANONYMOUS_NAME;
ResourceHandle::ResourceHandle() {}
ResourceHandle::ResourceHandle(const ResourceHandleProto& proto) {

View File

@ -17,8 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/random.h"
namespace tensorflow {
@ -69,25 +67,6 @@ class Var : public ResourceBase {
mutex* mu() { return &mu_; }
Tensor* tensor() { return &tensor_; }
Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
mutex_lock l(mu_);
Node* var = ops::SourceOp(
"VarHandleOp",
builder->opts()
.WithAttr("dtype", tensor_.dtype())
.WithAttr("shape", tensor_.shape())
.WithAttr("shared_name", ResourceHandle::ANONYMOUS_NAME));
Node* value = ops::SourceOp("Const", builder->opts()
.WithAttr("dtype", tensor_.dtype())
.WithAttr("value", tensor_));
Node* assign =
ops::BinaryOp("AssignVariableOp", var, value,
builder->opts().WithAttr("dtype", tensor_.dtype()));
*out =
ops::UnaryOp("Identity", var, builder->opts().WithControlInput(assign));
return Status::OK();
}
std::string DebugString() const override {
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
tensor_.shape().DebugString());
@ -109,7 +88,7 @@ class Var : public ResourceBase {
std::atomic<bool> copy_on_read_mode{false};
private:
mutable mutex mu_;
mutex mu_;
Tensor tensor_;
~Var() override {}

View File

@ -120,17 +120,15 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset1 = replicated_ds[self._device1]
dataset2 = replicated_ds[self._device2]
self.evaluate(counter_var.initializer)
with ops.device(self._device1):
self.assertDatasetProduces(
dataset1, range(1, 101), requires_initialization=True)
with ops.device(self._device2):
self.assertDatasetProduces(
dataset2, range(1, 101), requires_initialization=True)
# Iterate through the original device last so that replication happens
# before counter_var is modified. The order only matters in graph mode.
with ops.device(self._device0):
self.assertDatasetProduces(
dataset0, range(1, 101), requires_initialization=True)
with ops.device(self._device1):
self.assertDatasetProduces(
dataset1, range(101, 201), requires_initialization=True)
with ops.device(self._device2):
self.assertDatasetProduces(
dataset2, range(201, 301), requires_initialization=True)
@combinations.generate(test_base.default_test_combinations())
def testExternalStatePolicyIgnore(self):
@ -326,10 +324,10 @@ class EagerClusterReplicateTest(test_base.DatasetTestBase,
dataset0, range(1, 101), requires_initialization=True)
with ops.device(self._device1):
self.assertDatasetProduces(
dataset1, range(1, 101), requires_initialization=True)
dataset1, range(101, 201), requires_initialization=True)
with ops.device(self._device2):
self.assertDatasetProduces(
dataset2, range(1, 101), requires_initialization=True)
dataset2, range(201, 301), requires_initialization=True)
class GraphClusterReplicateTest(test_base.DatasetTestBase,

View File

@ -44,8 +44,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
@ -679,26 +677,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
ds = self.make_distributed_dataset(ds, cluster)
self.assertDatasetProduces(ds, [tensor])
@combinations.generate(
combinations.times(test_base.graph_only_combinations(),
combinations.combine(use_resource=False)) +
combinations.times(test_base.default_test_combinations(),
combinations.combine(use_resource=True)))
def testVariables(self, use_resource):
cluster = data_service_test_base.TestCluster(num_workers=1)
if not use_resource:
with variable_scope.variable_scope("foo", use_resource=False):
v = variables.VariableV1(10, dtype=dtypes.int64)
else:
v = variables.Variable(10, dtype=dtypes.int64)
ds = dataset_ops.Dataset.range(3)
ds = ds.map(lambda x: x + v)
ds = self.make_distributed_dataset(ds, cluster)
self.evaluate(v.initializer)
self.assertDatasetProduces(
ds, list(range(10, 13)), requires_initialization=True)
if __name__ == "__main__":
test.main()