mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix race condition in TensorForest tree traversal.
PiperOrigin-RevId: 170990425
This commit is contained in:
parent
d016cb0205
commit
727d6270f9
|
|
@ -271,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel {
|
||||||
string serialized_proto;
|
string serialized_proto;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
|
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
|
||||||
input_spec_.ParseFromString(serialized_proto);
|
input_spec_.ParseFromString(serialized_proto);
|
||||||
|
|
||||||
data_set_ =
|
|
||||||
std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
|
|
@ -282,7 +279,8 @@ class TraverseTreeV4Op : public OpKernel {
|
||||||
const Tensor& sparse_input_values = context->input(3);
|
const Tensor& sparse_input_values = context->input(3);
|
||||||
const Tensor& sparse_input_shape = context->input(4);
|
const Tensor& sparse_input_shape = context->input(4);
|
||||||
|
|
||||||
data_set_->set_input_tensors(input_data, sparse_input_indices,
|
std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
|
||||||
|
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||||
sparse_input_values, sparse_input_shape);
|
sparse_input_values, sparse_input_shape);
|
||||||
|
|
||||||
DecisionTreeResource* decision_tree_resource;
|
DecisionTreeResource* decision_tree_resource;
|
||||||
|
|
@ -291,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel {
|
||||||
mutex_lock l(*decision_tree_resource->get_mutex());
|
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||||
core::ScopedUnref unref_me(decision_tree_resource);
|
core::ScopedUnref unref_me(decision_tree_resource);
|
||||||
|
|
||||||
const int num_data = data_set_->NumItems();
|
const int num_data = data_set->NumItems();
|
||||||
|
|
||||||
Tensor* output_predictions = nullptr;
|
Tensor* output_predictions = nullptr;
|
||||||
TensorShape output_shape;
|
TensorShape output_shape;
|
||||||
|
|
@ -306,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel {
|
||||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||||
int num_threads = worker_threads->num_threads;
|
int num_threads = worker_threads->num_threads;
|
||||||
const int64 costPerTraverse = 500;
|
const int64 costPerTraverse = 500;
|
||||||
auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
|
auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource,
|
||||||
int64 start, int64 end) {
|
num_data](int64 start, int64 end) {
|
||||||
CHECK(start <= end);
|
CHECK(start <= end);
|
||||||
CHECK(end <= num_data);
|
CHECK(end <= num_data);
|
||||||
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
|
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
|
||||||
static_cast<int32>(end), set_leaf_ids, nullptr);
|
static_cast<int32>(end), set_leaf_ids, nullptr);
|
||||||
};
|
};
|
||||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||||
|
|
@ -319,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
tensorforest::TensorForestDataSpec input_spec_;
|
tensorforest::TensorForestDataSpec input_spec_;
|
||||||
std::unique_ptr<TensorDataSet> data_set_;
|
|
||||||
TensorForestParams param_proto_;
|
TensorForestParams param_proto_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user