mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
clean up the TracingState API (#21514)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21514 ghimport-source-id: 6a9b6fdd7e696ea29e8715482708efe897230e4d Reviewed By: jamesr66a Differential Revision: D15719980 Pulled By: zdevito fbshipit-source-id: 3de2746c3f3c3de4111b4cb73f4c4acedbf28862
This commit is contained in:
parent
8c5f3acfc0
commit
dd0faf4366
|
|
@ -52,19 +52,21 @@ std::function<void()> pauseTracing() {
|
|||
return [state]() { tracer::setTracingState(state); };
|
||||
}
|
||||
|
||||
void delValueTrace(const Variable& var) {
|
||||
AT_ASSERT(var.defined());
|
||||
auto& env_stack = getTracingState()->env_stack;
|
||||
void delValueTrace(const IValue& var) {
|
||||
getTracingState()->delValue(var);
|
||||
}
|
||||
void TracingState::delValue(const IValue& var) {
|
||||
at::Tensor t = var.toTensor();
|
||||
AT_ASSERT(t.defined());
|
||||
for (size_t i = 0; i < env_stack.size(); ++i) {
|
||||
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
|
||||
|
||||
auto it = value_map.find(var);
|
||||
auto it = value_map.find(t);
|
||||
if (it == value_map.end()) {
|
||||
continue;
|
||||
}
|
||||
value_map.erase(it);
|
||||
}
|
||||
getTracingState()->env_stack.back().value_map.erase(var);
|
||||
}
|
||||
|
||||
// Given a IValue 'var', return the 'node' which represents the instruction
|
||||
|
|
@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) {
|
|||
// zero. This is one of the cases where a Variable can be created inside of a
|
||||
// trace, and if we treat it as a constant, everything will work out.
|
||||
Value* getValueTrace(const IValue& var) {
|
||||
auto& state = getTracingState();
|
||||
auto& env_stack = getTracingState()->env_stack;
|
||||
|
||||
return getTracingState()->getValue(var);
|
||||
}
|
||||
Value* TracingState::getValue(const IValue& var) {
|
||||
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
|
||||
if (var.isTensorList()) {
|
||||
return state->graph
|
||||
->insertNode(state->graph->createList(
|
||||
return graph
|
||||
->insertNode(graph->createList(
|
||||
TensorType::get(),
|
||||
fmap(
|
||||
var.toTensorListRef(),
|
||||
[](const IValue& val) { return getValueTrace(val); })))
|
||||
->output();
|
||||
} else if (var.isTuple()) {
|
||||
return state->graph
|
||||
->insertNode(state->graph->createTuple(fmap(
|
||||
return graph
|
||||
->insertNode(graph->createTuple(fmap(
|
||||
var.toTuple()->elements(),
|
||||
[](const IValue& val) { return getValueTrace(val); })))
|
||||
->output();
|
||||
} if (var.isTensor()) {
|
||||
auto ten = var.toTensor();
|
||||
if (!ten.defined()) {
|
||||
Node* n = state->graph->createNone(TensorType::get());
|
||||
return state->graph->insertNode(n)->output();
|
||||
Node* n = graph->createNone(TensorType::get());
|
||||
return graph->insertNode(n)->output();
|
||||
}
|
||||
for (size_t i = 0; i < env_stack.size(); ++i) {
|
||||
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
|
||||
|
|
@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) {
|
|||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
|
||||
Value* constant = state->graph->insertConstant(ten);
|
||||
Value* constant = graph->insertConstant(ten);
|
||||
recordSourceLocation(constant->node());
|
||||
constant->inferTypeFrom(ten);
|
||||
auto it = env_stack.back().value_map.find(ten);
|
||||
|
|
@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) {
|
|||
} else {
|
||||
// If the values are non-tensors, we try to create constants
|
||||
// and bake those constants into the traced graph
|
||||
auto constant = tryInsertConstant(*state->graph, var);
|
||||
auto constant = tryInsertConstant(*graph, var);
|
||||
if (constant) {
|
||||
recordSourceLocation(constant.value()->node());
|
||||
return *constant;
|
||||
|
|
@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) {
|
|||
throw std::runtime_error(os.str());
|
||||
}
|
||||
}
|
||||
|
||||
Value* getOutputTrace(
|
||||
const std::shared_ptr<TracingState>& state,
|
||||
const Variable& var) {
|
||||
if (!var.defined()) {
|
||||
Node* n = state->graph->createNone(TensorType::get());
|
||||
return state->graph->insertNode(n)->output();
|
||||
bool TracingState::hasValue(const IValue& var) const {
|
||||
if (var.isTensor()) {
|
||||
at::Tensor t = var.toTensor();
|
||||
for(const auto & frame : env_stack) {
|
||||
if (frame.value_map.count(t)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& value_map = getTracingState()->env_stack.back().value_map;
|
||||
auto it = value_map.find(var);
|
||||
if (it == value_map.end()) {
|
||||
std::ostringstream os;
|
||||
os << "output of traced region did not have observable "
|
||||
<< "data dependence with trace inputs; this probably indicates your program "
|
||||
<< "cannot be understood by the tracer.";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
return it->second;
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* getNestedOutputTrace(
|
||||
const std::shared_ptr<TracingState>& state,
|
||||
const IValue& iv) {
|
||||
if (iv.isTensor()) {
|
||||
return getOutputTrace(state, iv.toTensor());
|
||||
|
||||
Value* TracingState::getOutput(const IValue& iv) {
|
||||
if (iv.isTensor()) {
|
||||
at::Tensor var = iv.toTensor();
|
||||
if (!var.defined()) {
|
||||
Node *n = graph->createNone(TensorType::get());
|
||||
return graph->insertNode(n)->output();
|
||||
}
|
||||
|
||||
auto &value_map = getTracingState()->env_stack.back().value_map;
|
||||
auto it = value_map.find(var);
|
||||
if (it == value_map.end()) {
|
||||
std::ostringstream os;
|
||||
os << "output of traced region did not have observable "
|
||||
<< "data dependence with trace inputs; this probably indicates your "
|
||||
"program "
|
||||
<< "cannot be understood by the tracer.";
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
return it->second;
|
||||
} else if (iv.isTuple()) {
|
||||
const auto& elems = iv.toTuple()->elements();
|
||||
auto tuple_node =
|
||||
state->graph->createTuple(fmap(elems, [&state](const IValue& ival) {
|
||||
return getNestedOutputTrace(state, ival);
|
||||
graph->createTuple(fmap(elems, [&](const IValue& ival) {
|
||||
return getOutput(ival);
|
||||
}));
|
||||
state->graph->insertNode(tuple_node);
|
||||
graph->insertNode(tuple_node);
|
||||
return tuple_node->output();
|
||||
} else {
|
||||
AT_ERROR(
|
||||
|
|
@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
|
|||
if (type->isSubtypeOf(TensorType::get())) {
|
||||
auto input_tensor = input.toTensor();
|
||||
auto name = Variable(input_tensor).name();
|
||||
auto& value_map = state->env_stack.back().value_map;
|
||||
if (value_map.find(input_tensor) != value_map.end()) {
|
||||
if (state->hasValue(input)) {
|
||||
input_tensor = input_tensor.view(input_tensor.sizes());
|
||||
}
|
||||
value->setUniqueName(name);
|
||||
value_map[input_tensor] = value;
|
||||
state->setValue(input_tensor, value);
|
||||
return input_tensor;
|
||||
} else if (auto tuple_type = type->cast<TupleType>()) {
|
||||
auto unpack_node =
|
||||
|
|
@ -330,7 +337,7 @@ void exit(const Stack& outputs) {
|
|||
auto& state = getTracingState();
|
||||
size_t i = 0;
|
||||
for (auto& output : outputs) {
|
||||
state->graph->registerOutput(getNestedOutputTrace(state, output));
|
||||
state->graph->registerOutput(state->getOutput(output));
|
||||
i++;
|
||||
}
|
||||
setTracingState(nullptr);
|
||||
|
|
@ -342,36 +349,36 @@ void abandon() {
|
|||
}
|
||||
|
||||
void setValueTrace(const IValue& v, Value* value) {
|
||||
return getTracingState()->setValue(v, value);
|
||||
}
|
||||
void TracingState::setValue(const IValue& v, Value* value) {
|
||||
if (v.isTensor()) {
|
||||
auto var = v.toTensor();
|
||||
AT_ASSERT(var.defined());
|
||||
getTracingState()->env_stack.back().value_map[var] = value;
|
||||
env_stack.back().value_map[var] = value;
|
||||
} else if (v.isTensorList()) {
|
||||
auto& outputs = v.toTensorList()->elements();
|
||||
auto graph = getTracingState()->graph;
|
||||
Node* unpack_node =
|
||||
graph->insertNode(graph->createListUnpack(value, outputs.size()));
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
setValueTrace(outputs[i], unpack_node->outputs()[i]);
|
||||
setValue(outputs[i], unpack_node->outputs()[i]);
|
||||
}
|
||||
} else if (v.isTuple()) {
|
||||
auto& outputs = v.toTuple()->elements();
|
||||
auto graph = getTracingState()->graph;
|
||||
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
setValueTrace(outputs[i], unpack_node->outputs()[i]);
|
||||
setValue(outputs[i], unpack_node->outputs()[i]);
|
||||
}
|
||||
} else if (v.isGenericList()) {
|
||||
auto elements = v.toGenericListRef();
|
||||
auto graph = getTracingState()->graph;
|
||||
Node* unpack_node =
|
||||
graph->insertNode(graph->createListUnpack(value, elements.size()));
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
setValueTrace(elements[i], unpack_node->outputs()[i]);
|
||||
setValue(elements[i], unpack_node->outputs()[i]);
|
||||
}
|
||||
} else if (v.isFuture()) {
|
||||
auto fut = v.toFuture();
|
||||
getTracingState()->env_stack.back().future_map[fut] = value;
|
||||
env_stack.back().future_map[fut] = value;
|
||||
} else {
|
||||
std::ostringstream os;
|
||||
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
|
||||
|
|
@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr<TracingState> state) {
|
|||
}
|
||||
|
||||
TracingState::TracingState()
|
||||
: env_stack{TracingEnvironmentFrame()}, graph(new Graph()) {}
|
||||
: graph(new Graph()), env_stack{Frame()} {}
|
||||
|
||||
TracingState::~TracingState() = default;
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,27 @@ struct TORCH_API TracingState
|
|||
TracingState();
|
||||
~TracingState();
|
||||
|
||||
std::shared_ptr<Graph> graph;
|
||||
bool warn = true;
|
||||
bool force_outplace = false;
|
||||
std::function<std::string(const Variable& var)> lookup_var_name_fn =
|
||||
[](const Variable& var) { return ""; };
|
||||
|
||||
void enterFrame() {
|
||||
env_stack.emplace_back();
|
||||
}
|
||||
|
||||
void leaveFrame() {
|
||||
env_stack.pop_back();
|
||||
}
|
||||
|
||||
void setValue(const IValue& v, Value* value);
|
||||
void delValue(const IValue& var);
|
||||
Value* getValue(const IValue& var);
|
||||
Value* getOutput(const IValue& var);
|
||||
bool hasValue(const IValue& var) const;
|
||||
|
||||
private:
|
||||
using WeakTensor = at::WeakTensor;
|
||||
|
||||
struct WeakTensorHasher {
|
||||
|
|
@ -66,7 +87,7 @@ struct TORCH_API TracingState
|
|||
}
|
||||
};
|
||||
|
||||
struct TracingEnvironmentFrame {
|
||||
struct Frame {
|
||||
std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
|
||||
value_map;
|
||||
// TODO weak refcount
|
||||
|
|
@ -74,14 +95,8 @@ struct TORCH_API TracingState
|
|||
future_map;
|
||||
};
|
||||
|
||||
using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>;
|
||||
std::vector<Frame> env_stack;
|
||||
|
||||
TracingEnvironmentStack env_stack;
|
||||
std::shared_ptr<Graph> graph;
|
||||
bool warn = true;
|
||||
bool force_outplace = false;
|
||||
std::function<std::string(const Variable& var)> lookup_var_name_fn =
|
||||
[](const Variable& var) { return ""; };
|
||||
};
|
||||
|
||||
// This is meant to be used as a thread local place, where we can store extra
|
||||
|
|
@ -182,11 +197,11 @@ struct TORCH_API NoWarn {
|
|||
|
||||
struct WithNestedTracingFrame {
|
||||
WithNestedTracingFrame() {
|
||||
getTracingState()->env_stack.emplace_back();
|
||||
getTracingState()->enterFrame();
|
||||
}
|
||||
|
||||
~WithNestedTracingFrame() {
|
||||
getTracingState()->env_stack.pop_back();
|
||||
getTracingState()->leaveFrame();
|
||||
}
|
||||
};
|
||||
TORCH_API void recordSourceLocation(Node* n);
|
||||
|
|
@ -197,20 +212,12 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
|
|||
// involving this variable know which node in the IR to reference.
|
||||
TORCH_API void setValueTrace(const IValue& v, Value* value);
|
||||
|
||||
TORCH_API void delValueTrace(const Variable& var);
|
||||
TORCH_API void delValueTrace(const IValue& var);
|
||||
|
||||
TORCH_API std::function<void()> pauseTracing();
|
||||
|
||||
TORCH_API Value* getValueTrace(const IValue& var);
|
||||
|
||||
TORCH_API Value* getOutputTrace(
|
||||
const std::shared_ptr<TracingState>& state,
|
||||
const Variable& var);
|
||||
|
||||
TORCH_API Value* getNestedOutputTrace(
|
||||
const std::shared_ptr<TracingState>& state,
|
||||
const IValue& iv);
|
||||
|
||||
struct TypedStack : public std::pair<Stack, TupleTypePtr>
|
||||
{
|
||||
using pair::pair;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user