From dd0faf436653d94e02b7be33cfe528e90fab6a25 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Fri, 7 Jun 2019 20:52:25 -0700 Subject: [PATCH] clean up the TracingState API (#21514) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21514 ghimport-source-id: 6a9b6fdd7e696ea29e8715482708efe897230e4d Reviewed By: jamesr66a Differential Revision: D15719980 Pulled By: zdevito fbshipit-source-id: 3de2746c3f3c3de4111b4cb73f4c4acedbf28862 --- torch/csrc/jit/tracer.cpp | 117 ++++++++++++++++++++------------------ torch/csrc/jit/tracer.h | 45 ++++++++------- 2 files changed, 88 insertions(+), 74 deletions(-) diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 6ac263ffb38..fb082384505 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -52,19 +52,21 @@ std::function pauseTracing() { return [state]() { tracer::setTracingState(state); }; } -void delValueTrace(const Variable& var) { - AT_ASSERT(var.defined()); - auto& env_stack = getTracingState()->env_stack; +void delValueTrace(const IValue& var) { + getTracingState()->delValue(var); +} +void TracingState::delValue(const IValue& var) { + at::Tensor t = var.toTensor(); + AT_ASSERT(t.defined()); for (size_t i = 0; i < env_stack.size(); ++i) { auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map; - auto it = value_map.find(var); + auto it = value_map.find(t); if (it == value_map.end()) { continue; } value_map.erase(it); } - getTracingState()->env_stack.back().value_map.erase(var); } // Given a IValue 'var', return the 'node' which represents the instruction @@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) { // zero. This is one of the cases where a Variable can be created inside of a // trace, and if we treat it as a constant, everything will work out. Value* getValueTrace(const IValue& var) { - auto& state = getTracingState(); - auto& env_stack = getTracingState()->env_stack; - + return getTracingState()->getValue(var); +} +Value* TracingState::getValue(const IValue& var) { // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments if (var.isTensorList()) { - return state->graph - ->insertNode(state->graph->createList( + return graph + ->insertNode(graph->createList( TensorType::get(), fmap( var.toTensorListRef(), [](const IValue& val) { return getValueTrace(val); }))) ->output(); } else if (var.isTuple()) { - return state->graph - ->insertNode(state->graph->createTuple(fmap( + return graph + ->insertNode(graph->createTuple(fmap( var.toTuple()->elements(), [](const IValue& val) { return getValueTrace(val); }))) ->output(); } if (var.isTensor()) { auto ten = var.toTensor(); if (!ten.defined()) { - Node* n = state->graph->createNone(TensorType::get()); - return state->graph->insertNode(n)->output(); + Node* n = graph->createNone(TensorType::get()); + return graph->insertNode(n)->output(); } for (size_t i = 0; i < env_stack.size(); ++i) { auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map; @@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) { throw std::runtime_error(oss.str()); } - Value* constant = state->graph->insertConstant(ten); + Value* constant = graph->insertConstant(ten); recordSourceLocation(constant->node()); constant->inferTypeFrom(ten); auto it = env_stack.back().value_map.find(ten); @@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) { } else { // If the values are non-tensors, we try to create constants // and bake those constants into the traced graph - auto constant = tryInsertConstant(*state->graph, var); + auto constant = tryInsertConstant(*graph, var); if (constant) { recordSourceLocation(constant.value()->node()); return *constant; @@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) { throw std::runtime_error(os.str()); } } - -Value* getOutputTrace( - const std::shared_ptr& state, - const Variable& var) { - if (!var.defined()) { - Node* n = state->graph->createNone(TensorType::get()); - return state->graph->insertNode(n)->output(); +bool TracingState::hasValue(const IValue& var) const { + if (var.isTensor()) { + at::Tensor t = var.toTensor(); + for(const auto & frame : env_stack) { + if (frame.value_map.count(t)) { + return true; + } + } } - - auto& value_map = getTracingState()->env_stack.back().value_map; - auto it = value_map.find(var); - if (it == value_map.end()) { - std::ostringstream os; - os << "output of traced region did not have observable " - << "data dependence with trace inputs; this probably indicates your program " - << "cannot be understood by the tracer."; - throw std::runtime_error(os.str()); - } - return it->second; + return false; } -Value* getNestedOutputTrace( - const std::shared_ptr& state, - const IValue& iv) { - if (iv.isTensor()) { - return getOutputTrace(state, iv.toTensor()); + +Value* TracingState::getOutput(const IValue& iv) { + if (iv.isTensor()) { + at::Tensor var = iv.toTensor(); + if (!var.defined()) { + Node *n = graph->createNone(TensorType::get()); + return graph->insertNode(n)->output(); + } + + auto &value_map = getTracingState()->env_stack.back().value_map; + auto it = value_map.find(var); + if (it == value_map.end()) { + std::ostringstream os; + os << "output of traced region did not have observable " + << "data dependence with trace inputs; this probably indicates your " + "program " + << "cannot be understood by the tracer."; + throw std::runtime_error(os.str()); + } + return it->second; } else if (iv.isTuple()) { const auto& elems = iv.toTuple()->elements(); auto tuple_node = - state->graph->createTuple(fmap(elems, [&state](const IValue& ival) { - return getNestedOutputTrace(state, ival); + graph->createTuple(fmap(elems, [&](const IValue& ival) { + return getOutput(ival); })); - state->graph->insertNode(tuple_node); + graph->insertNode(tuple_node); return tuple_node->output(); } else { AT_ERROR( @@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr & state, const IValue if (type->isSubtypeOf(TensorType::get())) { auto input_tensor = input.toTensor(); auto name = Variable(input_tensor).name(); - auto& value_map = state->env_stack.back().value_map; - if (value_map.find(input_tensor) != value_map.end()) { + if (state->hasValue(input)) { input_tensor = input_tensor.view(input_tensor.sizes()); } value->setUniqueName(name); - value_map[input_tensor] = value; + state->setValue(input_tensor, value); return input_tensor; } else if (auto tuple_type = type->cast()) { auto unpack_node = @@ -330,7 +337,7 @@ void exit(const Stack& outputs) { auto& state = getTracingState(); size_t i = 0; for (auto& output : outputs) { - state->graph->registerOutput(getNestedOutputTrace(state, output)); + state->graph->registerOutput(state->getOutput(output)); i++; } setTracingState(nullptr); @@ -342,36 +349,36 @@ void abandon() { } void setValueTrace(const IValue& v, Value* value) { + return getTracingState()->setValue(v, value); +} +void TracingState::setValue(const IValue& v, Value* value) { if (v.isTensor()) { auto var = v.toTensor(); AT_ASSERT(var.defined()); - getTracingState()->env_stack.back().value_map[var] = value; + env_stack.back().value_map[var] = value; } else if (v.isTensorList()) { auto& outputs = v.toTensorList()->elements(); - auto graph = getTracingState()->graph; Node* unpack_node = graph->insertNode(graph->createListUnpack(value, outputs.size())); for (size_t i = 0; i < outputs.size(); ++i) { - setValueTrace(outputs[i], unpack_node->outputs()[i]); + setValue(outputs[i], unpack_node->outputs()[i]); } } else if (v.isTuple()) { auto& outputs = v.toTuple()->elements(); - auto graph = getTracingState()->graph; Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value)); for (size_t i = 0; i < outputs.size(); ++i) { - setValueTrace(outputs[i], unpack_node->outputs()[i]); + setValue(outputs[i], unpack_node->outputs()[i]); } } else if (v.isGenericList()) { auto elements = v.toGenericListRef(); - auto graph = getTracingState()->graph; Node* unpack_node = graph->insertNode(graph->createListUnpack(value, elements.size())); for (size_t i = 0; i < elements.size(); ++i) { - setValueTrace(elements[i], unpack_node->outputs()[i]); + setValue(elements[i], unpack_node->outputs()[i]); } } else if (v.isFuture()) { auto fut = v.toFuture(); - getTracingState()->env_stack.back().future_map[fut] = value; + env_stack.back().future_map[fut] = value; } else { std::ostringstream os; os << "Tracer cannot set value trace for type " << v.tagKind() << ". " @@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr state) { } TracingState::TracingState() - : env_stack{TracingEnvironmentFrame()}, graph(new Graph()) {} + : graph(new Graph()), env_stack{Frame()} {} TracingState::~TracingState() = default; diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 5930f5d3f49..433bd9a164b 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -52,6 +52,27 @@ struct TORCH_API TracingState TracingState(); ~TracingState(); + std::shared_ptr graph; + bool warn = true; + bool force_outplace = false; + std::function lookup_var_name_fn = + [](const Variable& var) { return ""; }; + + void enterFrame() { + env_stack.emplace_back(); + } + + void leaveFrame() { + env_stack.pop_back(); + } + + void setValue(const IValue& v, Value* value); + void delValue(const IValue& var); + Value* getValue(const IValue& var); + Value* getOutput(const IValue& var); + bool hasValue(const IValue& var) const; + +private: using WeakTensor = at::WeakTensor; struct WeakTensorHasher { @@ -66,7 +87,7 @@ struct TORCH_API TracingState } }; - struct TracingEnvironmentFrame { + struct Frame { std::unordered_map value_map; // TODO weak refcount @@ -74,14 +95,8 @@ struct TORCH_API TracingState future_map; }; - using TracingEnvironmentStack = std::vector; + std::vector env_stack; - TracingEnvironmentStack env_stack; - std::shared_ptr graph; - bool warn = true; - bool force_outplace = false; - std::function lookup_var_name_fn = - [](const Variable& var) { return ""; }; }; // This is meant to be used as a thread local place, where we can store extra @@ -182,11 +197,11 @@ struct TORCH_API NoWarn { struct WithNestedTracingFrame { WithNestedTracingFrame() { - getTracingState()->env_stack.emplace_back(); + getTracingState()->enterFrame(); } ~WithNestedTracingFrame() { - getTracingState()->env_stack.pop_back(); + getTracingState()->leaveFrame(); } }; TORCH_API void recordSourceLocation(Node* n); @@ -197,20 +212,12 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*)); // involving this variable know which node in the IR to reference. TORCH_API void setValueTrace(const IValue& v, Value* value); -TORCH_API void delValueTrace(const Variable& var); +TORCH_API void delValueTrace(const IValue& var); TORCH_API std::function pauseTracing(); TORCH_API Value* getValueTrace(const IValue& var); -TORCH_API Value* getOutputTrace( - const std::shared_ptr& state, - const Variable& var); - -TORCH_API Value* getNestedOutputTrace( - const std::shared_ptr& state, - const IValue& iv); - struct TypedStack : public std::pair { using pair::pair;