mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Serialize first-class version of functions (#19723)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19723 ghimport-source-id: 7f7ec6200c3b42d19046a3e228a3d82212697f14 Reviewed By: jamesr66a Differential Revision: D15078533 Pulled By: zdevito fbshipit-source-id: fe421afab9607ee942f6d200f04bb6335fc0aa97
This commit is contained in:
parent
6cb1b994d8
commit
330990d878
|
|
@ -511,7 +511,7 @@ ClassTypePtr ClassType::create(
|
|||
}
|
||||
|
||||
ClassTypePtr ClassType::createModuleType(std::shared_ptr<CompilationUnit> cu) {
|
||||
return ClassTypePtr(new ClassType("Module", std::move(cu)));
|
||||
return ClassTypePtr(new ClassType("$Module", std::move(cu)));
|
||||
}
|
||||
|
||||
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
def graph(self,
|
||||
a: Tensor) -> None:
|
||||
def foo(a: Tensor) -> None:
|
||||
print("hi\016")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
def graph(self,
|
||||
x: Tensor,
|
||||
def foo(x: Tensor,
|
||||
y: Tensor) -> Tensor:
|
||||
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
|
||||
return _0
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
def graph(self,
|
||||
y: Tensor) -> List[float]:
|
||||
def empty_float_list_test(y: Tensor) -> List[float]:
|
||||
return [1., 2., 3.]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
def graph(self,
|
||||
y: Tensor) -> int:
|
||||
def empty_int_list_test(y: Tensor) -> int:
|
||||
x = annotate(List[int], [])
|
||||
return torch.select(x, 0)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
def graph(self,
|
||||
a: Tensor,
|
||||
def if_one(a: Tensor,
|
||||
b: Tensor) -> Tensor:
|
||||
if bool(torch.lt(a, b)):
|
||||
c = a
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
def graph(self,
|
||||
a: Tensor,
|
||||
def if_test(a: Tensor,
|
||||
b: Tensor) -> Tensor:
|
||||
if bool(torch.lt(a, b)):
|
||||
c = b
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
def graph(self,
|
||||
y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
x = torch.add(y, 1, 1)
|
||||
z = torch.add(x, 5, 1)
|
||||
y0, z0 = y, z
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
def graph(self,
|
||||
y: Tensor) -> None:
|
||||
def print_weird_test(y: Tensor) -> None:
|
||||
print("hi\016")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
def graph(self,
|
||||
y: Tensor) -> Tensor:
|
||||
def python_op_name_test(y: Tensor) -> Tensor:
|
||||
return ^python_fn()(y)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
def graph(self,
|
||||
a: Tensor,
|
||||
def while_if_test(a: Tensor,
|
||||
b: Tensor) -> Tensor:
|
||||
a0, b0, c = a, b, 0
|
||||
_0 = bool(torch.lt(a, 10))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
def graph(self,
|
||||
a: Tensor,
|
||||
def while_test(a: Tensor,
|
||||
i: Tensor) -> Tensor:
|
||||
a0, i0 = a, i
|
||||
_0 = bool(torch.lt(i, 3))
|
||||
|
|
|
|||
|
|
@ -317,7 +317,8 @@ class JitTestCase(TestCase):
|
|||
try:
|
||||
src, constants = _jit_python_print(func)
|
||||
cu = torch.jit.CompilationUnit()._import(src, constants)
|
||||
src2, constants2 = _jit_python_print(getattr(cu, func.name))
|
||||
func2 = getattr(cu, func.name)
|
||||
src2, constants2 = _jit_python_print(func2)
|
||||
self.assertMultiLineEqual(src, src2)
|
||||
except RuntimeError as e:
|
||||
se = str(e)
|
||||
|
|
@ -2768,22 +2769,22 @@ graph(%x : Tensor,
|
|||
def print_weird_test(y):
|
||||
print("hi\016")
|
||||
|
||||
self.assertExpected(if_test.graph.pretty_print(), "if_test")
|
||||
self.assertExpected(if_one.graph.pretty_print(), "if_one")
|
||||
self.assertExpected(while_test.graph.pretty_print(), "while_test")
|
||||
self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
|
||||
self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
|
||||
self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
|
||||
self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
|
||||
self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
|
||||
self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
|
||||
self.assertExpected(if_test.code, "if_test")
|
||||
self.assertExpected(if_one.code, "if_one")
|
||||
self.assertExpected(while_test.code, "while_test")
|
||||
self.assertExpected(while_if_test.code, "while_if_test")
|
||||
self.assertExpected(loop_use_test.code, "loop_use_test")
|
||||
self.assertExpected(python_op_name_test.code, "python_op_name_test")
|
||||
self.assertExpected(empty_int_list_test.code, "empty_int_list_test")
|
||||
self.assertExpected(empty_float_list_test.code, "empty_float_list_test")
|
||||
self.assertExpected(print_weird_test.code, "print_weird_test")
|
||||
|
||||
def test_cu_escaped_number(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def foo(a):
|
||||
print("hi\016")
|
||||
''')
|
||||
self.assertExpected(cu.foo.graph.pretty_print())
|
||||
self.assertExpected(cu.foo.code)
|
||||
|
||||
def test_import_method(self):
|
||||
@torch.jit.script
|
||||
|
|
@ -2792,7 +2793,7 @@ graph(%x : Tensor,
|
|||
|
||||
r, _ = _jit_python_print(foo)
|
||||
cu = torch.jit.CompilationUnit()._import(r, [])
|
||||
self.assertExpected(cu.foo.graph.pretty_print())
|
||||
self.assertExpected(cu.foo.code)
|
||||
|
||||
def test_function_default_values(self):
|
||||
outer_var = torch.tensor(20)
|
||||
|
|
|
|||
|
|
@ -120,20 +120,20 @@ class EncoderBase {
|
|||
protected:
|
||||
// Using std::map instead of std::unordered_map for initializers
|
||||
// in EncodeGraph cosntructor so that the order in which initializers
|
||||
// get written to the ONNX graph is always the deterministic and
|
||||
// predictable. While this is not a ONNX requirement, it is needed
|
||||
// get written to the ONNX graph is always the deterministic and
|
||||
// predictable. While this is not a ONNX requirement, it is needed
|
||||
// for testing purposes in tests that use _export_to_pretty_string()
|
||||
// for validating ONNX graphs.
|
||||
void EncodeGraph(
|
||||
onnx::GraphProto* graph_proto,
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const std::map<std::string, at::Tensor>& initializers =
|
||||
const std::map<std::string, at::Tensor>& initializers =
|
||||
std::map<std::string, at::Tensor>());
|
||||
|
||||
void EncodeBlock(
|
||||
onnx::GraphProto* graph_proto,
|
||||
const Block* block,
|
||||
const std::map<std::string, at::Tensor>& initializers =
|
||||
const std::map<std::string, at::Tensor>& initializers =
|
||||
std::map<std::string, at::Tensor>());
|
||||
|
||||
virtual void EncodeTensor(
|
||||
|
|
@ -767,7 +767,8 @@ void ScriptModuleSerializer::convertModule(
|
|||
methods << "op_version_set = " << op_version_set << "\n";
|
||||
PythonPrint(
|
||||
methods,
|
||||
module,
|
||||
module.class_compilation_unit(),
|
||||
/*is_method=*/true,
|
||||
tensor_table_,
|
||||
class_table_,
|
||||
/*enforce_importable=*/true);
|
||||
|
|
|
|||
|
|
@ -288,19 +288,6 @@ std::ostream& operator<<(std::ostream& out, const Graph& g) {
|
|||
return out;
|
||||
}
|
||||
|
||||
std::ostream& Graph::prettyPrint(std::ostream& out) {
|
||||
std::vector<at::Tensor> tensor_table;
|
||||
std::vector<ClassTypePtr> class_table;
|
||||
PythonPrint(out, *this, tensor_table, class_table);
|
||||
return out;
|
||||
}
|
||||
|
||||
void Graph::dumpPretty() {
|
||||
std::vector<at::Tensor> tensor_table;
|
||||
std::vector<ClassTypePtr> class_table;
|
||||
PythonPrint(std::cout, *this, tensor_table, class_table);
|
||||
}
|
||||
|
||||
static void checkSameDevice(const Node* node) {
|
||||
bool has_device = false;
|
||||
c10::optional<at::Device> device = c10::nullopt;
|
||||
|
|
|
|||
|
|
@ -1165,9 +1165,6 @@ struct Graph {
|
|||
|
||||
friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
|
||||
|
||||
TORCH_API std::ostream& prettyPrint(std::ostream& out);
|
||||
TORCH_API void dumpPretty();
|
||||
|
||||
TORCH_API std::shared_ptr<Graph> copy();
|
||||
TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
|
||||
|
||||
|
|
|
|||
|
|
@ -127,22 +127,6 @@ struct QualifiedName : c10::intrusive_ptr_target {
|
|||
}
|
||||
};
|
||||
|
||||
void createTensorToParameterNameMap(
|
||||
const script::Module& module,
|
||||
const QualifiedNamePtr& prefix,
|
||||
std::unordered_map<script::Slot, QualifiedNamePtr>& result) {
|
||||
for (const auto& param : module.get_parameters()) {
|
||||
result[param] = QualifiedName::create(prefix, param.name());
|
||||
}
|
||||
for (const auto& param : module.get_attributes()) {
|
||||
result[param] = QualifiedName::create(prefix, param.name());
|
||||
}
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
createTensorToParameterNameMap(
|
||||
*elem, QualifiedName::create(prefix, elem->name()), result);
|
||||
}
|
||||
}
|
||||
|
||||
// some names are valid identifiers but off limits because
|
||||
// they are keywords or namespaces used in the output
|
||||
const static std::unordered_set<std::string> reserved_names = {
|
||||
|
|
@ -156,7 +140,6 @@ const static std::unordered_set<std::string> reserved_names = {
|
|||
"inf",
|
||||
"nan",
|
||||
"ops",
|
||||
"self",
|
||||
// the python keywords
|
||||
"and",
|
||||
"as",
|
||||
|
|
@ -207,6 +190,12 @@ struct PythonPrintPass {
|
|||
std::vector<ClassTypePtr>& class_table_;
|
||||
// Helper to avoid duplicating class types
|
||||
void addToClassTable(const ClassTypePtr& classType) {
|
||||
// we serialize module classes separately.
|
||||
// Including them in the class table as well will cause the code
|
||||
// to get imported twice.
|
||||
if (classType->name() == "$Module") {
|
||||
return;
|
||||
}
|
||||
if (std::find(class_table_.cbegin(), class_table_.cend(), classType) ==
|
||||
class_table_.cend()) {
|
||||
class_table_.push_back(classType);
|
||||
|
|
@ -221,16 +210,21 @@ struct PythonPrintPass {
|
|||
// not be able to be reparsed?
|
||||
bool enforce_importable_;
|
||||
|
||||
// are funcitons being printed considered methods
|
||||
// either of a class or some module?
|
||||
// If true, this will surpress type annotation on their
|
||||
// first (self) argument. And forked functions will
|
||||
// be emitted as method calls (self.__fork...) rather
|
||||
// than as method calls
|
||||
bool is_method_;
|
||||
|
||||
|
||||
// what valid identifiers are in use for the current function
|
||||
std::unordered_set<std::string> used_names_;
|
||||
|
||||
// used method names
|
||||
std::unordered_set<std::string> used_method_names_;
|
||||
|
||||
// for fork,
|
||||
// subgraphs get added to the worklist, and will be printed later
|
||||
std::vector<std::function<void(void)>> worklist;
|
||||
|
||||
// scanValue, scanNode, scanBlock:
|
||||
// decide if it is safe to omit the output of a temporary variable,
|
||||
// and inline the expression into its use
|
||||
|
|
@ -275,6 +269,11 @@ struct PythonPrintPass {
|
|||
// w.r.t. to it
|
||||
if (use.user->kind() == prim::Loop && use.offset >= 2)
|
||||
return false;
|
||||
|
||||
// subgraph may use this more than once, so disable inlining
|
||||
if (use.user->kind() == prim::fork)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -726,6 +725,41 @@ struct PythonPrintPass {
|
|||
body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal)
|
||||
<< "\n";
|
||||
} break;
|
||||
case prim::fork: {
|
||||
// the subgraph gets emitted as another function
|
||||
auto name = genName("__forked_function");
|
||||
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
|
||||
indent();
|
||||
body_ << "def " << name << "():\n";
|
||||
for(size_t i = 0; i < node->inputs().size(); ++i) {
|
||||
assignValue(graph->inputs().at(i), node->inputs().at(i));
|
||||
}
|
||||
printBody(graph->block());
|
||||
std::stringstream ss;
|
||||
ss << "fork(" << name << ")";
|
||||
printOutputDefinition( node,ss.str());
|
||||
} break;
|
||||
case prim::Function: {
|
||||
if (enforce_importable_) {
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
<< "closures are not exportable";
|
||||
}
|
||||
assignValuesToTheirUniqueNames(node->outputs());
|
||||
auto name = useOf(node->output());
|
||||
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
|
||||
indent();
|
||||
body_ << "def " << name << "(";
|
||||
assignValuesToTheirUniqueNames(graph->inputs());
|
||||
for(size_t i = 0; i < graph->inputs().size(); ++i) {
|
||||
Value* v = graph->inputs().at(i);
|
||||
if (i > 0) {
|
||||
body_ << ", ";
|
||||
}
|
||||
body_ << useOf(v) << ": " << v->type()->python_str();
|
||||
}
|
||||
body_ << "):\n";
|
||||
printBody(graph->block());
|
||||
} break;
|
||||
default:
|
||||
std::stringstream ss;
|
||||
printRHS(ss, node);
|
||||
|
|
@ -899,30 +933,6 @@ struct PythonPrintPass {
|
|||
stmt << "(" << useOf(node->inputs().at(0)) << ")["
|
||||
<< useOf(node->inputs().at(1)) << "]";
|
||||
} break;
|
||||
case prim::fork: {
|
||||
// the subgraph gets emitted as another function
|
||||
auto name = genMethodName("__forked_function");
|
||||
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
|
||||
worklist.emplace_back(
|
||||
[graph, name, this] { printFunctionDefinition(*graph, name); });
|
||||
// and we put a call to fork which invokes that function.
|
||||
stmt << "fork(self." << name;
|
||||
for (Value* v : node->inputs()) {
|
||||
stmt << ", " << useOf(v);
|
||||
}
|
||||
stmt << ")";
|
||||
} break;
|
||||
case prim::Function: {
|
||||
if (enforce_importable_) {
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
<< "closures are not exportable";
|
||||
}
|
||||
auto name = genMethodName("__lambda");
|
||||
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
|
||||
worklist.emplace_back(
|
||||
[graph, name, this] { printFunctionDefinition(*graph, name); });
|
||||
stmt << "self." << name;
|
||||
} break;
|
||||
case prim::CreateObject: {
|
||||
const auto classType = node->output()->type()->expect<ClassType>();
|
||||
stmt << classType->name() << ".__new__(" << classType->name() << ")";
|
||||
|
|
@ -931,7 +941,13 @@ struct PythonPrintPass {
|
|||
const auto obj = node->inputs().at(0);
|
||||
const auto classType = obj->type()->expect<ClassType>();
|
||||
const auto& field = node->s(attr::name);
|
||||
stmt << useOf(obj) << "." << field;
|
||||
if (isValidIdentifier(field)) {
|
||||
stmt << useOf(obj) << "." << field;
|
||||
} else {
|
||||
stmt << "getattr(" << useOf(obj) << ", ";
|
||||
printQuotedString(stmt, field);
|
||||
stmt << ")";
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
Symbol kind = node->kind();
|
||||
|
|
@ -999,69 +1015,14 @@ struct PythonPrintPass {
|
|||
stmt << "=";
|
||||
printConstant(stmt, value);
|
||||
}
|
||||
void printFunctionDefinition(
|
||||
Graph& graph,
|
||||
const std::string& name,
|
||||
bool is_class = false,
|
||||
const std::vector<c10::optional<IValue>>& defaults = {},
|
||||
const std::vector<std::string>& param_names = {}) {
|
||||
used_names_.clear(); // each graph can reuse local names
|
||||
|
||||
void printBody(Block* body) {
|
||||
// we always print constants at the top of the function, in the order
|
||||
// in which they are used.
|
||||
std::vector<Node*> constants;
|
||||
buildConstantList(graph.block(), constants);
|
||||
buildConstantList(body, constants);
|
||||
|
||||
// current graph is used to de-dup names within a single graph
|
||||
scanBlock(graph.block());
|
||||
|
||||
// last param_names.size() arguments to the graph are parameters and not
|
||||
// actual inputs, we will print these as, e.g. self.foo.bar
|
||||
// while we print the true_inputs out as parameters
|
||||
auto true_inputs =
|
||||
graph.inputs().slice(0, graph.inputs().size() - param_names.size());
|
||||
auto param_names_it = param_names.begin();
|
||||
for (auto param : graph.inputs().slice(true_inputs.size())) {
|
||||
assignValue(param, *param_names_it++);
|
||||
}
|
||||
assignValuesToTheirUniqueNames(true_inputs);
|
||||
auto defaults_offset = defaults.begin();
|
||||
|
||||
indent();
|
||||
body_ << "def " << name << "(";
|
||||
|
||||
auto input_iter = true_inputs.begin();
|
||||
// Print the `self` argument
|
||||
if (is_class) {
|
||||
// If this is a class, print the self var without a type annotation,
|
||||
// following Python convention
|
||||
AT_ASSERT(true_inputs.size() > 0);
|
||||
body_ << useOf(*input_iter);
|
||||
++input_iter;
|
||||
|
||||
AT_ASSERT(!defaults_offset->has_value());
|
||||
++defaults_offset;
|
||||
} else {
|
||||
// If this is not a class, then we need to insert a "self".
|
||||
body_ << "self";
|
||||
}
|
||||
|
||||
// Print the rest of the arguments
|
||||
for (; input_iter != true_inputs.end(); ++input_iter) {
|
||||
auto input = *input_iter;
|
||||
body_ << ",\n " << useOf(input) << ": " << input->type()->python_str();
|
||||
if (defaults_offset != defaults.end()) {
|
||||
const c10::optional<IValue>& def = *defaults_offset++;
|
||||
if (def) {
|
||||
printDefaultValue(input->type(), body_, *def);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// have we use all the provided defaults?
|
||||
AT_ASSERT(defaults_offset == defaults.end());
|
||||
|
||||
body_ << ") -> " << resultType(graph)->python_str() << ":\n";
|
||||
scanBlock(body);
|
||||
{
|
||||
auto guard = WithIndented();
|
||||
// Print initial constant table (most are just inlined into their use,
|
||||
|
|
@ -1071,19 +1032,52 @@ struct PythonPrintPass {
|
|||
}
|
||||
// Print body
|
||||
printBlock(
|
||||
graph.block(), graph.block()->return_node()->inputs().size() > 0);
|
||||
printNode(graph.block()->return_node(), /*print_const=*/false);
|
||||
body, body->return_node()->inputs().size() > 0);
|
||||
printNode(body->return_node(), /*print_const=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
void printFunction(script::Function& func) {
|
||||
const FunctionSchema& schema = func.getSchema();
|
||||
Graph& graph = *func.graph();
|
||||
used_names_.clear(); // each graph can reuse local names
|
||||
|
||||
|
||||
indent();
|
||||
body_ << "def " << func.name() << "(";
|
||||
auto param_it = graph.inputs().begin();
|
||||
for(const Argument& arg : schema.arguments()) {
|
||||
std::string arg_name = genName(arg.name());
|
||||
if (param_it == graph.inputs().begin()) {
|
||||
// the first argument may omit its type when it is implied by context
|
||||
// the flag is_method_ determines when to do this
|
||||
body_ << arg_name;
|
||||
if (!is_method_) {
|
||||
body_ << ": " << arg.type()->python_str();
|
||||
}
|
||||
} else {
|
||||
body_ << ",\n " << arg_name << ": " << arg.type()->python_str();
|
||||
}
|
||||
if (arg.default_value()) {
|
||||
printDefaultValue(arg.type(), body_, *arg.default_value());
|
||||
}
|
||||
assignValue(*param_it++, arg_name);
|
||||
}
|
||||
|
||||
body_ << ") -> " << resultType(graph)->python_str() << ":\n";
|
||||
printBody(graph.block());
|
||||
}
|
||||
|
||||
PythonPrintPass(
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable)
|
||||
bool enforce_importable,
|
||||
bool is_method)
|
||||
: tensor_table_(tensor_table),
|
||||
class_table_(class_table),
|
||||
enforce_importable_(enforce_importable) {}
|
||||
enforce_importable_(enforce_importable),
|
||||
is_method_(is_method) {}
|
||||
|
||||
// TODO: we should consider forcing functions to return a single value
|
||||
// instead of handling this tuple logic both in the compiler and the printer
|
||||
|
|
@ -1096,63 +1090,9 @@ struct PythonPrintPass {
|
|||
}
|
||||
}
|
||||
|
||||
void printFunction(
|
||||
Graph& graph,
|
||||
const std::string& name,
|
||||
bool is_class,
|
||||
const std::vector<c10::optional<IValue>>& defaults = {},
|
||||
const std::vector<std::string>& param_names = {}) {
|
||||
printFunctionDefinition(graph, name, is_class, defaults, param_names);
|
||||
while (!worklist.empty()) {
|
||||
body_ << "\n\n";
|
||||
auto work = worklist.back();
|
||||
worklist.pop_back();
|
||||
work();
|
||||
}
|
||||
}
|
||||
void printMethod(script::Method& method) {
|
||||
std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
|
||||
createTensorToParameterNameMap(
|
||||
method.owner(), QualifiedName::create("self"), extra_ivalue_names);
|
||||
printMethod(method, /*is_class=*/false, extra_ivalue_names);
|
||||
}
|
||||
void printMethod(
|
||||
script::Method& method,
|
||||
bool is_class,
|
||||
const std::unordered_map<script::Slot, QualifiedNamePtr>&
|
||||
extra_ivalue_names) {
|
||||
std::vector<std::string> ivalue_names =
|
||||
fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
|
||||
return extra_ivalue_names.at(slot)->str();
|
||||
});
|
||||
const std::string& name = method.name();
|
||||
Graph& graph = *method.graph();
|
||||
auto defaults = fmap(
|
||||
method.getSchema().arguments(),
|
||||
[](const Argument& arg) { return arg.default_value(); });
|
||||
printFunction(graph, name, is_class, defaults, ivalue_names);
|
||||
}
|
||||
void printFunction(script::Function& method, bool is_class) {
|
||||
const std::string& name = method.name();
|
||||
Graph& graph = *method.graph();
|
||||
auto defaults = fmap(
|
||||
method.getSchema().arguments(),
|
||||
[](const Argument& arg) { return arg.default_value(); });
|
||||
printFunction(graph, name, is_class, defaults, {});
|
||||
}
|
||||
void printModule(script::Module& module) {
|
||||
std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
|
||||
createTensorToParameterNameMap(
|
||||
module, QualifiedName::create("self"), extra_ivalue_names);
|
||||
for (auto& method : module.get_methods()) {
|
||||
const std::string& name = method->name();
|
||||
// we skip __forked_functions because they actually get inlined into their
|
||||
// callers, exporting them again will lead to more code generated on each
|
||||
// export
|
||||
if (name.find("__forked_function") == 0) {
|
||||
continue;
|
||||
}
|
||||
printMethod(*method, /*is_class=*/false, extra_ivalue_names);
|
||||
void printCompilationUnit(script::CompilationUnit& cu) {
|
||||
for (auto& func : cu.get_functions()) {
|
||||
printFunction(*func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1161,7 +1101,7 @@ struct PythonPrintPass {
|
|||
{
|
||||
const auto guard = WithIndented();
|
||||
for (auto& method : classType->methods()) {
|
||||
printFunction(*method, /*is_class=*/true);
|
||||
printFunction(*method);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1173,49 +1113,27 @@ struct PythonPrintPass {
|
|||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const Graph& graph,
|
||||
const script::Function& func,
|
||||
bool is_method,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable) {
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable, is_method);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
pp.printFunction(const_cast<Graph&>(graph), "graph", /*is_class=*/false);
|
||||
pp.printFunction(const_cast<script::Function&>(func));
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Method& method,
|
||||
const script::CompilationUnit& cu,
|
||||
bool is_method,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable) {
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable, is_method);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
pp.printMethod(const_cast<script::Method&>(method));
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Function& callee,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable) {
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
pp.printFunction(const_cast<script::Function&>(callee), /*is_class=*/false);
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Module& module,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable) {
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
pp.printModule(const_cast<script::Module&>(module));
|
||||
pp.printCompilationUnit(const_cast<script::CompilationUnit&>(cu));
|
||||
pp.print(out);
|
||||
}
|
||||
|
||||
|
|
@ -1225,7 +1143,7 @@ TORCH_API void PythonPrint(
|
|||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable) {
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable);
|
||||
PythonPrintPass pp(tensor_table, class_table, enforce_importable, true);
|
||||
pp.printClass(classType);
|
||||
pp.print(out);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,33 +12,21 @@ struct Method;
|
|||
struct Module;
|
||||
} // namespace script
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const Graph& graph,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable = false);
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Method& graph,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable = false);
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Function& callee,
|
||||
bool is_method,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable = false);
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
const script::Module& module,
|
||||
const script::CompilationUnit& cu,
|
||||
bool is_method,
|
||||
std::vector<at::Tensor>& tensor_table,
|
||||
std::vector<ClassTypePtr>& class_table,
|
||||
bool enforce_importable = false);
|
||||
bool enforce_importable);
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
|
|
|
|||
|
|
@ -343,13 +343,6 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
})
|
||||
.def("param_node", [](Graph& g) { return g.block()->param_node(); })
|
||||
.def("return_node", [](Graph& g) { return g.block()->return_node(); })
|
||||
.def(
|
||||
"pretty_print",
|
||||
[](Graph& g) {
|
||||
std::ostringstream oss;
|
||||
g.prettyPrint(oss);
|
||||
return oss.str();
|
||||
})
|
||||
.GS(createFusionGroup)
|
||||
.def(
|
||||
"createClone",
|
||||
|
|
|
|||
|
|
@ -2263,7 +2263,7 @@ struct to_ir {
|
|||
}
|
||||
// Lambda lift block(0) into attr::Subgraph
|
||||
lambdaLiftFork(fork_node);
|
||||
|
||||
runCleanupPasses(fork_node->g(attr::Subgraph));
|
||||
return std::make_shared<SimpleValue>(node_output);
|
||||
}
|
||||
|
||||
|
|
@ -2871,11 +2871,13 @@ void lambdaLiftFork(Node* fork_node) {
|
|||
auto env = [&](Value* v) -> Value* {
|
||||
if (!uncaptures_map.count(v)) {
|
||||
// Capture values for both graphs
|
||||
uncaptures_map[v] = forked_graph->addInput()->copyMetadata(v);
|
||||
uncaptures_map[v] =
|
||||
forked_graph->addInput()->copyMetadata(v);
|
||||
fork_node->addInput(v);
|
||||
}
|
||||
return uncaptures_map[v];
|
||||
};
|
||||
|
||||
forked_graph->block()->cloneFrom(body_block, env);
|
||||
|
||||
// Separate the subgraph and clean up the orignal one
|
||||
|
|
|
|||
|
|
@ -12,8 +12,14 @@ struct ReturnInfo {
|
|||
};
|
||||
|
||||
void checkNoReturn(const TreeRef& ref) {
|
||||
if (ref->kind() == TK_RETURN)
|
||||
if (ref->kind() == TK_RETURN) {
|
||||
throw ErrorReport(ref) << "return is not allowed from a loop.";
|
||||
}
|
||||
// do not search into first-class functions
|
||||
if (ref->kind() == TK_DEF) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const TreeRef& child : ref->trees()) {
|
||||
checkNoReturn(child);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -939,7 +939,13 @@ void initJitScriptBindings(PyObject* module) {
|
|||
std::ostringstream ss;
|
||||
std::vector<at::Tensor> tensors;
|
||||
std::vector<ClassTypePtr> classes;
|
||||
PythonPrint(ss, self, tensors, classes, false);
|
||||
PythonPrint(
|
||||
ss,
|
||||
self.class_compilation_unit(),
|
||||
true,
|
||||
tensors,
|
||||
classes,
|
||||
false);
|
||||
return ss.str();
|
||||
})
|
||||
.def("apply", &Module::apply)
|
||||
|
|
@ -989,7 +995,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
std::ostringstream ss;
|
||||
std::vector<at::Tensor> tensors;
|
||||
std::vector<ClassTypePtr> classes;
|
||||
PythonPrint(ss, self, tensors, classes, false);
|
||||
PythonPrint(ss, self, false, tensors, classes, false);
|
||||
return ss.str();
|
||||
})
|
||||
.def(
|
||||
|
|
@ -1021,7 +1027,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
std::ostringstream ss;
|
||||
std::vector<at::Tensor> tensors;
|
||||
std::vector<ClassTypePtr> classes;
|
||||
PythonPrint(ss, self, tensors, classes, false);
|
||||
PythonPrint(ss, self.function(), true, tensors, classes, false);
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
|
|
@ -1119,12 +1125,12 @@ void initJitScriptBindings(PyObject* module) {
|
|||
std::vector<at::Tensor> constants;
|
||||
std::vector<ClassTypePtr> classes;
|
||||
if (auto self = as_module(obj)) {
|
||||
PythonPrint(ss, *self, constants, classes, true);
|
||||
PythonPrint(ss, self->class_compilation_unit(), true, constants, classes, true);
|
||||
} else if (auto self = as_function(obj)) {
|
||||
PythonPrint(ss, *self, constants, classes, true);
|
||||
PythonPrint(ss, *self, false, constants, classes, true);
|
||||
} else {
|
||||
auto& m = py::cast<Method&>(obj);
|
||||
PythonPrint(ss, m, constants, classes, true);
|
||||
PythonPrint(ss, m.function(), true, constants, classes, true);
|
||||
}
|
||||
return std::make_pair(ss.str(), std::move(constants));
|
||||
});
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
|
|||
}
|
||||
Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
|
||||
if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
|
||||
if (c->name() == "Module") {
|
||||
if (c->name() == "$Module") {
|
||||
auto obj = slot.value().toObject();
|
||||
for (Use use : e.n->output()->uses()) {
|
||||
to_scan.emplace_back(ToScan{obj, use.user, use.offset});
|
||||
|
|
|
|||
|
|
@ -118,7 +118,6 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
|||
if (auto method = classType->getMethod(field)) {
|
||||
return std::make_shared<MethodValue>(getValue(), method);
|
||||
}
|
||||
|
||||
if (!classType->hasAttribute(field)) {
|
||||
throw ErrorReport(loc)
|
||||
<< "Tried to access to nonexistent attribute " << field
|
||||
|
|
@ -215,6 +214,35 @@ void SimpleValue::setAttr(
|
|||
g.insertNode(g.createSetAttr(value_, field, newValue));
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> SimpleValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
at::ArrayRef<NamedValue> inputs,
|
||||
at::ArrayRef<NamedValue> attributes,
|
||||
size_t n_binders) {
|
||||
// allow our 'fake' closures to be called, used for fork serialization
|
||||
// at the moment, but can be expanded later
|
||||
Node* self = getValue()->node();
|
||||
if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 &&
|
||||
self->inputs().at(0)->node()->kind() == prim::Function) {
|
||||
std::shared_ptr<Graph> graph = self->inputs().at(0)->node()->g(attr::Subgraph);
|
||||
Value* context = self->inputs().at(1);
|
||||
AT_ASSERT(context->node()->kind() == prim::TupleConstruct);
|
||||
|
||||
// fork nodes are emitted in their own block but we do not simplify
|
||||
// tuple construction across blocks. To ensure we clean up the tuple
|
||||
// construct create another copy of the tuple construct in the fork block
|
||||
Value* close_context =
|
||||
m.graph()
|
||||
->insertNode(m.graph()->createTuple(context->node()->inputs()))
|
||||
->output();
|
||||
auto fn = CompilationUnit().create_function("anon", graph);
|
||||
return MethodValue(close_context, fn).call(loc, m, inputs, attributes, n_binders);
|
||||
}
|
||||
return SugaredValue::call(loc, m, inputs, attributes, n_binders);
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<SugaredValue> ClassValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
|
|
|
|||
|
|
@ -129,6 +129,14 @@ struct TORCH_API SimpleValue : public SugaredValue {
|
|||
const std::string& field,
|
||||
Value* newValue) override;
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
// note: names for args will be 'argument 0', 'argument 1', etc..
|
||||
at::ArrayRef<NamedValue> inputs_,
|
||||
at::ArrayRef<NamedValue> attributes,
|
||||
size_t n_binders) override;
|
||||
|
||||
Value* getValue() const {
|
||||
return value_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -707,15 +707,6 @@ class CompilationUnit(object):
|
|||
def _import(self, src, constants):
|
||||
""" test import logic for single function, use only for testing """
|
||||
src = "op_version_set = 0\n{}".format(src)
|
||||
# HACK: export always adds a self argument
|
||||
# even if the thing is a pure function,
|
||||
# we just delete it here
|
||||
# Once the serialization outputs the first-class
|
||||
# functions and not the lowered functions,
|
||||
# we can change this behavior
|
||||
src = (src.replace('self,\n', '')
|
||||
.replace('(self)', '()')
|
||||
.replace('self.__forked_function', '__forked_function'))
|
||||
torch._C._jit_import_functions(self._c, src, constants, None)
|
||||
return self
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user