mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Fix race condition in TensorForest tree traversal.
PiperOrigin-RevId: 170990425
This commit is contained in:
parent
d016cb0205
commit
727d6270f9
|
|
@ -271,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel {
|
|||
string serialized_proto;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
|
||||
input_spec_.ParseFromString(serialized_proto);
|
||||
|
||||
data_set_ =
|
||||
std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
|
|
@ -282,8 +279,9 @@ class TraverseTreeV4Op : public OpKernel {
|
|||
const Tensor& sparse_input_values = context->input(3);
|
||||
const Tensor& sparse_input_shape = context->input(4);
|
||||
|
||||
data_set_->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
|
||||
data_set->set_input_tensors(input_data, sparse_input_indices,
|
||||
sparse_input_values, sparse_input_shape);
|
||||
|
||||
DecisionTreeResource* decision_tree_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
|
|
@ -291,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel {
|
|||
mutex_lock l(*decision_tree_resource->get_mutex());
|
||||
core::ScopedUnref unref_me(decision_tree_resource);
|
||||
|
||||
const int num_data = data_set_->NumItems();
|
||||
const int num_data = data_set->NumItems();
|
||||
|
||||
Tensor* output_predictions = nullptr;
|
||||
TensorShape output_shape;
|
||||
|
|
@ -306,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel {
|
|||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
const int64 costPerTraverse = 500;
|
||||
auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
|
||||
int64 start, int64 end) {
|
||||
auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource,
|
||||
num_data](int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
|
||||
TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
|
||||
static_cast<int32>(end), set_leaf_ids, nullptr);
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
|
||||
|
|
@ -319,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel {
|
|||
|
||||
private:
|
||||
tensorforest::TensorForestDataSpec input_spec_;
|
||||
std::unique_ptr<TensorDataSet> data_set_;
|
||||
TensorForestParams param_proto_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user