mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
clean up the TracingState API (#21514)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21514 ghimport-source-id: 6a9b6fdd7e696ea29e8715482708efe897230e4d Reviewed By: jamesr66a Differential Revision: D15719980 Pulled By: zdevito fbshipit-source-id: 3de2746c3f3c3de4111b4cb73f4c4acedbf28862
This commit is contained in:
parent
8c5f3acfc0
commit
dd0faf4366
|
|
@ -52,19 +52,21 @@ std::function<void()> pauseTracing() {
|
||||||
return [state]() { tracer::setTracingState(state); };
|
return [state]() { tracer::setTracingState(state); };
|
||||||
}
|
}
|
||||||
|
|
||||||
void delValueTrace(const Variable& var) {
|
void delValueTrace(const IValue& var) {
|
||||||
AT_ASSERT(var.defined());
|
getTracingState()->delValue(var);
|
||||||
auto& env_stack = getTracingState()->env_stack;
|
}
|
||||||
|
void TracingState::delValue(const IValue& var) {
|
||||||
|
at::Tensor t = var.toTensor();
|
||||||
|
AT_ASSERT(t.defined());
|
||||||
for (size_t i = 0; i < env_stack.size(); ++i) {
|
for (size_t i = 0; i < env_stack.size(); ++i) {
|
||||||
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
|
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
|
||||||
|
|
||||||
auto it = value_map.find(var);
|
auto it = value_map.find(t);
|
||||||
if (it == value_map.end()) {
|
if (it == value_map.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
value_map.erase(it);
|
value_map.erase(it);
|
||||||
}
|
}
|
||||||
getTracingState()->env_stack.back().value_map.erase(var);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a IValue 'var', return the 'node' which represents the instruction
|
// Given a IValue 'var', return the 'node' which represents the instruction
|
||||||
|
|
@ -82,29 +84,29 @@ void delValueTrace(const Variable& var) {
|
||||||
// zero. This is one of the cases where a Variable can be created inside of a
|
// zero. This is one of the cases where a Variable can be created inside of a
|
||||||
// trace, and if we treat it as a constant, everything will work out.
|
// trace, and if we treat it as a constant, everything will work out.
|
||||||
Value* getValueTrace(const IValue& var) {
|
Value* getValueTrace(const IValue& var) {
|
||||||
auto& state = getTracingState();
|
return getTracingState()->getValue(var);
|
||||||
auto& env_stack = getTracingState()->env_stack;
|
}
|
||||||
|
Value* TracingState::getValue(const IValue& var) {
|
||||||
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
|
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
|
||||||
if (var.isTensorList()) {
|
if (var.isTensorList()) {
|
||||||
return state->graph
|
return graph
|
||||||
->insertNode(state->graph->createList(
|
->insertNode(graph->createList(
|
||||||
TensorType::get(),
|
TensorType::get(),
|
||||||
fmap(
|
fmap(
|
||||||
var.toTensorListRef(),
|
var.toTensorListRef(),
|
||||||
[](const IValue& val) { return getValueTrace(val); })))
|
[](const IValue& val) { return getValueTrace(val); })))
|
||||||
->output();
|
->output();
|
||||||
} else if (var.isTuple()) {
|
} else if (var.isTuple()) {
|
||||||
return state->graph
|
return graph
|
||||||
->insertNode(state->graph->createTuple(fmap(
|
->insertNode(graph->createTuple(fmap(
|
||||||
var.toTuple()->elements(),
|
var.toTuple()->elements(),
|
||||||
[](const IValue& val) { return getValueTrace(val); })))
|
[](const IValue& val) { return getValueTrace(val); })))
|
||||||
->output();
|
->output();
|
||||||
} if (var.isTensor()) {
|
} if (var.isTensor()) {
|
||||||
auto ten = var.toTensor();
|
auto ten = var.toTensor();
|
||||||
if (!ten.defined()) {
|
if (!ten.defined()) {
|
||||||
Node* n = state->graph->createNone(TensorType::get());
|
Node* n = graph->createNone(TensorType::get());
|
||||||
return state->graph->insertNode(n)->output();
|
return graph->insertNode(n)->output();
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < env_stack.size(); ++i) {
|
for (size_t i = 0; i < env_stack.size(); ++i) {
|
||||||
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
|
auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
|
||||||
|
|
@ -132,7 +134,7 @@ Value* getValueTrace(const IValue& var) {
|
||||||
throw std::runtime_error(oss.str());
|
throw std::runtime_error(oss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value* constant = state->graph->insertConstant(ten);
|
Value* constant = graph->insertConstant(ten);
|
||||||
recordSourceLocation(constant->node());
|
recordSourceLocation(constant->node());
|
||||||
constant->inferTypeFrom(ten);
|
constant->inferTypeFrom(ten);
|
||||||
auto it = env_stack.back().value_map.find(ten);
|
auto it = env_stack.back().value_map.find(ten);
|
||||||
|
|
@ -155,7 +157,7 @@ Value* getValueTrace(const IValue& var) {
|
||||||
} else {
|
} else {
|
||||||
// If the values are non-tensors, we try to create constants
|
// If the values are non-tensors, we try to create constants
|
||||||
// and bake those constants into the traced graph
|
// and bake those constants into the traced graph
|
||||||
auto constant = tryInsertConstant(*state->graph, var);
|
auto constant = tryInsertConstant(*graph, var);
|
||||||
if (constant) {
|
if (constant) {
|
||||||
recordSourceLocation(constant.value()->node());
|
recordSourceLocation(constant.value()->node());
|
||||||
return *constant;
|
return *constant;
|
||||||
|
|
@ -167,39 +169,45 @@ Value* getValueTrace(const IValue& var) {
|
||||||
throw std::runtime_error(os.str());
|
throw std::runtime_error(os.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
bool TracingState::hasValue(const IValue& var) const {
|
||||||
Value* getOutputTrace(
|
if (var.isTensor()) {
|
||||||
const std::shared_ptr<TracingState>& state,
|
at::Tensor t = var.toTensor();
|
||||||
const Variable& var) {
|
for(const auto & frame : env_stack) {
|
||||||
if (!var.defined()) {
|
if (frame.value_map.count(t)) {
|
||||||
Node* n = state->graph->createNone(TensorType::get());
|
return true;
|
||||||
return state->graph->insertNode(n)->output();
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
auto& value_map = getTracingState()->env_stack.back().value_map;
|
|
||||||
auto it = value_map.find(var);
|
|
||||||
if (it == value_map.end()) {
|
|
||||||
std::ostringstream os;
|
|
||||||
os << "output of traced region did not have observable "
|
|
||||||
<< "data dependence with trace inputs; this probably indicates your program "
|
|
||||||
<< "cannot be understood by the tracer.";
|
|
||||||
throw std::runtime_error(os.str());
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value* getNestedOutputTrace(
|
|
||||||
const std::shared_ptr<TracingState>& state,
|
Value* TracingState::getOutput(const IValue& iv) {
|
||||||
const IValue& iv) {
|
if (iv.isTensor()) {
|
||||||
if (iv.isTensor()) {
|
at::Tensor var = iv.toTensor();
|
||||||
return getOutputTrace(state, iv.toTensor());
|
if (!var.defined()) {
|
||||||
|
Node *n = graph->createNone(TensorType::get());
|
||||||
|
return graph->insertNode(n)->output();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &value_map = getTracingState()->env_stack.back().value_map;
|
||||||
|
auto it = value_map.find(var);
|
||||||
|
if (it == value_map.end()) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "output of traced region did not have observable "
|
||||||
|
<< "data dependence with trace inputs; this probably indicates your "
|
||||||
|
"program "
|
||||||
|
<< "cannot be understood by the tracer.";
|
||||||
|
throw std::runtime_error(os.str());
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
} else if (iv.isTuple()) {
|
} else if (iv.isTuple()) {
|
||||||
const auto& elems = iv.toTuple()->elements();
|
const auto& elems = iv.toTuple()->elements();
|
||||||
auto tuple_node =
|
auto tuple_node =
|
||||||
state->graph->createTuple(fmap(elems, [&state](const IValue& ival) {
|
graph->createTuple(fmap(elems, [&](const IValue& ival) {
|
||||||
return getNestedOutputTrace(state, ival);
|
return getOutput(ival);
|
||||||
}));
|
}));
|
||||||
state->graph->insertNode(tuple_node);
|
graph->insertNode(tuple_node);
|
||||||
return tuple_node->output();
|
return tuple_node->output();
|
||||||
} else {
|
} else {
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
|
|
@ -213,12 +221,11 @@ static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue
|
||||||
if (type->isSubtypeOf(TensorType::get())) {
|
if (type->isSubtypeOf(TensorType::get())) {
|
||||||
auto input_tensor = input.toTensor();
|
auto input_tensor = input.toTensor();
|
||||||
auto name = Variable(input_tensor).name();
|
auto name = Variable(input_tensor).name();
|
||||||
auto& value_map = state->env_stack.back().value_map;
|
if (state->hasValue(input)) {
|
||||||
if (value_map.find(input_tensor) != value_map.end()) {
|
|
||||||
input_tensor = input_tensor.view(input_tensor.sizes());
|
input_tensor = input_tensor.view(input_tensor.sizes());
|
||||||
}
|
}
|
||||||
value->setUniqueName(name);
|
value->setUniqueName(name);
|
||||||
value_map[input_tensor] = value;
|
state->setValue(input_tensor, value);
|
||||||
return input_tensor;
|
return input_tensor;
|
||||||
} else if (auto tuple_type = type->cast<TupleType>()) {
|
} else if (auto tuple_type = type->cast<TupleType>()) {
|
||||||
auto unpack_node =
|
auto unpack_node =
|
||||||
|
|
@ -330,7 +337,7 @@ void exit(const Stack& outputs) {
|
||||||
auto& state = getTracingState();
|
auto& state = getTracingState();
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (auto& output : outputs) {
|
for (auto& output : outputs) {
|
||||||
state->graph->registerOutput(getNestedOutputTrace(state, output));
|
state->graph->registerOutput(state->getOutput(output));
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
setTracingState(nullptr);
|
setTracingState(nullptr);
|
||||||
|
|
@ -342,36 +349,36 @@ void abandon() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void setValueTrace(const IValue& v, Value* value) {
|
void setValueTrace(const IValue& v, Value* value) {
|
||||||
|
return getTracingState()->setValue(v, value);
|
||||||
|
}
|
||||||
|
void TracingState::setValue(const IValue& v, Value* value) {
|
||||||
if (v.isTensor()) {
|
if (v.isTensor()) {
|
||||||
auto var = v.toTensor();
|
auto var = v.toTensor();
|
||||||
AT_ASSERT(var.defined());
|
AT_ASSERT(var.defined());
|
||||||
getTracingState()->env_stack.back().value_map[var] = value;
|
env_stack.back().value_map[var] = value;
|
||||||
} else if (v.isTensorList()) {
|
} else if (v.isTensorList()) {
|
||||||
auto& outputs = v.toTensorList()->elements();
|
auto& outputs = v.toTensorList()->elements();
|
||||||
auto graph = getTracingState()->graph;
|
|
||||||
Node* unpack_node =
|
Node* unpack_node =
|
||||||
graph->insertNode(graph->createListUnpack(value, outputs.size()));
|
graph->insertNode(graph->createListUnpack(value, outputs.size()));
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
setValueTrace(outputs[i], unpack_node->outputs()[i]);
|
setValue(outputs[i], unpack_node->outputs()[i]);
|
||||||
}
|
}
|
||||||
} else if (v.isTuple()) {
|
} else if (v.isTuple()) {
|
||||||
auto& outputs = v.toTuple()->elements();
|
auto& outputs = v.toTuple()->elements();
|
||||||
auto graph = getTracingState()->graph;
|
|
||||||
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
|
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
setValueTrace(outputs[i], unpack_node->outputs()[i]);
|
setValue(outputs[i], unpack_node->outputs()[i]);
|
||||||
}
|
}
|
||||||
} else if (v.isGenericList()) {
|
} else if (v.isGenericList()) {
|
||||||
auto elements = v.toGenericListRef();
|
auto elements = v.toGenericListRef();
|
||||||
auto graph = getTracingState()->graph;
|
|
||||||
Node* unpack_node =
|
Node* unpack_node =
|
||||||
graph->insertNode(graph->createListUnpack(value, elements.size()));
|
graph->insertNode(graph->createListUnpack(value, elements.size()));
|
||||||
for (size_t i = 0; i < elements.size(); ++i) {
|
for (size_t i = 0; i < elements.size(); ++i) {
|
||||||
setValueTrace(elements[i], unpack_node->outputs()[i]);
|
setValue(elements[i], unpack_node->outputs()[i]);
|
||||||
}
|
}
|
||||||
} else if (v.isFuture()) {
|
} else if (v.isFuture()) {
|
||||||
auto fut = v.toFuture();
|
auto fut = v.toFuture();
|
||||||
getTracingState()->env_stack.back().future_map[fut] = value;
|
env_stack.back().future_map[fut] = value;
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
|
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
|
||||||
|
|
@ -560,7 +567,7 @@ void setTracingState(std::shared_ptr<TracingState> state) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TracingState::TracingState()
|
TracingState::TracingState()
|
||||||
: env_stack{TracingEnvironmentFrame()}, graph(new Graph()) {}
|
: graph(new Graph()), env_stack{Frame()} {}
|
||||||
|
|
||||||
TracingState::~TracingState() = default;
|
TracingState::~TracingState() = default;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,27 @@ struct TORCH_API TracingState
|
||||||
TracingState();
|
TracingState();
|
||||||
~TracingState();
|
~TracingState();
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> graph;
|
||||||
|
bool warn = true;
|
||||||
|
bool force_outplace = false;
|
||||||
|
std::function<std::string(const Variable& var)> lookup_var_name_fn =
|
||||||
|
[](const Variable& var) { return ""; };
|
||||||
|
|
||||||
|
void enterFrame() {
|
||||||
|
env_stack.emplace_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void leaveFrame() {
|
||||||
|
env_stack.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void setValue(const IValue& v, Value* value);
|
||||||
|
void delValue(const IValue& var);
|
||||||
|
Value* getValue(const IValue& var);
|
||||||
|
Value* getOutput(const IValue& var);
|
||||||
|
bool hasValue(const IValue& var) const;
|
||||||
|
|
||||||
|
private:
|
||||||
using WeakTensor = at::WeakTensor;
|
using WeakTensor = at::WeakTensor;
|
||||||
|
|
||||||
struct WeakTensorHasher {
|
struct WeakTensorHasher {
|
||||||
|
|
@ -66,7 +87,7 @@ struct TORCH_API TracingState
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TracingEnvironmentFrame {
|
struct Frame {
|
||||||
std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
|
std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
|
||||||
value_map;
|
value_map;
|
||||||
// TODO weak refcount
|
// TODO weak refcount
|
||||||
|
|
@ -74,14 +95,8 @@ struct TORCH_API TracingState
|
||||||
future_map;
|
future_map;
|
||||||
};
|
};
|
||||||
|
|
||||||
using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>;
|
std::vector<Frame> env_stack;
|
||||||
|
|
||||||
TracingEnvironmentStack env_stack;
|
|
||||||
std::shared_ptr<Graph> graph;
|
|
||||||
bool warn = true;
|
|
||||||
bool force_outplace = false;
|
|
||||||
std::function<std::string(const Variable& var)> lookup_var_name_fn =
|
|
||||||
[](const Variable& var) { return ""; };
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is meant to be used as a thread local place, where we can store extra
|
// This is meant to be used as a thread local place, where we can store extra
|
||||||
|
|
@ -182,11 +197,11 @@ struct TORCH_API NoWarn {
|
||||||
|
|
||||||
struct WithNestedTracingFrame {
|
struct WithNestedTracingFrame {
|
||||||
WithNestedTracingFrame() {
|
WithNestedTracingFrame() {
|
||||||
getTracingState()->env_stack.emplace_back();
|
getTracingState()->enterFrame();
|
||||||
}
|
}
|
||||||
|
|
||||||
~WithNestedTracingFrame() {
|
~WithNestedTracingFrame() {
|
||||||
getTracingState()->env_stack.pop_back();
|
getTracingState()->leaveFrame();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
TORCH_API void recordSourceLocation(Node* n);
|
TORCH_API void recordSourceLocation(Node* n);
|
||||||
|
|
@ -197,20 +212,12 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
|
||||||
// involving this variable know which node in the IR to reference.
|
// involving this variable know which node in the IR to reference.
|
||||||
TORCH_API void setValueTrace(const IValue& v, Value* value);
|
TORCH_API void setValueTrace(const IValue& v, Value* value);
|
||||||
|
|
||||||
TORCH_API void delValueTrace(const Variable& var);
|
TORCH_API void delValueTrace(const IValue& var);
|
||||||
|
|
||||||
TORCH_API std::function<void()> pauseTracing();
|
TORCH_API std::function<void()> pauseTracing();
|
||||||
|
|
||||||
TORCH_API Value* getValueTrace(const IValue& var);
|
TORCH_API Value* getValueTrace(const IValue& var);
|
||||||
|
|
||||||
TORCH_API Value* getOutputTrace(
|
|
||||||
const std::shared_ptr<TracingState>& state,
|
|
||||||
const Variable& var);
|
|
||||||
|
|
||||||
TORCH_API Value* getNestedOutputTrace(
|
|
||||||
const std::shared_ptr<TracingState>& state,
|
|
||||||
const IValue& iv);
|
|
||||||
|
|
||||||
struct TypedStack : public std::pair<Stack, TupleTypePtr>
|
struct TypedStack : public std::pair<Stack, TupleTypePtr>
|
||||||
{
|
{
|
||||||
using pair::pair;
|
using pair::pair;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user