pytorch/torch/csrc/jit/script/compiler.h
Michael Suo f9820e55af initializing class value (#17585)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17585

Create a sugared value that represents a class during initialization. This is
so that assignments to attributes correctly define attributes in __init__ but
raise an error elsewhere.

Reviewed By: shannonzhu

Differential Revision: D14263403

fbshipit-source-id: 09b2feeb272302f00a79c2a0302fbdf5483aed6a
2019-03-11 19:13:52 -07:00

76 lines
2.3 KiB
C++

#pragma once
#include <functional>
#include <memory>
#include <string>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/sugared_value.h>
#include <torch/csrc/jit/script/tree_views.h>
namespace torch {
namespace jit {
namespace script {
using Resolver = std::function<std::shared_ptr<
SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
inline std::shared_ptr<SugaredValue> nativeResolver(
const std::string& name,
Method& m,
const SourceRange& loc) {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");
}
return nullptr;
}
// Represents the `self` argument to a method. This wrapper class is necessary
// because sometimes `self` sometimes is first class and sometimes not.
//
// `self` is first class when it refers to a ClassType. It will be bound as a
// graph input argument.
// `self` is sugared when it refers to a ModuleValue.
class Self {
public:
explicit Self(std::shared_ptr<SugaredValue> sugared)
: sugared_(std::move(sugared)) {}
explicit Self(ClassTypePtr type) : firstClass_(std::move(type)) {}
ClassTypePtr asFirstClass() const {
return firstClass_;
}
std::shared_ptr<SugaredValue> asSugared() const {
return sugared_;
}
private:
// Used when `self` is not first-class and so we don't represent it in the
// graph. This is only ModuleValue.
std::shared_ptr<SugaredValue> sugared_ = nullptr;
// Used when `self` is a first-class type
ClassTypePtr firstClass_ = nullptr;
};
TORCH_API void defineMethodsInModule(
const std::shared_ptr<Module>& m,
const std::vector<Def>& definitions,
const std::vector<Resolver>& resolvers, /* determines how we handle free
variables in each definition*/
// if non-null, the first argument to each def, is bound to this value
const c10::optional<Self>& self);
// same as above but parse the definitions from source
TORCH_API void defineMethodsInModule(
const std::shared_ptr<Module>& m,
const std::string& source,
const Resolver& resolver,
const c10::optional<Self>& self);
TORCH_API void lambdaLiftFork(Node* fork_node);
} // namespace script
} // namespace jit
} // namespace torch