mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #50500 from geetachavan1/cherrypicks_0Q30D
[CherryPick:r2.6]Rollback breaking change
This commit is contained in:
commit
73fb51374d
|
|
@ -19,9 +19,6 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
// Must be declared here for pre-C++17 compatibility.
|
||||
/* static */ constexpr const char* ResourceHandle::ANONYMOUS_NAME;
|
||||
|
||||
ResourceHandle::ResourceHandle() {}
|
||||
|
||||
ResourceHandle::ResourceHandle(const ResourceHandleProto& proto) {
|
||||
|
|
|
|||
|
|
@ -17,8 +17,6 @@ limitations under the License.
|
|||
#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
@ -69,25 +67,6 @@ class Var : public ResourceBase {
|
|||
mutex* mu() { return &mu_; }
|
||||
Tensor* tensor() { return &tensor_; }
|
||||
|
||||
Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
|
||||
mutex_lock l(mu_);
|
||||
Node* var = ops::SourceOp(
|
||||
"VarHandleOp",
|
||||
builder->opts()
|
||||
.WithAttr("dtype", tensor_.dtype())
|
||||
.WithAttr("shape", tensor_.shape())
|
||||
.WithAttr("shared_name", ResourceHandle::ANONYMOUS_NAME));
|
||||
Node* value = ops::SourceOp("Const", builder->opts()
|
||||
.WithAttr("dtype", tensor_.dtype())
|
||||
.WithAttr("value", tensor_));
|
||||
Node* assign =
|
||||
ops::BinaryOp("AssignVariableOp", var, value,
|
||||
builder->opts().WithAttr("dtype", tensor_.dtype()));
|
||||
*out =
|
||||
ops::UnaryOp("Identity", var, builder->opts().WithControlInput(assign));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string DebugString() const override {
|
||||
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
|
||||
tensor_.shape().DebugString());
|
||||
|
|
@ -109,7 +88,7 @@ class Var : public ResourceBase {
|
|||
std::atomic<bool> copy_on_read_mode{false};
|
||||
|
||||
private:
|
||||
mutable mutex mu_;
|
||||
mutex mu_;
|
||||
Tensor tensor_;
|
||||
|
||||
~Var() override {}
|
||||
|
|
|
|||
|
|
@ -120,17 +120,15 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset1 = replicated_ds[self._device1]
|
||||
dataset2 = replicated_ds[self._device2]
|
||||
self.evaluate(counter_var.initializer)
|
||||
with ops.device(self._device1):
|
||||
self.assertDatasetProduces(
|
||||
dataset1, range(1, 101), requires_initialization=True)
|
||||
with ops.device(self._device2):
|
||||
self.assertDatasetProduces(
|
||||
dataset2, range(1, 101), requires_initialization=True)
|
||||
# Iterate through the original device last so that replication happens
|
||||
# before counter_var is modified. The order only matters in graph mode.
|
||||
with ops.device(self._device0):
|
||||
self.assertDatasetProduces(
|
||||
dataset0, range(1, 101), requires_initialization=True)
|
||||
with ops.device(self._device1):
|
||||
self.assertDatasetProduces(
|
||||
dataset1, range(101, 201), requires_initialization=True)
|
||||
with ops.device(self._device2):
|
||||
self.assertDatasetProduces(
|
||||
dataset2, range(201, 301), requires_initialization=True)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testExternalStatePolicyIgnore(self):
|
||||
|
|
@ -326,10 +324,10 @@ class EagerClusterReplicateTest(test_base.DatasetTestBase,
|
|||
dataset0, range(1, 101), requires_initialization=True)
|
||||
with ops.device(self._device1):
|
||||
self.assertDatasetProduces(
|
||||
dataset1, range(1, 101), requires_initialization=True)
|
||||
dataset1, range(101, 201), requires_initialization=True)
|
||||
with ops.device(self._device2):
|
||||
self.assertDatasetProduces(
|
||||
dataset2, range(1, 101), requires_initialization=True)
|
||||
dataset2, range(201, 301), requires_initialization=True)
|
||||
|
||||
|
||||
class GraphClusterReplicateTest(test_base.DatasetTestBase,
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@ from tensorflow.python.ops import random_ops
|
|||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
|
||||
|
|
@ -679,26 +677,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||
ds = self.make_distributed_dataset(ds, cluster)
|
||||
self.assertDatasetProduces(ds, [tensor])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.graph_only_combinations(),
|
||||
combinations.combine(use_resource=False)) +
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(use_resource=True)))
|
||||
def testVariables(self, use_resource):
|
||||
cluster = data_service_test_base.TestCluster(num_workers=1)
|
||||
if not use_resource:
|
||||
with variable_scope.variable_scope("foo", use_resource=False):
|
||||
v = variables.VariableV1(10, dtype=dtypes.int64)
|
||||
else:
|
||||
v = variables.Variable(10, dtype=dtypes.int64)
|
||||
|
||||
ds = dataset_ops.Dataset.range(3)
|
||||
ds = ds.map(lambda x: x + v)
|
||||
ds = self.make_distributed_dataset(ds, cluster)
|
||||
self.evaluate(v.initializer)
|
||||
self.assertDatasetProduces(
|
||||
ds, list(range(10, 13)), requires_initialization=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user