mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Introduce scopes during tracing (#3016)
This commit is contained in:
parent
7ddcb91c7f
commit
4eb8e12765
|
|
@ -1,6 +1,6 @@
|
|||
graph(%0 : Double(20, 16, 50, 40)
|
||||
%1 : Double(13, 16, 3, 3)) {
|
||||
%2 : UNKNOWN_TYPE = Undefined()
|
||||
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2)
|
||||
%2 : UNKNOWN_TYPE = Undefined(), scope: Conv2d
|
||||
%3 : Double(20, 13, 48, 38), %4 : Handle = CppOp[ConvForward](%0, %1, %2), scope: Conv2d
|
||||
return (%3);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
graph(%0 : Double(2, 2)) {
|
||||
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0)
|
||||
%1 : Double(2, 2), %2 : Handle = ^Dropout(0.6, True, False)(%0), scope: Dropout
|
||||
return (%1);
|
||||
}
|
||||
|
|
|
|||
8
test/expect/TestJit.test_scopes.expect
Normal file
8
test/expect/TestJit.test_scopes.expect
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
graph(%0 : Double(1)
|
||||
%1 : Double(1)) {
|
||||
%2 : Double(1) = add[alpha={1}](%0, %1)
|
||||
%3 : Double(1) = mul(%0, %2), scope: Foo
|
||||
%4 : Double(1) = tanh(%3), scope: Foo/Bar
|
||||
%5 : Double(1) = sigmoid(%4), scope: Foo
|
||||
return (%5);
|
||||
}
|
||||
|
|
@ -69,6 +69,23 @@ class TestJit(TestCase):
|
|||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace))
|
||||
|
||||
def test_scopes(self):
|
||||
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
||||
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
||||
|
||||
def f(x, y):
|
||||
out = x + y
|
||||
with torch.jit.scope('Foo', out):
|
||||
out = x * out
|
||||
with torch.jit.scope('Bar', out):
|
||||
out = torch.tanh(out)
|
||||
out = torch.sigmoid(out)
|
||||
return out
|
||||
|
||||
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
def test_lstm_fusion(self):
|
||||
input = Variable(torch.randn(3, 10).float().cuda())
|
||||
|
|
|
|||
|
|
@ -239,7 +239,15 @@ std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const No
|
|||
printAttributes(out,n);
|
||||
}
|
||||
IR_END()
|
||||
out << "(" << n->inputs() << ")\n";
|
||||
out << "(" << n->inputs() << ")";
|
||||
std::string scopeName = n->scopeName();
|
||||
if (scopeName.empty()) {
|
||||
out << "\n";
|
||||
}
|
||||
else {
|
||||
out << ", ";
|
||||
out << "scope: " << scopeName << "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -70,6 +70,64 @@ struct SourceLocation {
|
|||
std::string python_traceback;
|
||||
};
|
||||
|
||||
// Scope is a node of a trie that represents the tree of nested scopes.
|
||||
// Individual scopes are pushed and popped from Graph, which holds a
|
||||
// pointer to the current scope. Each Node in Graph holds a pointer
|
||||
// to the scope that was current when the node was created.
|
||||
// The trie never needs to shrink, it only grows until it is disposed
|
||||
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
|
||||
// will always be valid as long as Graph is alive.
|
||||
struct Scope {
|
||||
private:
|
||||
Scope* parent_;
|
||||
Symbol name_;
|
||||
std::vector<std::unique_ptr<Scope> > children_;
|
||||
public:
|
||||
Scope() {
|
||||
name_ = stringToSymbol("");
|
||||
parent_ = NULL;
|
||||
}
|
||||
Scope(Scope* parent, Symbol name) {
|
||||
name_ = name;
|
||||
parent_ = parent;
|
||||
}
|
||||
Scope* push(Symbol name) {
|
||||
children_.push_back(std::unique_ptr<Scope>(new Scope(this, name)));
|
||||
return children_.back().get();
|
||||
}
|
||||
Scope* parent() {
|
||||
if (parent_ == NULL) {
|
||||
throw std::runtime_error("Cannot get parent from Scope with no parent");
|
||||
}
|
||||
return parent_;
|
||||
}
|
||||
bool isRoot() {
|
||||
return parent_ == NULL;
|
||||
}
|
||||
Scope* getRoot() {
|
||||
Scope* current = this;
|
||||
while (current->parent_) {
|
||||
current = current->parent_;
|
||||
}
|
||||
return current;
|
||||
}
|
||||
Symbol name() {
|
||||
return name_;
|
||||
}
|
||||
std::string namesFromRoot(const std::string& separator="/") {
|
||||
std::string out = std::string(symbolToString(this->name_));
|
||||
if (this->isRoot()) {
|
||||
return out;
|
||||
}
|
||||
Scope* parent = this->parent_;
|
||||
while (!parent->isRoot()) {
|
||||
out = std::string(symbolToString(parent->name_)) + separator + out;
|
||||
parent = parent->parent_;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// the list types are intentionally simple, but we type-def
|
||||
// them here so if we need to change them, refactoring will be easier
|
||||
using node_list = std::vector<Node*>;
|
||||
|
|
@ -139,6 +197,9 @@ public:
|
|||
const Node * node() const {
|
||||
return node_;
|
||||
}
|
||||
Scope* scope();
|
||||
void setScope(Scope* scope);
|
||||
std::string scopeName() const;
|
||||
Graph * owningGraph();
|
||||
const Graph * owningGraph() const;
|
||||
// TODO: make this more const correct
|
||||
|
|
@ -197,6 +258,7 @@ private:
|
|||
Graph* graph_;
|
||||
std::shared_ptr<SourceLocation> source_location_;
|
||||
size_t stage_;
|
||||
Scope* scope_;
|
||||
protected:
|
||||
Node(Graph * graph_, NodeKind kind_); //defined after graph
|
||||
public:
|
||||
|
|
@ -223,6 +285,18 @@ public:
|
|||
stage_ = s;
|
||||
return this;
|
||||
}
|
||||
Scope* scope() {
|
||||
return scope_;
|
||||
}
|
||||
void setScope(Scope* scope) {
|
||||
scope_ = scope;
|
||||
}
|
||||
std::string scopeName() const {
|
||||
if (scope_ == NULL) {
|
||||
return "";
|
||||
}
|
||||
return scope_->namesFromRoot();
|
||||
}
|
||||
// NB: This returns an ArrayRef; that means that it will
|
||||
// get invalidated if you resize inputs (e.g., using addInput)
|
||||
// We can't return a std::vector<Node*>& because there's no
|
||||
|
|
@ -534,6 +608,7 @@ protected:
|
|||
// if you are going to preserve it.
|
||||
virtual void cloneFrom(Node * s) {
|
||||
setSourceLocation(s->getSourceLocation());
|
||||
scope_ = s->scope_;
|
||||
copyAttributes(*s);
|
||||
}
|
||||
};
|
||||
|
|
@ -556,6 +631,9 @@ private:
|
|||
|
||||
size_t new_node_stage_;
|
||||
|
||||
std::shared_ptr<Scope> scope_root_;
|
||||
Scope * current_scope_;
|
||||
|
||||
// holds outputs in a way that can be reflected
|
||||
// as a Use object
|
||||
// also used as the beginning/end of the circular node list to avoid
|
||||
|
|
@ -564,11 +642,17 @@ private:
|
|||
Node * const input_;
|
||||
|
||||
public:
|
||||
Graph()
|
||||
|
||||
Graph(std::shared_ptr<Scope> scope_root)
|
||||
: next_unique_(0)
|
||||
, new_node_stage_(0)
|
||||
, scope_root_(scope_root)
|
||||
, current_scope_(scope_root_.get())
|
||||
, output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {}
|
||||
|
||||
Graph()
|
||||
: Graph( std::make_shared<Scope>()) {}
|
||||
|
||||
at::ArrayRef<Value*> inputs() {
|
||||
return input_->outputs();
|
||||
}
|
||||
|
|
@ -621,6 +705,18 @@ public:
|
|||
const Node * return_node() const {
|
||||
return output_;
|
||||
}
|
||||
void push_scope(const std::string& scope_name) {
|
||||
current_scope_ = current_scope_->push(stringToSymbol(scope_name));
|
||||
}
|
||||
void pop_scope() {
|
||||
current_scope_ = current_scope_->parent();
|
||||
}
|
||||
Scope * current_scope() {
|
||||
return current_scope_;
|
||||
}
|
||||
std::shared_ptr<Scope> scope_root() {
|
||||
return scope_root_;
|
||||
}
|
||||
Value * addInput(std::string name="") {
|
||||
Value * v = input_->addOutput();
|
||||
if (name != "") v->setUniqueName(name);
|
||||
|
|
@ -676,7 +772,8 @@ public:
|
|||
}
|
||||
Node * createFusionGroup() {
|
||||
auto n = create(kFusionGroup, 0);
|
||||
n->g_(kSubgraph,std::make_shared<Graph>());
|
||||
auto subgraph = std::make_shared<Graph>(scope_root_);
|
||||
n->g_(kSubgraph, subgraph);
|
||||
return n;
|
||||
}
|
||||
Node * createPythonOp(THPObjectPtr&& pyobj, const std::string & cconv, bool is_legacy, std::vector<VariableFlags> && var_flags, pyobj_list&& scalar_args);
|
||||
|
|
@ -759,6 +856,18 @@ inline Value::Value(Node * node_, size_t offset_)
|
|||
node_->graph_->all_values.emplace(this);
|
||||
}
|
||||
|
||||
inline Scope* Value::scope() {
|
||||
return node()->scope();
|
||||
}
|
||||
|
||||
inline void Value::setScope(Scope* scope) {
|
||||
node()->setScope(scope);
|
||||
}
|
||||
|
||||
inline std::string Value::scopeName() const {
|
||||
return node()->scopeName();
|
||||
}
|
||||
|
||||
inline Graph * Value::owningGraph() {
|
||||
return node()->owningGraph();
|
||||
}
|
||||
|
|
@ -779,7 +888,8 @@ inline void Value::replaceAllUsesWith(Value * newValue) {
|
|||
inline Node::Node(Graph * graph_, NodeKind kind_) :
|
||||
kind_(kind_),
|
||||
graph_(graph_),
|
||||
stage_(graph_->new_node_stage_) {
|
||||
stage_(graph_->new_node_stage_),
|
||||
scope_(graph_->current_scope_) {
|
||||
graph_->all_nodes.emplace(this);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
|
|||
throw std::logic_error("ToONNX: tracing state is expired");
|
||||
}
|
||||
|
||||
auto new_graph = std::make_shared<Graph>();
|
||||
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
|
||||
std::unordered_map<void*, Value*> new_buffer_map;
|
||||
|
||||
torch::autograd::SymbolicContext ctx;
|
||||
|
|
@ -137,6 +137,10 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
|
|||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
for (auto& el: outputs) {
|
||||
el->setScope(n->scope());
|
||||
}
|
||||
|
||||
setOutputs(op_name, n, outputs);
|
||||
};
|
||||
|
||||
|
|
@ -208,6 +212,9 @@ void ToONNX(std::shared_ptr<tracer::TracingState>& state) {
|
|||
IR_IFM(node, CppOp)
|
||||
if (auto fn = std::dynamic_pointer_cast<autograd::HasSymbolic>(value->fn)) {
|
||||
auto outputs = fn->symbolic(&ctx, fmap(node->inputs(), envFn), node->getSourceLocation());
|
||||
for (auto& el: outputs) {
|
||||
el->setScope(node->scope());
|
||||
}
|
||||
setOutputs(value->name(), node, outputs);
|
||||
} else {
|
||||
cloneNode(node);
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ namespace torch { namespace jit {
|
|||
|
||||
void initPythonTracerBindings(PyObject* module_) {
|
||||
auto m = py::handle(module_).cast<py::module>();
|
||||
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState")
|
||||
py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
|
||||
// NB: no constructor; you have to get it from C++ code
|
||||
.def("__repr__", [](const TracingState& s) {
|
||||
std::ostringstream ss;
|
||||
|
|
@ -32,6 +32,14 @@ void initPythonTracerBindings(PyObject* module_) {
|
|||
ss << *s.graph;
|
||||
return ss.str();
|
||||
})
|
||||
.def("push_scope", [](TracingState& s, const std::string& scope_name) {
|
||||
ASSERT_UNEXPIRED("push_scope");
|
||||
s.push_scope(scope_name);
|
||||
})
|
||||
.def("pop_scope", [](TracingState& s) {
|
||||
ASSERT_UNEXPIRED("pop_scope");
|
||||
s.pop_scope();
|
||||
})
|
||||
.def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
|
||||
ASSERT_UNEXPIRED("export");
|
||||
return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
|
||||
|
|
@ -52,6 +60,12 @@ void initPythonTracerBindings(PyObject* module_) {
|
|||
m.def("_tracer_exit", [](variable_list var_outputs) {
|
||||
tracer::exit(var_outputs);
|
||||
});
|
||||
m.def("_get_tracing_state", [](const variable_list& vars) {
|
||||
return getTracingState(vars);
|
||||
});
|
||||
m.def("_is_tracing", [](const variable_list& vars) {
|
||||
return isTracing(vars);
|
||||
});
|
||||
}
|
||||
|
||||
}} // namespace torch::jit
|
||||
|
|
|
|||
|
|
@ -74,6 +74,14 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
|
|||
bool is_complete() const {
|
||||
return !is_expired() && graph->stage() == num_stages - 1;
|
||||
}
|
||||
|
||||
void push_scope(const std::string& scope_name) {
|
||||
graph->push_scope(scope_name);
|
||||
}
|
||||
|
||||
void pop_scope() {
|
||||
graph->pop_scope();
|
||||
}
|
||||
};
|
||||
|
||||
struct ValueTracingStateElem {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,30 @@ import copy
|
|||
_flatten = torch._C._jit_flatten
|
||||
|
||||
|
||||
# This global variable is set when we are tracing a *forwards* computation.
|
||||
# It is intended to be a cheap way to test if tracing has occurred, before
|
||||
# doing the slower path using `get_tracing_state` (below.)
|
||||
_tracing = False
|
||||
|
||||
|
||||
def get_tracing_state(args):
|
||||
if not torch._C._is_tracing(args):
|
||||
return None
|
||||
return torch._C._get_tracing_state(args)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def scope(scope_name, *vars):
|
||||
tracing_state = get_tracing_state(vars)
|
||||
if tracing_state:
|
||||
tracing_state.push_scope(scope_name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if tracing_state:
|
||||
tracing_state.pop_scope()
|
||||
|
||||
|
||||
def compile(arg=None, nderivs=1, optimize=True, enabled=True):
|
||||
"""
|
||||
Decorator which marks a function or module class as eligible for
|
||||
|
|
@ -237,13 +261,16 @@ class TracedModule(Module):
|
|||
self.nderivs = nderivs
|
||||
|
||||
def forward(self, *args):
|
||||
global _tracing
|
||||
in_vars = _flatten(args)
|
||||
# NOTE: use full state, because we need it for BatchNorm export
|
||||
# This differs from the compiler path, which doesn't support it at the moment.
|
||||
module_state = list(self.state_dict(keep_vars=True).values())
|
||||
trace = torch._C._tracer_enter(in_vars + module_state, self.nderivs)
|
||||
_tracing = True
|
||||
out = self.inner(*args)
|
||||
out_vars = _flatten(out)
|
||||
_tracing = False
|
||||
torch._C._tracer_exit(out_vars)
|
||||
return trace, out
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections import OrderedDict
|
||||
from collections import OrderedDict, Iterable
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
|
@ -319,10 +319,42 @@ class Module(object):
|
|||
self._forward_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
def _tracing_name(self, tracing_state):
|
||||
if not tracing_state._traced_module_stack:
|
||||
return None
|
||||
module = tracing_state._traced_module_stack[-1]
|
||||
for name, child in module.named_children():
|
||||
if child is self:
|
||||
return name
|
||||
return None
|
||||
|
||||
def _slow_forward(self, *input, **kwargs):
|
||||
input_vars = tuple(torch.autograd.function._iter_variables(input))
|
||||
tracing_state = torch.jit.get_tracing_state(input_vars)
|
||||
if not tracing_state:
|
||||
return self.forward(*input, **kwargs)
|
||||
if not hasattr(tracing_state, '_traced_module_stack'):
|
||||
tracing_state._traced_module_stack = []
|
||||
name = self._tracing_name(tracing_state)
|
||||
if name:
|
||||
tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name))
|
||||
else:
|
||||
tracing_state.push_scope(self.__class__.__name__)
|
||||
tracing_state._traced_module_stack.append(self)
|
||||
try:
|
||||
result = self.forward(*input, **kwargs)
|
||||
finally:
|
||||
tracing_state.pop_scope()
|
||||
tracing_state._traced_module_stack.pop()
|
||||
return result
|
||||
|
||||
def __call__(self, *input, **kwargs):
|
||||
for hook in self._forward_pre_hooks.values():
|
||||
hook(self, input)
|
||||
result = self.forward(*input, **kwargs)
|
||||
if torch.jit._tracing:
|
||||
result = self._slow_forward(*input, **kwargs)
|
||||
else:
|
||||
result = self.forward(*input, **kwargs)
|
||||
for hook in self._forward_hooks.values():
|
||||
hook_result = hook(self, input, result)
|
||||
if hook_result is not None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user