clean up the TracingState API (#21514)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21514
ghimport-source-id: 6a9b6fdd7e696ea29e8715482708efe897230e4d

Reviewed By: jamesr66a

Differential Revision: D15719980

Pulled By: zdevito

fbshipit-source-id: 3de2746c3f3c3de4111b4cb73f4c4acedbf28862
This commit is contained in:
Zachary DeVito 2019-06-07 20:52:25 -07:00 committed by Facebook Github Bot
parent 8c5f3acfc0
commit dd0faf4366
2 changed files with 88 additions and 74 deletions

View File

@ -52,19 +52,21 @@ std::function<void()> pauseTracing() {
return [state]() { tracer::setTracingState(state); }; return [state]() { tracer::setTracingState(state); };
} }
void delValueTrace(const Variable& var) { void delValueTrace(const IValue& var) {
AT_ASSERT(var.defined()); getTracingState()->delValue(var);
auto& env_stack = getTracingState()->env_stack; }
void TracingState::delValue(const IValue& var) {
at::Tensor t = var.toTensor();
AT_ASSERT(t.defined());
for (size_t i = 0; i < env_stack.size(); ++i) { for (size_t i = 0; i < env_stack.size(); ++i) {
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map; auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
auto it = value_map.find(var); auto it = value_map.find(t);
if (it == value_map.end()) { if (it == value_map.end()) {
continue; continue;
} }
value_map.erase(it); value_map.erase(it);
} }
getTracingState()->env_stack.back().value_map.erase(var);
} }
// Given a IValue 'var', return the 'node' which represents the instruction // Given a IValue 'var', return the 'node' which represents the instruction
@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) {
// zero. This is one of the cases where a Variable can be created inside of a // zero. This is one of the cases where a Variable can be created inside of a
// trace, and if we treat it as a constant, everything will work out. // trace, and if we treat it as a constant, everything will work out.
Value* getValueTrace(const IValue& var) { Value* getValueTrace(const IValue& var) {
auto& state = getTracingState(); return getTracingState()->getValue(var);
auto& env_stack = getTracingState()->env_stack; }
Value* TracingState::getValue(const IValue& var) {
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
if (var.isTensorList()) { if (var.isTensorList()) {
return state->graph return graph
->insertNode(state->graph->createList( ->insertNode(graph->createList(
TensorType::get(), TensorType::get(),
fmap( fmap(
var.toTensorListRef(), var.toTensorListRef(),
[](const IValue& val) { return getValueTrace(val); }))) [](const IValue& val) { return getValueTrace(val); })))
->output(); ->output();
} else if (var.isTuple()) { } else if (var.isTuple()) {
return state->graph return graph
->insertNode(state->graph->createTuple(fmap( ->insertNode(graph->createTuple(fmap(
var.toTuple()->elements(), var.toTuple()->elements(),
[](const IValue& val) { return getValueTrace(val); }))) [](const IValue& val) { return getValueTrace(val); })))
->output(); ->output();
} if (var.isTensor()) { } if (var.isTensor()) {
auto ten = var.toTensor(); auto ten = var.toTensor();
if (!ten.defined()) { if (!ten.defined()) {
Node* n = state->graph->createNone(TensorType::get()); Node* n = graph->createNone(TensorType::get());
return state->graph->insertNode(n)->output(); return graph->insertNode(n)->output();
} }
for (size_t i = 0; i < env_stack.size(); ++i) { for (size_t i = 0; i < env_stack.size(); ++i) {
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map; auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) {
throw std::runtime_error(oss.str()); throw std::runtime_error(oss.str());
} }
Value* constant = state->graph->insertConstant(ten); Value* constant = graph->insertConstant(ten);
recordSourceLocation(constant->node()); recordSourceLocation(constant->node());
constant->inferTypeFrom(ten); constant->inferTypeFrom(ten);
auto it = env_stack.back().value_map.find(ten); auto it = env_stack.back().value_map.find(ten);
@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) {
} else { } else {
// If the values are non-tensors, we try to create constants // If the values are non-tensors, we try to create constants
// and bake those constants into the traced graph // and bake those constants into the traced graph
auto constant = tryInsertConstant(*state->graph, var); auto constant = tryInsertConstant(*graph, var);
if (constant) { if (constant) {
recordSourceLocation(constant.value()->node()); recordSourceLocation(constant.value()->node());
return *constant; return *constant;
@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) {
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
} }
bool TracingState::hasValue(const IValue& var) const {
Value* getOutputTrace( if (var.isTensor()) {
const std::shared_ptr<TracingState>& state, at::Tensor t = var.toTensor();
const Variable& var) { for(const auto & frame : env_stack) {
if (!var.defined()) { if (frame.value_map.count(t)) {
Node* n = state->graph->createNone(TensorType::get()); return true;
return state->graph->insertNode(n)->output(); }
}
} }
return false;
auto& value_map = getTracingState()->env_stack.back().value_map;
auto it = value_map.find(var);
if (it == value_map.end()) {
std::ostringstream os;
os << "output of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
return it->second;
} }
Value* getNestedOutputTrace(
const std::shared_ptr<TracingState>& state, Value* TracingState::getOutput(const IValue& iv) {
const IValue& iv) { if (iv.isTensor()) {
if (iv.isTensor()) { at::Tensor var = iv.toTensor();
return getOutputTrace(state, iv.toTensor()); if (!var.defined()) {
Node *n = graph->createNone(TensorType::get());
return graph->insertNode(n)->output();
}
auto &value_map = getTracingState()->env_stack.back().value_map;
auto it = value_map.find(var);
if (it == value_map.end()) {
std::ostringstream os;
os << "output of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your "
"program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
return it->second;
} else if (iv.isTuple()) { } else if (iv.isTuple()) {
const auto& elems = iv.toTuple()->elements(); const auto& elems = iv.toTuple()->elements();
auto tuple_node = auto tuple_node =
state->graph->createTuple(fmap(elems, [&state](const IValue& ival) { graph->createTuple(fmap(elems, [&](const IValue& ival) {
return getNestedOutputTrace(state, ival); return getOutput(ival);
})); }));
state->graph->insertNode(tuple_node); graph->insertNode(tuple_node);
return tuple_node->output(); return tuple_node->output();
} else { } else {
AT_ERROR( AT_ERROR(
@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
if (type->isSubtypeOf(TensorType::get())) { if (type->isSubtypeOf(TensorType::get())) {
auto input_tensor = input.toTensor(); auto input_tensor = input.toTensor();
auto name = Variable(input_tensor).name(); auto name = Variable(input_tensor).name();
auto& value_map = state->env_stack.back().value_map; if (state->hasValue(input)) {
if (value_map.find(input_tensor) != value_map.end()) {
input_tensor = input_tensor.view(input_tensor.sizes()); input_tensor = input_tensor.view(input_tensor.sizes());
} }
value->setUniqueName(name); value->setUniqueName(name);
value_map[input_tensor] = value; state->setValue(input_tensor, value);
return input_tensor; return input_tensor;
} else if (auto tuple_type = type->cast<TupleType>()) { } else if (auto tuple_type = type->cast<TupleType>()) {
auto unpack_node = auto unpack_node =
@ -330,7 +337,7 @@ void exit(const Stack& outputs) {
auto& state = getTracingState(); auto& state = getTracingState();
size_t i = 0; size_t i = 0;
for (auto& output : outputs) { for (auto& output : outputs) {
state->graph->registerOutput(getNestedOutputTrace(state, output)); state->graph->registerOutput(state->getOutput(output));
i++; i++;
} }
setTracingState(nullptr); setTracingState(nullptr);
@ -342,36 +349,36 @@ void abandon() {
} }
void setValueTrace(const IValue& v, Value* value) { void setValueTrace(const IValue& v, Value* value) {
return getTracingState()->setValue(v, value);
}
void TracingState::setValue(const IValue& v, Value* value) {
if (v.isTensor()) { if (v.isTensor()) {
auto var = v.toTensor(); auto var = v.toTensor();
AT_ASSERT(var.defined()); AT_ASSERT(var.defined());
getTracingState()->env_stack.back().value_map[var] = value; env_stack.back().value_map[var] = value;
} else if (v.isTensorList()) { } else if (v.isTensorList()) {
auto& outputs = v.toTensorList()->elements(); auto& outputs = v.toTensorList()->elements();
auto graph = getTracingState()->graph;
Node* unpack_node = Node* unpack_node =
graph->insertNode(graph->createListUnpack(value, outputs.size())); graph->insertNode(graph->createListUnpack(value, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]); setValue(outputs[i], unpack_node->outputs()[i]);
} }
} else if (v.isTuple()) { } else if (v.isTuple()) {
auto& outputs = v.toTuple()->elements(); auto& outputs = v.toTuple()->elements();
auto graph = getTracingState()->graph;
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value)); Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]); setValue(outputs[i], unpack_node->outputs()[i]);
} }
} else if (v.isGenericList()) { } else if (v.isGenericList()) {
auto elements = v.toGenericListRef(); auto elements = v.toGenericListRef();
auto graph = getTracingState()->graph;
Node* unpack_node = Node* unpack_node =
graph->insertNode(graph->createListUnpack(value, elements.size())); graph->insertNode(graph->createListUnpack(value, elements.size()));
for (size_t i = 0; i < elements.size(); ++i) { for (size_t i = 0; i < elements.size(); ++i) {
setValueTrace(elements[i], unpack_node->outputs()[i]); setValue(elements[i], unpack_node->outputs()[i]);
} }
} else if (v.isFuture()) { } else if (v.isFuture()) {
auto fut = v.toFuture(); auto fut = v.toFuture();
getTracingState()->env_stack.back().future_map[fut] = value; env_stack.back().future_map[fut] = value;
} else { } else {
std::ostringstream os; std::ostringstream os;
os << "Tracer cannot set value trace for type " << v.tagKind() << ". " os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr<TracingState> state) {
} }
TracingState::TracingState() TracingState::TracingState()
: env_stack{TracingEnvironmentFrame()}, graph(new Graph()) {} : graph(new Graph()), env_stack{Frame()} {}
TracingState::~TracingState() = default; TracingState::~TracingState() = default;

View File

@ -52,6 +52,27 @@ struct TORCH_API TracingState
TracingState(); TracingState();
~TracingState(); ~TracingState();
std::shared_ptr<Graph> graph;
bool warn = true;
bool force_outplace = false;
std::function<std::string(const Variable& var)> lookup_var_name_fn =
[](const Variable& var) { return ""; };
void enterFrame() {
env_stack.emplace_back();
}
void leaveFrame() {
env_stack.pop_back();
}
void setValue(const IValue& v, Value* value);
void delValue(const IValue& var);
Value* getValue(const IValue& var);
Value* getOutput(const IValue& var);
bool hasValue(const IValue& var) const;
private:
using WeakTensor = at::WeakTensor; using WeakTensor = at::WeakTensor;
struct WeakTensorHasher { struct WeakTensorHasher {
@ -66,7 +87,7 @@ struct TORCH_API TracingState
} }
}; };
struct TracingEnvironmentFrame { struct Frame {
std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq> std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
value_map; value_map;
// TODO weak refcount // TODO weak refcount
@ -74,14 +95,8 @@ struct TORCH_API TracingState
future_map; future_map;
}; };
using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>; std::vector<Frame> env_stack;
TracingEnvironmentStack env_stack;
std::shared_ptr<Graph> graph;
bool warn = true;
bool force_outplace = false;
std::function<std::string(const Variable& var)> lookup_var_name_fn =
[](const Variable& var) { return ""; };
}; };
// This is meant to be used as a thread local place, where we can store extra // This is meant to be used as a thread local place, where we can store extra
@ -182,11 +197,11 @@ struct TORCH_API NoWarn {
struct WithNestedTracingFrame { struct WithNestedTracingFrame {
WithNestedTracingFrame() { WithNestedTracingFrame() {
getTracingState()->env_stack.emplace_back(); getTracingState()->enterFrame();
} }
~WithNestedTracingFrame() { ~WithNestedTracingFrame() {
getTracingState()->env_stack.pop_back(); getTracingState()->leaveFrame();
} }
}; };
TORCH_API void recordSourceLocation(Node* n); TORCH_API void recordSourceLocation(Node* n);
@ -197,20 +212,12 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
// involving this variable know which node in the IR to reference. // involving this variable know which node in the IR to reference.
TORCH_API void setValueTrace(const IValue& v, Value* value); TORCH_API void setValueTrace(const IValue& v, Value* value);
TORCH_API void delValueTrace(const Variable& var); TORCH_API void delValueTrace(const IValue& var);
TORCH_API std::function<void()> pauseTracing(); TORCH_API std::function<void()> pauseTracing();
TORCH_API Value* getValueTrace(const IValue& var); TORCH_API Value* getValueTrace(const IValue& var);
TORCH_API Value* getOutputTrace(
const std::shared_ptr<TracingState>& state,
const Variable& var);
TORCH_API Value* getNestedOutputTrace(
const std::shared_ptr<TracingState>& state,
const IValue& iv);
struct TypedStack : public std::pair<Stack, TupleTypePtr> struct TypedStack : public std::pair<Stack, TupleTypePtr>
{ {
using pair::pair; using pair::pair;