Introduce scopes during tracing (#3016)

This commit is contained in:
Luca Antiga 2017-12-04 18:19:06 +01:00 committed by Adam Paszke
parent 7ddcb91c7f
commit 4eb8e12765
11 changed files with 242 additions and 11 deletions

View File

@ -1,6 +1,6 @@
graph(%0 : Double(20, 16, 50, 40)
%1 : Double(13, 16, 3, 3)) {
%2 : UNKNOWN_TYPE = Undefined()
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2)
%2 : UNKNOWN_TYPE = Undefined(), scope: Conv2d
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2), scope: Conv2d
return (%3);
}

View File

@ -1,4 +1,4 @@
graph(%0 : Double(2, 2)) {
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0)
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout
return (%1);
}

View File

@ -0,0 +1,8 @@
graph(%0 : Double(1)
%1 : Double(1)) {
%2 : Double(1) = add[alpha={1}](%0, %1)
%3 : Double(1) = mul(%0, %2), scope: Foo
%4 : Double(1) = tanh(%3), scope: Foo/Bar
%5 : Double(1) = sigmoid(%4), scope: Foo
return (%5);
}

View File

@ -69,6 +69,23 @@ class TestJit(TestCase):
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
def test_scopes(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def f(x, y):
out = x + y
with torch.jit.scope('Foo', out):
out = x * out
with torch.jit.scope('Bar', out):
out = torch.tanh(out)
out = torch.sigmoid(out)
return out
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_lstm_fusion(self):
input = Variable(torch.randn(3, 10).float().cuda())

View File

@ -239,7 +239,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const No
printAttributes(out,n);
}
IR_END()
out << "(" << n->inputs() << ")\n";
out << "(" << n->inputs() << ")";
std::string scopeName = n->scopeName();
if (scopeName.empty()) {
out << "\n";
}
else {
out << ", ";
out << "scope: " << scopeName << "\n";
}
return out;
}

View File

@ -70,6 +70,64 @@ struct SourceLocation {
std::string python_traceback;
};
// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope {
private:
Scope* parent_;
Symbol name_;
std::vector<std::unique_ptr<Scope> > children_;
public:
Scope() {
name_ = stringToSymbol("");
parent_ = NULL;
}
Scope(Scope* parent, Symbol name) {
name_ = name;
parent_ = parent;
}
Scope* push(Symbol name) {
children_.push_back(std::unique_ptr<Scope>(new Scope(this, name)));
return children_.back().get();
}
Scope* parent() {
if (parent_ == NULL) {
throw std::runtime_error("Cannot get parent from Scope with no parent");
}
return parent_;
}
bool isRoot() {
return parent_ == NULL;
}
Scope* getRoot() {
Scope* current = this;
while (current->parent_) {
current = current->parent_;
}
return current;
}
Symbol name() {
return name_;
}
std::string namesFromRoot(const std::string& separator="/") {
std::string out = std::string(symbolToString(this->name_));
if (this->isRoot()) {
return out;
}
Scope* parent = this->parent_;
while (!parent->isRoot()) {
out = std::string(symbolToString(parent->name_)) + separator + out;
parent = parent->parent_;
}
return out;
}
};
// the list types are intentionally simple, but we type-def
// them here so if we need to change them, refactoring will be easier
using node_list = std::vector<Node*>;
@ -139,6 +197,9 @@ public:
const Node * node() const {
return node_;
}
Scope* scope();
void setScope(Scope* scope);
std::string scopeName() const;
Graph * owningGraph();
const Graph * owningGraph() const;
// TODO: make this more const correct
@ -197,6 +258,7 @@ private:
Graph* graph_;
std::shared_ptr<SourceLocation> source_location_;
size_t stage_;
Scope* scope_;
protected:
Node(Graph * graph_, NodeKind kind_); //defined after graph
public:
@ -223,6 +285,18 @@ public:
stage_ = s;
return this;
}
Scope* scope() {
return scope_;
}
void setScope(Scope* scope) {
scope_ = scope;
}
std::string scopeName() const {
if (scope_ == NULL) {
return "";
}
return scope_->namesFromRoot();
}
// NB: This returns an ArrayRef; that means that it will
// get invalidated if you resize inputs (e.g., using addInput)
// We can't return a std::vector<Node*>& because there's no
@ -534,6 +608,7 @@ protected:
// if you are going to preserve it.
virtual void cloneFrom(Node * s) {
setSourceLocation(s->getSourceLocation());
scope_ = s->scope_;
copyAttributes(*s);
}
};
@ -556,6 +631,9 @@ private:
size_t new_node_stage_;
std::shared_ptr<Scope> scope_root_;
Scope * current_scope_;
// holds outputs in a way that can be reflected
// as a Use object
// also used as the beginning/end of the circular node list to avoid
@ -564,11 +642,17 @@ private:
Node * const input_;
public:
Graph()
Graph(std::shared_ptr<Scope> scope_root)
: next_unique_(0)
, new_node_stage_(0)
, scope_root_(scope_root)
, current_scope_(scope_root_.get())
, output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {}
Graph()
: Graph( std::make_shared<Scope>()) {}
at::ArrayRef<Value*> inputs() {
return input_->outputs();
}
@ -621,6 +705,18 @@ public:
const Node * return_node() const {
return output_;
}
void push_scope(const std::string& scope_name) {
current_scope_ = current_scope_->push(stringToSymbol(scope_name));
}
void pop_scope() {
current_scope_ = current_scope_->parent();
}
Scope * current_scope() {
return current_scope_;
}
std::shared_ptr<Scope> scope_root() {
return scope_root_;
}
Value * addInput(std::string name="") {
Value * v = input_->addOutput();
if (name != "") v->setUniqueName(name);
@ -676,7 +772,8 @@ public:
}
Node * createFusionGroup() {
auto n = create(kFusionGroup, 0);
n->g_(kSubgraph,std::make_shared<Graph>());
auto subgraph = std::make_shared<Graph>(scope_root_);
n->g_(kSubgraph, subgraph);
return n;
}
Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args);
@ -759,6 +856,18 @@ inline Value::Value(Node * node_, size_t offset_)
node_->graph_->all_values.emplace(this);
}
inline Scope* Value::scope() {
return node()->scope();
}
inline void Value::setScope(Scope* scope) {
node()->setScope(scope);
}
inline std::string Value::scopeName() const {
return node()->scopeName();
}
inline Graph * Value::owningGraph() {
return node()->owningGraph();
}
@ -779,7 +888,8 @@ inline void Value::replaceAllUsesWith(Value * newValue) {
inline Node::Node(Graph * graph_, NodeKind kind_) :
kind_(kind_),
graph_(graph_),
stage_(graph_->new_node_stage_) {
stage_(graph_->new_node_stage_),
scope_(graph_->current_scope_) {
graph_->all_nodes.emplace(this);
}

View File

@ -31,7 +31,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
throw std::logic_error("ToONNX: tracing state is expired");
}
auto new_graph = std::make_shared<Graph>();
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
std::unordered_map<void*, Value*> new_buffer_map;
torch::autograd::SymbolicContext ctx;
@ -137,6 +137,10 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
throw std::runtime_error(ss.str());
}
for (auto& el: outputs) {
el->setScope(n->scope());
}
setOutputs(op_name, n, outputs);
};
@ -208,6 +212,9 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
IR_IFM(node, CppOp)
if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn), node->getSourceLocation());
for (auto& el: outputs) {
el->setScope(node->scope());
}
setOutputs(value->name(), node, outputs);
} else {
cloneNode(node);

View File

@ -19,7 +19,7 @@ namespace torch { namespace jit {
void initPythonTracerBindings(PyObject* module_) {
auto m = py::handle(module_).cast<py::module>();
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState")
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
// NB: no constructor; you have to get it from C++ code
.def("__repr__", [](const TracingState& s) {
std::ostringstream ss;
@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) {
ss << *s.graph;
return ss.str();
})
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
ASSERT_UNEXPIRED("push_scope");
s.push_scope(scope_name);
})
.def("pop_scope", [](TracingState& s) {
ASSERT_UNEXPIRED("pop_scope");
s.pop_scope();
})
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
ASSERT_UNEXPIRED("export");
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) {
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);
});
m.def("_get_tracing_state", [](const variable_list& vars) {
return getTracingState(vars);
});
m.def("_is_tracing", [](const variable_list& vars) {
return isTracing(vars);
});
}
}} // namespace torch::jit

