clean up the TracingState API (#21514)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21514
ghimport-source-id: 6a9b6fdd7e696ea29e8715482708efe897230e4d

Reviewed By: jamesr66a

Differential Revision: D15719980

Pulled By: zdevito

fbshipit-source-id: 3de2746c3f3c3de4111b4cb73f4c4acedbf28862
This commit is contained in:
Zachary DeVito 2019-06-07 20:52:25 -07:00 committed by Facebook Github Bot
parent 8c5f3acfc0
commit dd0faf4366
2 changed files with 88 additions and 74 deletions

View File

@ -52,19 +52,21 @@ std::function<void()> pauseTracing() {
return [state]() { tracer::setTracingState(state); };
}
void delValueTrace(const Variable& var) {
AT_ASSERT(var.defined());
auto& env_stack = getTracingState()->env_stack;
void delValueTrace(const IValue& var) {
getTracingState()->delValue(var);
}
void TracingState::delValue(const IValue& var) {
at::Tensor t = var.toTensor();
AT_ASSERT(t.defined());
for (size_t i = 0; i < env_stack.size(); ++i) {
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
auto it = value_map.find(var);
auto it = value_map.find(t);
if (it == value_map.end()) {
continue;
}
value_map.erase(it);
}
getTracingState()->env_stack.back().value_map.erase(var);
}
// Given a IValue 'var', return the 'node' which represents the instruction
@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) {
// zero. This is one of the cases where a Variable can be created inside of a
// trace, and if we treat it as a constant, everything will work out.
Value* getValueTrace(const IValue& var) {
auto& state = getTracingState();
auto& env_stack = getTracingState()->env_stack;
return getTracingState()->getValue(var);
}
Value* TracingState::getValue(const IValue& var) {
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
if (var.isTensorList()) {
return state->graph
->insertNode(state->graph->createList(
return graph
->insertNode(graph->createList(
TensorType::get(),
fmap(
var.toTensorListRef(),
[](const IValue& val) { return getValueTrace(val); })))
->output();
} else if (var.isTuple()) {
return state->graph
->insertNode(state->graph->createTuple(fmap(
return graph
->insertNode(graph->createTuple(fmap(
var.toTuple()->elements(),
[](const IValue& val) { return getValueTrace(val); })))
->output();
} if (var.isTensor()) {
auto ten = var.toTensor();
if (!ten.defined()) {
Node* n = state->graph->createNone(TensorType::get());
return state->graph->insertNode(n)->output();
Node* n = graph->createNone(TensorType::get());
return graph->insertNode(n)->output();
}
for (size_t i = 0; i < env_stack.size(); ++i) {
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) {
throw std::runtime_error(oss.str());
}
Value* constant = state->graph->insertConstant(ten);
Value* constant = graph->insertConstant(ten);
recordSourceLocation(constant->node());
constant->inferTypeFrom(ten);
auto it = env_stack.back().value_map.find(ten);
@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) {
} else {
// If the values are non-tensors, we try to create constants
// and bake those constants into the traced graph
auto constant = tryInsertConstant(*state->graph, var);
auto constant = tryInsertConstant(*graph, var);
if (constant) {
recordSourceLocation(constant.value()->node());
return *constant;
@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) {
throw std::runtime_error(os.str());
}
}
Value* getOutputTrace(
const std::shared_ptr<TracingState>& state,
const Variable& var) {
if (!var.defined()) {
Node* n = state->graph->createNone(TensorType::get());
return state->graph->insertNode(n)->output();
bool TracingState::hasValue(const IValue& var) const {
if (var.isTensor()) {
at::Tensor t = var.toTensor();
for(const auto & frame : env_stack) {
if (frame.value_map.count(t)) {
return true;
}
}
}
auto& value_map = getTracingState()->env_stack.back().value_map;
auto it = value_map.find(var);
if (it == value_map.end()) {
std::ostringstream os;
os << "output of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
return it->second;
return false;
}
Value* getNestedOutputTrace(
const std::shared_ptr<TracingState>& state,
const IValue& iv) {
if (iv.isTensor()) {
return getOutputTrace(state, iv.toTensor());
Value* TracingState::getOutput(const IValue& iv) {
if (iv.isTensor()) {
at::Tensor var = iv.toTensor();
if (!var.defined()) {
Node *n = graph->createNone(TensorType::get());
return graph->insertNode(n)->output();
}
auto &value_map = getTracingState()->env_stack.back().value_map;
auto it = value_map.find(var);
if (it == value_map.end()) {
std::ostringstream os;
os << "output of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your "
"program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
return it->second;
} else if (iv.isTuple()) {
const auto& elems = iv.toTuple()->elements();
auto tuple_node =
state->graph->createTuple(fmap(elems, [&state](const IValue& ival) {
return getNestedOutputTrace(state, ival);
graph->createTuple(fmap(elems, [&](const IValue& ival) {
return getOutput(ival);
}));
state->graph->insertNode(tuple_node);
graph->insertNode(tuple_node);
return tuple_node->output();
} else {
AT_ERROR(
@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
if (type->isSubtypeOf(TensorType::get())) {
auto input_tensor = input.toTensor();
auto name = Variable(input_tensor).name();
auto& value_map = state->env_stack.back().value_map;
if (value_map.find(input_tensor) != value_map.end()) {
if (state->hasValue(input)) {
input_tensor = input_tensor.view(input_tensor.sizes());
}
value->setUniqueName(name);
value_map[input_tensor] = value;
state->setValue(input_tensor, value);
return input_tensor;
} else if (auto tuple_type = type->cast<TupleType>()) {
auto unpack_node =
@ -330,7 +337,7 @@ void exit(const Stack& outputs) {
auto& state = getTracingState();
size_t i = 0;
for (auto& output : outputs) {
state->graph->registerOutput(getNestedOutputTrace(state, output));
state->graph->registerOutput(state->getOutput(output));
i++;
}
setTracingState(nullptr);
@ -342,36 +349,36 @@ void abandon() {
}
void setValueTrace(const IValue& v, Value* value) {
return getTracingState()->setValue(v, value);
}
void TracingState::setValue(const IValue& v, Value* value) {
if (v.isTensor()) {
auto var = v.toTensor();
AT_ASSERT(var.defined());
getTracingState()->env_stack.back().value_map[var] = value;
env_stack.back().value_map[var] = value;
} else if (v.isTensorList()) {
auto& outputs = v.toTensorList()->elements();
auto graph = getTracingState()->graph;
Node* unpack_node =
graph->insertNode(graph->createListUnpack(value, outputs.size()));
for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]);
setValue(outputs[i], unpack_node->outputs()[i]);
}
} else if (v.isTuple()) {
auto& outputs = v.toTuple()->elements();
auto graph = getTracingState()->graph;
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
for (size_t i = 0; i < outputs.size(); ++i) {
setValueTrace(outputs[i], unpack_node->outputs()[i]);
setValue(outputs[i], unpack_node->outputs()[i]);
}
} else if (v.isGenericList()) {
auto elements = v.toGenericListRef();
auto graph = getTracingState()->graph;
Node* unpack_node =
graph->insertNode(graph->createListUnpack(value, elements.size()));
for (size_t i = 0; i < elements.size(); ++i) {
setValueTrace(elements[i], unpack_node->outputs()[i]);
setValue(elements[i], unpack_node->outputs()[i]);
}
} else if (v.isFuture()) {
auto fut = v.toFuture();
getTracingState()->env_stack.back().future_map[fut] = value;
env_stack.back().future_map[fut] = value;
} else {
std::ostringstream os;
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr<TracingState> state) {
}
TracingState::TracingState()
: env_stack{TracingEnvironmentFrame()}, graph(new Graph()) {}
: graph(new Graph()), env_stack{Frame()} {}
TracingState::~TracingState() = default;

View File

@ -52,6 +52,27 @@ struct TORCH_API TracingState
TracingState();
~TracingState();
std::shared_ptr<Graph> graph;
bool warn = true;
bool force_outplace = false;
std::function<std::string(const Variable& var)> lookup_var_name_fn =
[](const Variable& var) { return ""; };
void enterFrame() {
env_stack.emplace_back();
}
void leaveFrame() {
env_stack.pop_back();
}
void setValue(const IValue& v, Value* value);
void delValue(const IValue& var);
Value* getValue(const IValue& var);
Value* getOutput(const IValue& var);
bool hasValue(const IValue& var) const;
private:
using WeakTensor = at::WeakTensor;
struct WeakTensorHasher {
@ -66,7 +87,7 @@ struct TORCH_API TracingState
}
};
struct TracingEnvironmentFrame {
struct Frame {
std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
value_map;
// TODO weak refcount
@ -74,14 +95,8 @@ struct TORCH_API TracingState
future_map;
};
using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>;
std::vector<Frame> env_stack;
TracingEnvironmentStack env_stack;
std::shared_ptr<Graph> graph;
bool warn = true;
bool force_outplace = false;
std::function<std::string(const Variable& var)> lookup_var_name_fn =
[](const Variable& var) { return ""; };
};
// This is meant to be used as a thread local place, where we can store extra
@ -182,11 +197,11 @@ struct TORCH_API NoWarn {
struct WithNestedTracingFrame {
WithNestedTracingFrame() {
getTracingState()->env_stack.emplace_back();
getTracingState()->enterFrame();
}
~WithNestedTracingFrame() {
getTracingState()->env_stack.pop_back();
getTracingState()->leaveFrame();
}
};
TORCH_API void recordSourceLocation(Node* n);
@ -197,20 +212,12 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
// involving this variable know which node in the IR to reference.
TORCH_API void setValueTrace(const IValue& v, Value* value);
TORCH_API void delValueTrace(const Variable& var);
TORCH_API void delValueTrace(const IValue& var);
TORCH_API std::function<void()> pauseTracing();
TORCH_API Value* getValueTrace(const IValue& var);
TORCH_API Value* getOutputTrace(
const std::shared_ptr<TracingState>& state,
const Variable& var);
TORCH_API Value* getNestedOutputTrace(
const std::shared_ptr<TracingState>& state,
const IValue& iv);
struct TypedStack : public std::pair<Stack, TupleTypePtr>
{
using pair::pair;