Fix race condition in TensorForest tree traversal.

PiperOrigin-RevId: 170990425
This commit is contained in:
A. Unique TensorFlower 2017-10-04 03:57:59 -07:00 committed by TensorFlower Gardener
parent d016cb0205
commit 727d6270f9

View File

@ -271,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel {
string serialized_proto;
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
input_spec_.ParseFromString(serialized_proto);
data_set_ =
std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
}
void Compute(OpKernelContext* context) override {
@ -282,8 +279,9 @@ class TraverseTreeV4Op : public OpKernel {
const Tensor& sparse_input_values = context->input(3);
const Tensor& sparse_input_shape = context->input(4);
data_set_->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
@ -291,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel {
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
const int num_data = data_set_->NumItems();
const int num_data = data_set->NumItems();
Tensor* output_predictions = nullptr;
TensorShape output_shape;
@ -306,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
const int64 costPerTraverse = 500;
auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
int64 start, int64 end) {
auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource,
num_data](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
static_cast<int32>(end), set_leaf_ids, nullptr);
};
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
@ -319,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel {
private:
tensorforest::TensorForestDataSpec input_spec_;
std::unique_ptr<TensorDataSet> data_set_;
TensorForestParams param_proto_;
};