View File

@ -74,6 +74,14 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
bool is_complete() const {
return !is_expired() && graph->stage() == num_stages - 1;
}
void push_scope(const std::string& scope_name) {
graph->push_scope(scope_name);
}
void pop_scope() {
graph->pop_scope();
}
};
struct ValueTracingStateElem {

View File

@ -19,6 +19,30 @@ import copy
_flatten = torch._C._jit_flatten
# This global variable is set when we are tracing a *forwards* computation.
# It is intended to be a cheap way to test if tracing has occurred, before
# doing the slower path using `get_tracing_state` (below.)
_tracing = False
def get_tracing_state(args):
if not torch._C._is_tracing(args):
return None
return torch._C._get_tracing_state(args)
@contextlib.contextmanager
def scope(scope_name, *vars):
tracing_state = get_tracing_state(vars)
if tracing_state:
tracing_state.push_scope(scope_name)
try:
yield
finally:
if tracing_state:
tracing_state.pop_scope()
def compile(arg=None, nderivs=1, optimize=True, enabled=True):
"""
Decorator which marks a function or module class as eligible for
@ -237,13 +261,16 @@ class TracedModule(Module):
self.nderivs = nderivs
def forward(self, *args):
global _tracing
in_vars = _flatten(args)
# NOTE: use full state, because we need it for BatchNorm export
# This differs from the compiler path, which doesn't support it at the moment.
module_state = list(self.state_dict(keep_vars=True).values())
trace = torch._C._tracer_enter(in_vars + module_state, self.nderivs)
_tracing = True
out = self.inner(*args)
out_vars = _flatten(out)
_tracing = False
torch._C._tracer_exit(out_vars)
return trace, out

View File

@ -1,4 +1,4 @@
from collections import OrderedDict
from collections import OrderedDict, Iterable
import functools
import torch
@ -319,10 +319,42 @@ class Module(object):
self._forward_hooks[handle.id] = hook
return handle
def _tracing_name(self, tracing_state):
if not tracing_state._traced_module_stack:
return None
module = tracing_state._traced_module_stack[-1]
for name, child in module.named_children():
if child is self:
return name
return None
def _slow_forward(self, *input, **kwargs):
input_vars = tuple(torch.autograd.function._iter_variables(input))
tracing_state = torch.jit.get_tracing_state(input_vars)
if not tracing_state:
return self.forward(*input, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = self._tracing_name(tracing_state)
if name:
tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name))
else:
tracing_state.push_scope(self.__class__.__name__)
tracing_state._traced_module_stack.append(self)
try:
result = self.forward(*input, **kwargs)
finally:
tracing_state.pop_scope()
tracing_state._traced_module_stack.pop()
return result
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
result = self.forward(*input, **kwargs)
if torch.jit._tracing:
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None: