Serialize first-class version of functions (#19723)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19723
ghimport-source-id: 7f7ec6200c3b42d19046a3e228a3d82212697f14

Reviewed By: jamesr66a

Differential Revision: D15078533

Pulled By: zdevito

fbshipit-source-id: fe421afab9607ee942f6d200f04bb6335fc0aa97
This commit is contained in:
Zachary DeVito 2019-04-25 15:43:53 -07:00 committed by Facebook Github Bot
parent 6cb1b994d8
commit 330990d878
26 changed files with 212 additions and 297 deletions

View File

@ -511,7 +511,7 @@ ClassTypePtr ClassType::create(
}
ClassTypePtr ClassType::createModuleType(std::shared_ptr<CompilationUnit> cu) {
return ClassTypePtr(new ClassType("Module", std::move(cu)));
return ClassTypePtr(new ClassType("$Module", std::move(cu)));
}
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {

View File

@ -1,4 +1,3 @@
def graph(self,
a: Tensor) -> None:
def foo(a: Tensor) -> None:
print("hi\016")
return None

View File

@ -1,5 +1,4 @@
def graph(self,
x: Tensor,
def foo(x: Tensor,
y: Tensor) -> Tensor:
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
return _0

View File

@ -1,3 +1,2 @@
def graph(self,
y: Tensor) -> List[float]:
def empty_float_list_test(y: Tensor) -> List[float]:
return [1., 2., 3.]

View File

@ -1,4 +1,3 @@
def graph(self,
y: Tensor) -> int:
def empty_int_list_test(y: Tensor) -> int:
x = annotate(List[int], [])
return torch.select(x, 0)

View File

@ -1,5 +1,4 @@
def graph(self,
a: Tensor,
def if_one(a: Tensor,
b: Tensor) -> Tensor:
if bool(torch.lt(a, b)):
c = a

View File

@ -1,5 +1,4 @@
def graph(self,
a: Tensor,
def if_test(a: Tensor,
b: Tensor) -> Tensor:
if bool(torch.lt(a, b)):
c = b

View File

@ -1,5 +1,4 @@
def graph(self,
y: Tensor) -> Tuple[Tensor, Tensor]:
def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.add(y, 1, 1)
z = torch.add(x, 5, 1)
y0, z0 = y, z

View File

@ -1,4 +1,3 @@
def graph(self,
y: Tensor) -> None:
def print_weird_test(y: Tensor) -> None:
print("hi\016")
return None

View File

@ -1,3 +1,2 @@
def graph(self,
y: Tensor) -> Tensor:
def python_op_name_test(y: Tensor) -> Tensor:
return ^python_fn()(y)

View File

@ -1,5 +1,4 @@
def graph(self,
a: Tensor,
def while_if_test(a: Tensor,
b: Tensor) -> Tensor:
a0, b0, c = a, b, 0
_0 = bool(torch.lt(a, 10))

View File

@ -1,5 +1,4 @@
def graph(self,
a: Tensor,
def while_test(a: Tensor,
i: Tensor) -> Tensor:
a0, i0 = a, i
_0 = bool(torch.lt(i, 3))

View File

@ -317,7 +317,8 @@ class JitTestCase(TestCase):
try:
src, constants = _jit_python_print(func)
cu = torch.jit.CompilationUnit()._import(src, constants)
src2, constants2 = _jit_python_print(getattr(cu, func.name))
func2 = getattr(cu, func.name)
src2, constants2 = _jit_python_print(func2)
self.assertMultiLineEqual(src, src2)
except RuntimeError as e:
se = str(e)
@ -2768,22 +2769,22 @@ graph(%x : Tensor,
def print_weird_test(y):
print("hi\016")
self.assertExpected(if_test.graph.pretty_print(), "if_test")
self.assertExpected(if_one.graph.pretty_print(), "if_one")
self.assertExpected(while_test.graph.pretty_print(), "while_test")
self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
self.assertExpected(if_test.code, "if_test")
self.assertExpected(if_one.code, "if_one")
self.assertExpected(while_test.code, "while_test")
self.assertExpected(while_if_test.code, "while_if_test")
self.assertExpected(loop_use_test.code, "loop_use_test")
self.assertExpected(python_op_name_test.code, "python_op_name_test")
self.assertExpected(empty_int_list_test.code, "empty_int_list_test")
self.assertExpected(empty_float_list_test.code, "empty_float_list_test")
self.assertExpected(print_weird_test.code, "print_weird_test")
def test_cu_escaped_number(self):
cu = torch.jit.CompilationUnit('''
def foo(a):
print("hi\016")
''')
self.assertExpected(cu.foo.graph.pretty_print())
self.assertExpected(cu.foo.code)
def test_import_method(self):
@torch.jit.script
@ -2792,7 +2793,7 @@ graph(%x : Tensor,
r, _ = _jit_python_print(foo)
cu = torch.jit.CompilationUnit()._import(r, [])
self.assertExpected(cu.foo.graph.pretty_print())
self.assertExpected(cu.foo.code)
def test_function_default_values(self):
outer_var = torch.tensor(20)

View File

@ -120,20 +120,20 @@ class EncoderBase {
protected:
// Using std::map instead of std::unordered_map for initializers
// in EncodeGraph cosntructor so that the order in which initializers
// get written to the ONNX graph is always the deterministic and
// predictable. While this is not a ONNX requirement, it is needed
// get written to the ONNX graph is always the deterministic and
// predictable. While this is not a ONNX requirement, it is needed
// for testing purposes in tests that use _export_to_pretty_string()
// for validating ONNX graphs.
void EncodeGraph(
onnx::GraphProto* graph_proto,
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers =
const std::map<std::string, at::Tensor>& initializers =
std::map<std::string, at::Tensor>());
void EncodeBlock(
onnx::GraphProto* graph_proto,
const Block* block,
const std::map<std::string, at::Tensor>& initializers =
const std::map<std::string, at::Tensor>& initializers =
std::map<std::string, at::Tensor>());
virtual void EncodeTensor(
@ -767,7 +767,8 @@ void ScriptModuleSerializer::convertModule(
methods << "op_version_set = " << op_version_set << "\n";
PythonPrint(
methods,
module,
module.class_compilation_unit(),
/*is_method=*/true,
tensor_table_,
class_table_,
/*enforce_importable=*/true);

View File

@ -288,19 +288,6 @@ std::ostream& operator<<(std::ostream& out, const Graph& g) {
return out;
}
std::ostream& Graph::prettyPrint(std::ostream& out) {
std::vector<at::Tensor> tensor_table;
std::vector<ClassTypePtr> class_table;
PythonPrint(out, *this, tensor_table, class_table);
return out;
}
void Graph::dumpPretty() {
std::vector<at::Tensor> tensor_table;
std::vector<ClassTypePtr> class_table;
PythonPrint(std::cout, *this, tensor_table, class_table);
}
static void checkSameDevice(const Node* node) {
bool has_device = false;
c10::optional<at::Device> device = c10::nullopt;

View File

@ -1165,9 +1165,6 @@ struct Graph {
friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
TORCH_API std::ostream& prettyPrint(std::ostream& out);
TORCH_API void dumpPretty();
TORCH_API std::shared_ptr<Graph> copy();
TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);

View File

@ -127,22 +127,6 @@ struct QualifiedName : c10::intrusive_ptr_target {
}
};
void createTensorToParameterNameMap(
const script::Module& module,
const QualifiedNamePtr& prefix,
std::unordered_map<script::Slot, QualifiedNamePtr>& result) {
for (const auto& param : module.get_parameters()) {
result[param] = QualifiedName::create(prefix, param.name());
}
for (const auto& param : module.get_attributes()) {
result[param] = QualifiedName::create(prefix, param.name());
}
for (const auto& elem : module.get_modules()) {
createTensorToParameterNameMap(
*elem, QualifiedName::create(prefix, elem->name()), result);
}
}
// some names are valid identifiers but off limits because
// they are keywords or namespaces used in the output
const static std::unordered_set<std::string> reserved_names = {
@ -156,7 +140,6 @@ const static std::unordered_set<std::string> reserved_names = {
"inf",
"nan",
"ops",
"self",
// the python keywords
"and",
"as",
@ -207,6 +190,12 @@ struct PythonPrintPass {
std::vector<ClassTypePtr>& class_table_;
// Helper to avoid duplicating class types
void addToClassTable(const ClassTypePtr& classType) {
// we serialize module classes separately.
// Including them in the class table as well will cause the code
// to get imported twice.
if (classType->name() == "$Module") {
return;
}
if (std::find(class_table_.cbegin(), class_table_.cend(), classType) ==
class_table_.cend()) {
class_table_.push_back(classType);
@ -221,16 +210,21 @@ struct PythonPrintPass {
// not be able to be reparsed?
bool enforce_importable_;
// are funcitons being printed considered methods
// either of a class or some module?
// If true, this will surpress type annotation on their
// first (self) argument. And forked functions will
// be emitted as method calls (self.__fork...) rather
// than as method calls
bool is_method_;
// what valid identifiers are in use for the current function
std::unordered_set<std::string> used_names_;
// used method names
std::unordered_set<std::string> used_method_names_;
// for fork,
// subgraphs get added to the worklist, and will be printed later
std::vector<std::function<void(void)>> worklist;
// scanValue, scanNode, scanBlock:
// decide if it is safe to omit the output of a temporary variable,
// and inline the expression into its use
@ -275,6 +269,11 @@ struct PythonPrintPass {
// w.r.t. to it
if (use.user->kind() == prim::Loop && use.offset >= 2)
return false;
// subgraph may use this more than once, so disable inlining
if (use.user->kind() == prim::fork)
return false;
return true;
}
@ -726,6 +725,41 @@ struct PythonPrintPass {
body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal)
<< "\n";
} break;
case prim::fork: {
// the subgraph gets emitted as another function
auto name = genName("__forked_function");
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
indent();
body_ << "def " << name << "():\n";
for(size_t i = 0; i < node->inputs().size(); ++i) {
assignValue(graph->inputs().at(i), node->inputs().at(i));
}
printBody(graph->block());
std::stringstream ss;
ss << "fork(" << name << ")";
printOutputDefinition( node,ss.str());
} break;
case prim::Function: {
if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation())
<< "closures are not exportable";
}
assignValuesToTheirUniqueNames(node->outputs());
auto name = useOf(node->output());
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
indent();
body_ << "def " << name << "(";
assignValuesToTheirUniqueNames(graph->inputs());
for(size_t i = 0; i < graph->inputs().size(); ++i) {
Value* v = graph->inputs().at(i);
if (i > 0) {
body_ << ", ";
}
body_ << useOf(v) << ": " << v->type()->python_str();
}
body_ << "):\n";
printBody(graph->block());
} break;
default:
std::stringstream ss;
printRHS(ss, node);
@ -899,30 +933,6 @@ struct PythonPrintPass {
stmt << "(" << useOf(node->inputs().at(0)) << ")["
<< useOf(node->inputs().at(1)) << "]";
} break;
case prim::fork: {
// the subgraph gets emitted as another function
auto name = genMethodName("__forked_function");
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
worklist.emplace_back(
[graph, name, this] { printFunctionDefinition(*graph, name); });
// and we put a call to fork which invokes that function.
stmt << "fork(self." << name;
for (Value* v : node->inputs()) {
stmt << ", " << useOf(v);
}
stmt << ")";
} break;
case prim::Function: {
if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation())
<< "closures are not exportable";
}
auto name = genMethodName("__lambda");
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
worklist.emplace_back(
[graph, name, this] { printFunctionDefinition(*graph, name); });
stmt << "self." << name;
} break;
case prim::CreateObject: {
const auto classType = node->output()->type()->expect<ClassType>();
stmt << classType->name() << ".__new__(" << classType->name() << ")";
@ -931,7 +941,13 @@ struct PythonPrintPass {
const auto obj = node->inputs().at(0);
const auto classType = obj->type()->expect<ClassType>();
const auto& field = node->s(attr::name);
stmt << useOf(obj) << "." << field;
if (isValidIdentifier(field)) {
stmt << useOf(obj) << "." << field;
} else {
stmt << "getattr(" << useOf(obj) << ", ";
printQuotedString(stmt, field);
stmt << ")";
}
} break;
default: {
Symbol kind = node->kind();
@ -999,69 +1015,14 @@ struct PythonPrintPass {
stmt << "=";
printConstant(stmt, value);
}
void printFunctionDefinition(
Graph& graph,
const std::string& name,
bool is_class = false,
const std::vector<c10::optional<IValue>>& defaults = {},
const std::vector<std::string>& param_names = {}) {
used_names_.clear(); // each graph can reuse local names
void printBody(Block* body) {
// we always print constants at the top of the function, in the order
// in which they are used.
std::vector<Node*> constants;
buildConstantList(graph.block(), constants);
buildConstantList(body, constants);
// current graph is used to de-dup names within a single graph
scanBlock(graph.block());
// last param_names.size() arguments to the graph are parameters and not
// actual inputs, we will print these as, e.g. self.foo.bar
// while we print the true_inputs out as parameters
auto true_inputs =
graph.inputs().slice(0, graph.inputs().size() - param_names.size());
auto param_names_it = param_names.begin();
for (auto param : graph.inputs().slice(true_inputs.size())) {
assignValue(param, *param_names_it++);
}
assignValuesToTheirUniqueNames(true_inputs);
auto defaults_offset = defaults.begin();
indent();
body_ << "def " << name << "(";
auto input_iter = true_inputs.begin();
// Print the `self` argument
if (is_class) {
// If this is a class, print the self var without a type annotation,
// following Python convention
AT_ASSERT(true_inputs.size() > 0);
body_ << useOf(*input_iter);
++input_iter;
AT_ASSERT(!defaults_offset->has_value());
++defaults_offset;
} else {
// If this is not a class, then we need to insert a "self".
body_ << "self";
}
// Print the rest of the arguments
for (; input_iter != true_inputs.end(); ++input_iter) {
auto input = *input_iter;
body_ << ",\n " << useOf(input) << ": " << input->type()->python_str();
if (defaults_offset != defaults.end()) {
const c10::optional<IValue>& def = *defaults_offset++;
if (def) {
printDefaultValue(input->type(), body_, *def);
}
}
}
// have we use all the provided defaults?
AT_ASSERT(defaults_offset == defaults.end());
body_ << ") -> " << resultType(graph)->python_str() << ":\n";
scanBlock(body);
{
auto guard = WithIndented();
// Print initial constant table (most are just inlined into their use,
@ -1071,19 +1032,52 @@ struct PythonPrintPass {
}
// Print body
printBlock(
graph.block(), graph.block()->return_node()->inputs().size() > 0);
printNode(graph.block()->return_node(), /*print_const=*/false);
body, body->return_node()->inputs().size() > 0);
printNode(body->return_node(), /*print_const=*/false);
}
}
public:
public:
void printFunction(script::Function& func) {
const FunctionSchema& schema = func.getSchema();
Graph& graph = *func.graph();
used_names_.clear(); // each graph can reuse local names
indent();
body_ << "def " << func.name() << "(";
auto param_it = graph.inputs().begin();
for(const Argument& arg : schema.arguments()) {
std::string arg_name = genName(arg.name());
if (param_it == graph.inputs().begin()) {
// the first argument may omit its type when it is implied by context
// the flag is_method_ determines when to do this
body_ << arg_name;
if (!is_method_) {
body_ << ": " << arg.type()->python_str();
}
} else {
body_ << ",\n " << arg_name << ": " << arg.type()->python_str();
}
if (arg.default_value()) {
printDefaultValue(arg.type(), body_, *arg.default_value());
}
assignValue(*param_it++, arg_name);
}
body_ << ") -> " << resultType(graph)->python_str() << ":\n";
printBody(graph.block());
}
PythonPrintPass(
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable)
bool enforce_importable,
bool is_method)
: tensor_table_(tensor_table),
class_table_(class_table),
enforce_importable_(enforce_importable) {}
enforce_importable_(enforce_importable),
is_method_(is_method) {}
// TODO: we should consider forcing functions to return a single value
// instead of handling this tuple logic both in the compiler and the printer
@ -1096,63 +1090,9 @@ struct PythonPrintPass {
}
}
void printFunction(
Graph& graph,
const std::string& name,
bool is_class,
const std::vector<c10::optional<IValue>>& defaults = {},
const std::vector<std::string>& param_names = {}) {
printFunctionDefinition(graph, name, is_class, defaults, param_names);
while (!worklist.empty()) {
body_ << "\n\n";
auto work = worklist.back();
worklist.pop_back();
work();
}
}
void printMethod(script::Method& method) {
std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
createTensorToParameterNameMap(
method.owner(), QualifiedName::create("self"), extra_ivalue_names);
printMethod(method, /*is_class=*/false, extra_ivalue_names);
}
void printMethod(
script::Method& method,
bool is_class,
const std::unordered_map<script::Slot, QualifiedNamePtr>&
extra_ivalue_names) {
std::vector<std::string> ivalue_names =
fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
return extra_ivalue_names.at(slot)->str();
});
const std::string& name = method.name();
Graph& graph = *method.graph();
auto defaults = fmap(
method.getSchema().arguments(),
[](const Argument& arg) { return arg.default_value(); });
printFunction(graph, name, is_class, defaults, ivalue_names);
}
void printFunction(script::Function& method, bool is_class) {
const std::string& name = method.name();
Graph& graph = *method.graph();
auto defaults = fmap(
method.getSchema().arguments(),
[](const Argument& arg) { return arg.default_value(); });
printFunction(graph, name, is_class, defaults, {});
}
void printModule(script::Module& module) {
std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
createTensorToParameterNameMap(
module, QualifiedName::create("self"), extra_ivalue_names);
for (auto& method : module.get_methods()) {
const std::string& name = method->name();
// we skip __forked_functions because they actually get inlined into their
// callers, exporting them again will lead to more code generated on each
// export
if (name.find("__forked_function") == 0) {
continue;
}
printMethod(*method, /*is_class=*/false, extra_ivalue_names);
void printCompilationUnit(script::CompilationUnit& cu) {
for (auto& func : cu.get_functions()) {
printFunction(*func);
}
}
@ -1161,7 +1101,7 @@ struct PythonPrintPass {
{
const auto guard = WithIndented();
for (auto& method : classType->methods()) {
printFunction(*method, /*is_class=*/true);
printFunction(*method);
}
}
}
@ -1173,49 +1113,27 @@ struct PythonPrintPass {
TORCH_API void PythonPrint(
std::ostream& out,
const Graph& graph,
const script::Function& func,
bool is_method,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
PythonPrintPass pp(tensor_table, class_table, enforce_importable, is_method);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printFunction(const_cast<Graph&>(graph), "graph", /*is_class=*/false);
pp.printFunction(const_cast<script::Function&>(func));
pp.print(out);
}
TORCH_API void PythonPrint(
std::ostream& out,
const script::Method& method,
const script::CompilationUnit& cu,
bool is_method,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
PythonPrintPass pp(tensor_table, class_table, enforce_importable, is_method);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printMethod(const_cast<script::Method&>(method));
pp.print(out);
}
TORCH_API void PythonPrint(
std::ostream& out,
const script::Function& callee,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printFunction(const_cast<script::Function&>(callee), /*is_class=*/false);
pp.print(out);
}
TORCH_API void PythonPrint(
std::ostream& out,
const script::Module& module,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printModule(const_cast<script::Module&>(module));
pp.printCompilationUnit(const_cast<script::CompilationUnit&>(cu));
pp.print(out);
}
@ -1225,7 +1143,7 @@ TORCH_API void PythonPrint(
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
PythonPrintPass pp(tensor_table, class_table, enforce_importable, true);
pp.printClass(classType);
pp.print(out);
}

View File

@ -12,33 +12,21 @@ struct Method;
struct Module;
} // namespace script
TORCH_API void PythonPrint(
std::ostream& out,
const Graph& graph,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable = false);
TORCH_API void PythonPrint(
std::ostream& out,
const script::Method& graph,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable = false);
TORCH_API void PythonPrint(
std::ostream& out,
const script::Function& callee,
bool is_method,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable = false);
TORCH_API void PythonPrint(
std::ostream& out,
const script::Module& module,
const script::CompilationUnit& cu,
bool is_method,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable = false);
bool enforce_importable);
TORCH_API void PythonPrint(
std::ostream& out,

View File

@ -343,13 +343,6 @@ void initPythonIRBindings(PyObject* module_) {
})
.def("param_node", [](Graph& g) { return g.block()->param_node(); })
.def("return_node", [](Graph& g) { return g.block()->return_node(); })
.def(
"pretty_print",
[](Graph& g) {
std::ostringstream oss;
g.prettyPrint(oss);
return oss.str();
})
.GS(createFusionGroup)
.def(
"createClone",

View File

@ -2263,7 +2263,7 @@ struct to_ir {
}
// Lambda lift block(0) into attr::Subgraph
lambdaLiftFork(fork_node);
runCleanupPasses(fork_node->g(attr::Subgraph));
return std::make_shared<SimpleValue>(node_output);
}
@ -2871,11 +2871,13 @@ void lambdaLiftFork(Node* fork_node) {
auto env = [&](Value* v) -> Value* {
if (!uncaptures_map.count(v)) {
// Capture values for both graphs
uncaptures_map[v] = forked_graph->addInput()->copyMetadata(v);
uncaptures_map[v] =
forked_graph->addInput()->copyMetadata(v);
fork_node->addInput(v);
}
return uncaptures_map[v];
};
forked_graph->block()->cloneFrom(body_block, env);
// Separate the subgraph and clean up the orignal one

View File

@ -12,8 +12,14 @@ struct ReturnInfo {
};
void checkNoReturn(const TreeRef& ref) {
if (ref->kind() == TK_RETURN)
if (ref->kind() == TK_RETURN) {
throw ErrorReport(ref) << "return is not allowed from a loop.";
}
// do not search into first-class functions
if (ref->kind() == TK_DEF) {
return;
}
for (const TreeRef& child : ref->trees()) {
checkNoReturn(child);
}

View File

@ -939,7 +939,13 @@ void initJitScriptBindings(PyObject* module) {
std::ostringstream ss;
std::vector<at::Tensor> tensors;
std::vector<ClassTypePtr> classes;
PythonPrint(ss, self, tensors, classes, false);
PythonPrint(
ss,
self.class_compilation_unit(),
true,
tensors,
classes,
false);
return ss.str();
})
.def("apply", &Module::apply)
@ -989,7 +995,7 @@ void initJitScriptBindings(PyObject* module) {
std::ostringstream ss;
std::vector<at::Tensor> tensors;
std::vector<ClassTypePtr> classes;
PythonPrint(ss, self, tensors, classes, false);
PythonPrint(ss, self, false, tensors, classes, false);
return ss.str();
})
.def(
@ -1021,7 +1027,7 @@ void initJitScriptBindings(PyObject* module) {
std::ostringstream ss;
std::vector<at::Tensor> tensors;
std::vector<ClassTypePtr> classes;
PythonPrint(ss, self, tensors, classes, false);
PythonPrint(ss, self.function(), true, tensors, classes, false);
return ss.str();
});
@ -1119,12 +1125,12 @@ void initJitScriptBindings(PyObject* module) {
std::vector<at::Tensor> constants;
std::vector<ClassTypePtr> classes;
if (auto self = as_module(obj)) {
PythonPrint(ss, *self, constants, classes, true);
PythonPrint(ss, self->class_compilation_unit(), true, constants, classes, true);
} else if (auto self = as_function(obj)) {
PythonPrint(ss, *self, constants, classes, true);
PythonPrint(ss, *self, false, constants, classes, true);
} else {
auto& m = py::cast<Method&>(obj);
PythonPrint(ss, m, constants, classes, true);
PythonPrint(ss, m.function(), true, constants, classes, true);
}
return std::make_pair(ss.str(), std::move(constants));
});

View File

@ -185,7 +185,7 @@ std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
}
Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
if (c->name() == "Module") {
if (c->name() == "$Module") {
auto obj = slot.value().toObject();
for (Use use : e.n->output()->uses()) {
to_scan.emplace_back(ToScan{obj, use.user, use.offset});

View File

@ -118,7 +118,6 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
if (auto method = classType->getMethod(field)) {
return std::make_shared<MethodValue>(getValue(), method);
}
if (!classType->hasAttribute(field)) {
throw ErrorReport(loc)
<< "Tried to access to nonexistent attribute " << field
@ -215,6 +214,35 @@ void SimpleValue::setAttr(
g.insertNode(g.createSetAttr(value_, field, newValue));
}
std::shared_ptr<SugaredValue> SimpleValue::call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
// allow our 'fake' closures to be called, used for fork serialization
// at the moment, but can be expanded later
Node* self = getValue()->node();
if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 &&
self->inputs().at(0)->node()->kind() == prim::Function) {
std::shared_ptr<Graph> graph = self->inputs().at(0)->node()->g(attr::Subgraph);
Value* context = self->inputs().at(1);
AT_ASSERT(context->node()->kind() == prim::TupleConstruct);
// fork nodes are emitted in their own block but we do not simplify
// tuple construction across blocks. To ensure we clean up the tuple
// construct create another copy of the tuple construct in the fork block
Value* close_context =
m.graph()
->insertNode(m.graph()->createTuple(context->node()->inputs()))
->output();
auto fn = CompilationUnit().create_function("anon", graph);
return MethodValue(close_context, fn).call(loc, m, inputs, attributes, n_binders);
}
return SugaredValue::call(loc, m, inputs, attributes, n_binders);
}
std::shared_ptr<SugaredValue> ClassValue::call(
const SourceRange& loc,
Function& m,

View File

@ -129,6 +129,14 @@ struct TORCH_API SimpleValue : public SugaredValue {
const std::string& field,
Value* newValue) override;
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;
Value* getValue() const {
return value_;
}

View File

@ -707,15 +707,6 @@ class CompilationUnit(object):
def _import(self, src, constants):
""" test import logic for single function, use only for testing """
src = "op_version_set = 0\n{}".format(src)
# HACK: export always adds a self argument
# even if the thing is a pure function,
# we just delete it here
# Once the serialization outputs the first-class
# functions and not the lowered functions,
# we can change this behavior
src = (src.replace('self,\n', '')
.replace('(self)', '()')
.replace('self.__forked_function', '__forked_function'))
torch._C._jit_import_functions(self._c, src, constants, None)
return self