mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable.
Currently, something like this works:
```cpp
struct Foo {
int x, y;
Foo(): x(2), y(5){}
Foo(int x_, int y_) : x(x_), y(y_) {}
void display() {
cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
}
int64_t add(int64_t z) {
return (x+y)*z;
}
};
static auto test = torch::jit::class_<Foo>("Foo")
.def(torch::jit::init<int64_t, int64_t>())
.def("display", &Foo::display)
.def("add", &Foo::add)
.def("combine", &Foo::combine);
```
with
```py
torch.jit.script
def f(x):
val = torch._C.Foo(5, 3)
val.display()
print(val.add(3))
```
results in
```
x: 5 y: 3
24
```
Current issues:
- [x] The python class created by torchscript doesn't interactly properly with the surrounding code.
```
torch.jit.script
def f(x):
val = torch._C.Foo(5, 3)
return val
```
- [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
```cpp
void combine(Foo x) {
```
- [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
```py
torch.jit.script
def f(x):
val = torch._C.Foo(5, 3)
val2 = torch._C.Foo(100, 0)
val.display()
print(val.add(3))
```
- [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
- [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())`
- [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually.
- [ ] Currently, the conversion from Python into Torchscript doesn't work.
- [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
- [ ] We pass back into Python by value, currently. There's no way of passing by reference.
- [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature.
Somewhat blocked on https://github.com/pytorch/pytorch/pull/21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21098
Differential Revision: D16634872
Pulled By: Chillee
fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
783 lines
25 KiB
C++
783 lines
25 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <torch/csrc/Device.h>
|
|
#include <torch/csrc/Dtype.h>
|
|
#include <torch/csrc/Layout.h>
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
#include <torch/csrc/jit/script/module.h>
|
|
#include <torch/csrc/jit/tracer.h>
|
|
#include <torch/csrc/utils/auto_gil.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/six.h>
|
|
|
|
#include <ATen/core/function_schema.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstddef>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
// The visibility attribute is to avoid a warning about storing a field in the
|
|
// struct that has a different visibility (from pybind) than the struct.
|
|
#ifdef _WIN32
|
|
#define VISIBILITY_HIDDEN
|
|
#else
|
|
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
|
|
#endif
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// error reporting: when reporting user-caused errors, these functions should
|
|
// not use AT_ERROR macros, since these macros add stack trace information
|
|
// that is confusing to display to the end user since it always reports
|
|
// locations in libtorch code rather than user code.
|
|
|
|
using tracer::TypedStack;
|
|
|
|
inline std::shared_ptr<script::CompilationUnit> get_python_cu() {
|
|
return py::module::import("torch.jit")
|
|
.attr("_python_cu")
|
|
.cast<std::shared_ptr<script::CompilationUnit>>();
|
|
}
|
|
|
|
struct TypedIValue : public std::pair<IValue, TypePtr> {
|
|
using pair::pair;
|
|
|
|
IValue& ivalue() {
|
|
return this->first;
|
|
}
|
|
TypePtr& type() {
|
|
return this->second;
|
|
}
|
|
};
|
|
|
|
inline TypedIValue toDictKeyIValue(py::handle key) {
|
|
if (py::isinstance<py::str>(key)) {
|
|
return TypedIValue(
|
|
ConstantString::create(py::cast<std::string>(key)),
|
|
StringType::create());
|
|
} else if (py::isinstance<py::int_>(key)) {
|
|
return TypedIValue(py::cast<int64_t>(key), IntType::create());
|
|
} else if (py::isinstance<py::float_>(key)) {
|
|
return TypedIValue(py::cast<double>(key), FloatType::create());
|
|
} else {
|
|
AT_ERROR("Dictionary inputs may only have string, int, or float keys");
|
|
}
|
|
}
|
|
|
|
inline c10::optional<TypePtr> unifyOrInitializeType(
|
|
TypePtr accum,
|
|
TypePtr unify) {
|
|
if (!accum) {
|
|
return unify;
|
|
}
|
|
return unifyTypes(accum, unify);
|
|
}
|
|
|
|
MatchTypeReturn tryToInferContainerType(py::handle input);
|
|
|
|
// Try to infer the type of a Python object
|
|
// The type cannot be inferred if:
|
|
// input is a None
|
|
// input is an empty container (list, dict)
|
|
// input is an list with element types that cannot be unified
|
|
// input is an dict with key or value types that cannot be unified
|
|
inline MatchTypeReturn tryToInferType(py::handle input) {
|
|
// Try tensor types
|
|
if (THPVariable_Check(input.ptr())) {
|
|
auto tensor = py::cast<at::Tensor>(input);
|
|
if (tensor.is_sparse()) {
|
|
return MatchTypeReturn("Sparse tensors not supported");
|
|
}
|
|
if (tensor.is_mkldnn()) {
|
|
// mkldnn tensor as opaque tensor doesn't have strides, so we can
|
|
// not create a CompleteTensorType
|
|
return MatchTypeReturn(DimensionedTensorType::create(tensor));
|
|
}
|
|
|
|
// TODO: maybe unshape this type if this is used for script instead of
|
|
// tracing
|
|
return MatchTypeReturn(CompleteTensorType::create(tensor));
|
|
}
|
|
|
|
if (input.is(py::none())) {
|
|
return MatchTypeReturn("Cannot infer type of a None value");
|
|
}
|
|
|
|
// Try basic types first
|
|
if (py::isinstance<py::bool_>(input)) {
|
|
return MatchTypeReturn(BoolType::get());
|
|
} else if (py::isinstance<py::int_>(input)) {
|
|
return MatchTypeReturn(IntType::get());
|
|
} else if (py::isinstance<py::float_>(input)) {
|
|
return MatchTypeReturn(FloatType::get());
|
|
} else if (py::isinstance<py::str>(input)) {
|
|
return MatchTypeReturn(StringType::get());
|
|
} else if (THPLayout_Check(input.ptr())) {
|
|
return MatchTypeReturn(IntType::get());
|
|
} else if (THPDevice_Check(input.ptr())) {
|
|
return MatchTypeReturn(DeviceObjType::get());
|
|
} else if (THPDtype_Check(input.ptr())) {
|
|
return MatchTypeReturn(IntType::get());
|
|
}
|
|
|
|
// Try container types
|
|
return tryToInferContainerType(input);
|
|
}
|
|
|
|
inline MatchTypeReturn tryToInferContainerType(py::handle input) {
|
|
if (six::isTuple(input)) {
|
|
py::tuple tuple = py::cast<py::tuple>(input);
|
|
std::vector<TypePtr> element_types;
|
|
element_types.reserve(tuple.size());
|
|
|
|
for (py::handle elem : tuple) {
|
|
auto type_match = tryToInferType(elem);
|
|
if (type_match.type) {
|
|
element_types.push_back(*type_match.type);
|
|
} else {
|
|
// Forward error message along
|
|
return type_match.errMsg;
|
|
}
|
|
}
|
|
return MatchTypeReturn(TupleType::create(element_types));
|
|
} else if (PyDict_Check(input.ptr())) {
|
|
// Check to make sure we can generate useful input/output types
|
|
auto dict = py::cast<py::dict>(input);
|
|
size_t len = py::len(dict);
|
|
if (!len) {
|
|
return MatchTypeReturn("Dictionary inputs must have entries");
|
|
}
|
|
|
|
TypePtr key_type = nullptr;
|
|
TypePtr value_type = nullptr;
|
|
|
|
for (auto entry : dict) {
|
|
// Try to infer the key type and unify it with the existing one
|
|
auto entry_key_type_match = tryToInferType(entry.first);
|
|
if (!entry_key_type_match.type) {
|
|
return entry_key_type_match.errMsg;
|
|
}
|
|
auto unified_key =
|
|
unifyOrInitializeType(key_type, *entry_key_type_match.type);
|
|
if (!unified_key) {
|
|
return MatchTypeReturn(c10::str(
|
|
"Dictionary inputs to traced functions must have consistent type. Found ",
|
|
key_type->python_str(),
|
|
" and ",
|
|
(*entry_key_type_match.type)->python_str()));
|
|
}
|
|
|
|
// Try to infer the value type and unify it with the existing one
|
|
auto entry_value_type_match = tryToInferType(entry.second);
|
|
if (!entry_value_type_match.type) {
|
|
return entry_value_type_match.errMsg;
|
|
}
|
|
auto unified_value =
|
|
unifyOrInitializeType(value_type, *entry_value_type_match.type);
|
|
if (!unified_value) {
|
|
return MatchTypeReturn(c10::str(
|
|
"Dictionary inputs to traced functions must have consistent type. Found ",
|
|
value_type->python_str(),
|
|
" and ",
|
|
(*entry_value_type_match.type)->python_str()));
|
|
}
|
|
|
|
key_type = *unified_key;
|
|
value_type = *unified_value;
|
|
}
|
|
return MatchTypeReturn(DictType::create(key_type, value_type));
|
|
} else if (PyList_Check(input.ptr())) {
|
|
auto list = py::cast<py::list>(input);
|
|
size_t len = py::len(list);
|
|
if (!len) {
|
|
return MatchTypeReturn("List trace inputs must have elements");
|
|
}
|
|
|
|
TypePtr element_type = nullptr;
|
|
for (auto elem : list) {
|
|
auto element_type_match = tryToInferType(elem);
|
|
if (!element_type_match.type) {
|
|
return MatchTypeReturn(c10::str(
|
|
"Could not infer type of list element: ",
|
|
element_type_match.errMsg));
|
|
}
|
|
auto unified_type =
|
|
unifyOrInitializeType(element_type, *element_type_match.type);
|
|
if (!unified_type) {
|
|
return MatchTypeReturn(c10::str(
|
|
"List inputs to traced functions must have consistent element type. Found ",
|
|
element_type->python_str(),
|
|
" and ",
|
|
(*element_type_match.type)->python_str()));
|
|
}
|
|
element_type = *unified_type;
|
|
}
|
|
return MatchTypeReturn(ListType::create(element_type));
|
|
} else {
|
|
return MatchTypeReturn(c10::str(
|
|
"Only tensors and (possibly nested) tuples of tensors, lists, or dicts",
|
|
"are supported ",
|
|
"as inputs or outputs of traced functions",
|
|
", but instead got value of type ",
|
|
py::str(input.get_type().attr("__name__")),
|
|
".",
|
|
"\nValue: ",
|
|
py::repr(input)));
|
|
}
|
|
}
|
|
|
|
inline IValue toIValue(
|
|
py::handle obj,
|
|
const TypePtr& type,
|
|
c10::optional<int32_t> N = c10::nullopt);
|
|
|
|
inline bool isTraceableType(TypePtr type) {
|
|
if (type->isSubtypeOf(TensorType::get())) {
|
|
return true;
|
|
}
|
|
|
|
if (auto list_type = type->cast<ListType>()) {
|
|
return isTraceableType(list_type->getElementType());
|
|
}
|
|
|
|
if (auto tuple_type = type->cast<TupleType>()) {
|
|
return std::all_of(
|
|
tuple_type->elements().begin(),
|
|
tuple_type->elements().end(),
|
|
[](TypePtr element_type) { return isTraceableType(element_type); });
|
|
}
|
|
|
|
if (auto dict_type = type->cast<DictType>()) {
|
|
return isTraceableType(dict_type->getValueType());
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
inline TypedIValue toTraceableIValue(py::handle input) {
|
|
auto match = tryToInferType(input);
|
|
if (!match.type) {
|
|
AT_ERROR(
|
|
"Tracer cannot infer type of ", py::str(input), "\n:", match.errMsg);
|
|
}
|
|
auto type = *match.type;
|
|
|
|
if (isTraceableType(type)) {
|
|
return TypedIValue(toIValue(input, type), type);
|
|
}
|
|
|
|
AT_ERROR(
|
|
"Type '",
|
|
type->python_str(),
|
|
"' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
|
|
" Tuples of Tensors can be traced");
|
|
}
|
|
|
|
inline IValue toIValue(py::handle input) {
|
|
return toTraceableIValue(input).ivalue();
|
|
}
|
|
|
|
inline Stack toStack(const py::tuple& inputs) {
|
|
return toIValue(inputs).toTuple()->elements();
|
|
}
|
|
|
|
inline TypedStack toTypedStack(const py::tuple& inputs) {
|
|
auto info = toTraceableIValue(inputs);
|
|
return TypedStack(
|
|
info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
|
|
}
|
|
|
|
inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
|
|
auto elems = c10::impl::GenericList(elem_type);
|
|
for (auto elem : obj) {
|
|
elems.push_back(toIValue(std::move(elem), elem_type));
|
|
}
|
|
return IValue(std::move(elems));
|
|
}
|
|
|
|
inline IValue createGenericDict(
|
|
py::handle obj,
|
|
const TypePtr& key_type,
|
|
const TypePtr& value_type) {
|
|
c10::impl::GenericDict elems(key_type, value_type);
|
|
elems.reserve(py::len(obj));
|
|
for (auto key : obj) {
|
|
elems.insert(
|
|
toIValue(key, key_type), toIValue(obj[key], value_type));
|
|
}
|
|
return IValue(std::move(elems));
|
|
}
|
|
|
|
inline IValue toIValue(
|
|
py::handle obj,
|
|
const TypePtr& type,
|
|
c10::optional<int32_t> N) {
|
|
switch (type->kind()) {
|
|
case TypeKind::TensorType:
|
|
case TypeKind::AutogradZeroTensorType:
|
|
case TypeKind::DimensionedTensorType:
|
|
case TypeKind::ProfiledTensorType:
|
|
case TypeKind::CompleteTensorType: {
|
|
auto var = py::cast<autograd::Variable>(obj);
|
|
if (var.is_sparse()) {
|
|
AT_ERROR("sparse tensors not supported");
|
|
}
|
|
return var;
|
|
}
|
|
case TypeKind::FloatType:
|
|
return py::cast<double>(obj);
|
|
case TypeKind::IntType:
|
|
return py::cast<int64_t>(obj);
|
|
case TypeKind::NoneType:
|
|
if (!obj.is_none()) {
|
|
throw py::cast_error(
|
|
c10::str("Cannot cast ", py::str(obj), " to None"));
|
|
}
|
|
return {};
|
|
case TypeKind::BoolType:
|
|
return py::cast<bool>(obj);
|
|
case TypeKind::TupleType: {
|
|
py::tuple tuple = py::cast<py::tuple>(obj);
|
|
size_t tuple_size = tuple.size();
|
|
auto tuple_type = type->cast<TupleType>();
|
|
const auto& elem_types = tuple_type->elements();
|
|
if (elem_types.size() != tuple_size) {
|
|
throw py::cast_error(c10::str(
|
|
"Object ",
|
|
py::str(obj),
|
|
" had a different number of elements than type ",
|
|
type->python_str()));
|
|
}
|
|
std::vector<IValue> values;
|
|
values.reserve(tuple_size);
|
|
for (size_t i = 0; i < tuple_size; ++i) {
|
|
values.push_back(toIValue(tuple[i], elem_types[i]));
|
|
}
|
|
return c10::ivalue::Tuple::create(std::move(values), tuple_type);
|
|
}
|
|
case TypeKind::StringType:
|
|
return ConstantString::create(py::cast<std::string>(obj));
|
|
case TypeKind::DeviceObjType: {
|
|
auto device = reinterpret_cast<THPDevice*>(obj.ptr());
|
|
return device->device;
|
|
}
|
|
case TypeKind::ListType: {
|
|
const auto& elem_type = type->expect<ListType>()->getElementType();
|
|
switch (elem_type->kind()) {
|
|
// allows single int/float to be broadcasted to a fixed size list
|
|
case TypeKind::IntType:
|
|
if (!N || !py::isinstance<py::int_>(obj)) {
|
|
return c10::impl::toList(py::cast<std::vector<int64_t>>(obj));
|
|
} else {
|
|
double value = py::cast<int64_t>(obj);
|
|
c10::List<double> repeated;
|
|
repeated.reserve(*N);
|
|
for (int i = 0; i < *N; ++i) {
|
|
repeated.push_back(value);
|
|
}
|
|
return repeated;
|
|
}
|
|
case TypeKind::FloatType:
|
|
if (!N || !py::isinstance<py::float_>(obj)) {
|
|
return c10::impl::toList(py::cast<std::vector<double>>(obj));
|
|
} else {
|
|
double value = py::cast<double>(obj);
|
|
c10::List<double> repeated;
|
|
repeated.reserve(*N);
|
|
for (int i = 0; i < *N; ++i) {
|
|
repeated.push_back(value);
|
|
}
|
|
return repeated;
|
|
}
|
|
case TypeKind::DimensionedTensorType:
|
|
case TypeKind::ProfiledTensorType:
|
|
case TypeKind::TensorType:
|
|
return c10::impl::toList(py::cast<std::vector<at::Tensor>>(obj));
|
|
default:
|
|
return createGenericList(obj, elem_type);
|
|
}
|
|
}
|
|
case TypeKind::DictType: {
|
|
const auto& dict_type = type->expect<DictType>();
|
|
return createGenericDict(
|
|
obj, dict_type->getKeyType(), dict_type->getValueType());
|
|
}
|
|
case TypeKind::OptionalType: {
|
|
// check if it's a none obj since optional accepts NoneType
|
|
if (obj.is_none()) {
|
|
// check if it's a none obj since optional accepts NoneType
|
|
// return an IValue() to denote a NoneType
|
|
return {};
|
|
}
|
|
return toIValue(obj, type->expect<OptionalType>()->getElementType());
|
|
}
|
|
case TypeKind::ClassType: {
|
|
auto classType = type->expect<ClassType>();
|
|
// 1. create a bare ivalue
|
|
const size_t numAttrs = classType->numAttributes();
|
|
auto cu = classType->compilation_unit();
|
|
auto userObj = c10::ivalue::Object::create(
|
|
c10::StrongTypePtr(cu, classType), numAttrs);
|
|
|
|
// 2. copy all the contained types
|
|
for (size_t slot = 0; slot < numAttrs; slot++) {
|
|
const auto& attrType = classType->getAttribute(slot);
|
|
const auto& attrName = classType->getAttributeName(slot);
|
|
|
|
const auto& contained = py::getattr(obj, attrName.c_str());
|
|
userObj->setSlot(slot, toIValue(contained, attrType));
|
|
}
|
|
return userObj;
|
|
}
|
|
case TypeKind::NumberType:
|
|
if (py::isinstance<py::int_>(obj)) {
|
|
return py::cast<int64_t>(obj);
|
|
} else if (py::isinstance<py::float_>(obj)) {
|
|
return py::cast<double>(obj);
|
|
}
|
|
case TypeKind::GeneratorType:
|
|
case TypeKind::VarType:
|
|
case TypeKind::FutureType:
|
|
break;
|
|
case TypeKind::FunctionType:
|
|
AT_ERROR("Function Values aren't yet supported");
|
|
case TypeKind::CapsuleType:
|
|
AT_ERROR("Capsule Values aren't supported");
|
|
}
|
|
AT_ERROR(
|
|
"Missing cases in toIValue for type: ",
|
|
type->str(),
|
|
"! File a bug report.");
|
|
}
|
|
|
|
// Small wrapper around getting the type name string from Python to make
|
|
// types easier to interpret, e.g. give the structural type for a NamedTuple
|
|
inline std::string friendlyTypeName(py::handle obj) {
|
|
if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
|
|
auto field_names =
|
|
py::cast<std::vector<std::string>>(py::getattr(obj, "_fields"));
|
|
std::stringstream ss;
|
|
ss << py::str(obj.get_type().attr("__name__"));
|
|
ss << " (aka NamedTuple(";
|
|
bool first = true;
|
|
for (auto& field_name : field_names) {
|
|
if (!first) {
|
|
ss << ", ";
|
|
}
|
|
ss << field_name;
|
|
first = false;
|
|
}
|
|
ss << "))";
|
|
return ss.str();
|
|
} else {
|
|
return py::str(obj.get_type().attr("__name__"));
|
|
}
|
|
}
|
|
|
|
inline IValue argumentToIValue(
|
|
const FunctionSchema& schema,
|
|
size_t argumentPosition,
|
|
py::handle object) {
|
|
const auto& argument = schema.arguments().at(argumentPosition);
|
|
try {
|
|
return toIValue(object, argument.type(), argument.N());
|
|
} catch (const py::cast_error& error) {
|
|
throw std::runtime_error(schema.formatTypeMismatchMsg(
|
|
argument,
|
|
friendlyTypeName(object),
|
|
argumentPosition,
|
|
py::repr(object)));
|
|
}
|
|
}
|
|
|
|
inline IValue returnToIValue(const TypePtr& type, py::handle object) {
|
|
try {
|
|
return toIValue(object, type);
|
|
} catch (const py::cast_error& error) {
|
|
throw std::runtime_error(c10::str(
|
|
" expected value of type ",
|
|
type->str(),
|
|
" for return value but instead got value of type ",
|
|
py::str(object.get_type().attr("__name__")),
|
|
".",
|
|
"\nValue: ",
|
|
py::repr(object)));
|
|
}
|
|
}
|
|
|
|
inline c10::optional<py::object> tryToConvertToCustomClass(
|
|
const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
|
|
if (obj->name().find("__torch__.torch.classes") == 0) {
|
|
auto objPtr = (void*)obj->getSlot(0).toCapsule().release();
|
|
auto classConverter = c10::getClassConverter()[obj->name()];
|
|
py::handle rawPyObj = classConverter(objPtr);
|
|
auto o = py::reinterpret_steal<py::object>(rawPyObj);
|
|
return o;
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
inline py::object toPyObject(IValue&& ivalue) {
|
|
if (ivalue.isNone()) {
|
|
return py::none();
|
|
} else if (ivalue.isTensor()) {
|
|
auto tensor = std::move(ivalue).toTensor();
|
|
if (tensor.is_sparse()) {
|
|
AT_ERROR("sparse tensors not supported");
|
|
}
|
|
return py::cast(autograd::Variable(std::move(tensor)));
|
|
} else if (ivalue.isDouble()) {
|
|
return py::cast(std::move(ivalue).toDouble());
|
|
} else if (ivalue.isInt()) {
|
|
return py::cast(std::move(ivalue).toInt());
|
|
} else if (ivalue.isBool()) {
|
|
return py::cast(std::move(ivalue).toBool());
|
|
} else if (ivalue.isString()) {
|
|
return py::cast(std::move(ivalue).toStringRef());
|
|
} else if (ivalue.isIntList()) {
|
|
return py::cast(c10::impl::toVector(std::move(ivalue).toIntList()));
|
|
} else if (ivalue.isDoubleList()) {
|
|
return py::cast(c10::impl::toVector(std::move(ivalue).toDoubleList()));
|
|
} else if (ivalue.isBoolList()) {
|
|
return py::cast(c10::impl::toVector(std::move(ivalue).toBoolList()));
|
|
} else if (ivalue.isTensorList()) {
|
|
return py::cast(c10::impl::toVector(std::move(ivalue).toTensorList()));
|
|
} else if (ivalue.isGenericList()) {
|
|
auto list = std::move(ivalue).toGenericList();
|
|
py::list t{list.size()};
|
|
for (size_t i = 0; i < list.size(); ++i) {
|
|
t[i] = toPyObject(IValue{list.get(i)});
|
|
}
|
|
return std::move(t);
|
|
} else if (ivalue.isTuple()) {
|
|
auto tuple = std::move(ivalue).toTuple();
|
|
const auto& elements = tuple->elements();
|
|
py::tuple t{elements.size()};
|
|
for (size_t i = 0; i < elements.size(); ++i) {
|
|
t[i] = toPyObject(IValue{elements.at(i)});
|
|
}
|
|
if (tuple->type && tuple->type->schema() &&
|
|
tuple->type->schema()->name() != "") {
|
|
auto unqualName = tuple->type->basename();
|
|
auto fieldNames = fmap(tuple->type->schema()->arguments(), [](const Argument& arg) {
|
|
return arg.name();
|
|
});
|
|
return py::module::import("torch.jit")
|
|
.attr("_create_named_tuple")(
|
|
t, unqualName, fieldNames);
|
|
} else {
|
|
return std::move(t);
|
|
}
|
|
} else if (ivalue.isDevice()) {
|
|
return py::cast<py::object>(THPDevice_New(std::move(ivalue).toDevice()));
|
|
} else if (ivalue.isGenericDict()) {
|
|
auto dict = std::move(ivalue).toGenericDict();
|
|
py::dict py_dict;
|
|
for (auto& pair : dict) {
|
|
py_dict[toPyObject(IValue{pair.key()})] = toPyObject(IValue{pair.value()});
|
|
}
|
|
return std::move(py_dict);
|
|
} else if (ivalue.isObject()) {
|
|
const auto obj = std::move(ivalue).toObject();
|
|
auto pyCu = get_python_cu();
|
|
auto res = tryToConvertToCustomClass(obj);
|
|
if (res.has_value()) {
|
|
return res.value();
|
|
}
|
|
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
|
|
AT_ASSERT(classType);
|
|
auto pyClass =
|
|
py::module::import("torch.jit").attr("_get_script_class")(obj->name());
|
|
auto pyObj = pyClass.attr("__new__")(pyClass);
|
|
|
|
const auto numAttrs = classType->numAttributes();
|
|
|
|
for (size_t slot = 0; slot < numAttrs; slot++) {
|
|
const auto& attrName = classType->getAttributeName(slot);
|
|
IValue v = obj->getSlot(slot);
|
|
py::setattr(pyObj, attrName.c_str(), toPyObject(std::move(v)));
|
|
}
|
|
return pyObj;
|
|
} else {
|
|
AT_ERROR(
|
|
"Missing cases in 'toPyObject'! Can't convert ",
|
|
ivalue.tagKind(),
|
|
" to a Python object");
|
|
}
|
|
}
|
|
|
|
struct VISIBILITY_HIDDEN tuple_slice {
|
|
/*implicit*/ tuple_slice(py::tuple tup_)
|
|
: tup(std::move(tup_)), b(0), e(tup.size()) {}
|
|
tuple_slice(py::tuple tup_, int64_t b_)
|
|
: tup(std::move(tup_)), b(b_), e(tup.size()) {}
|
|
tuple_slice(py::tuple tup_, int64_t b_, int64_t e_)
|
|
: tup(std::move(tup_)), b(b_), e(e_) {}
|
|
py::detail::tuple_iterator begin() const {
|
|
return {tup, static_cast<pybind11::ssize_t>(b)};
|
|
}
|
|
py::detail::tuple_iterator end() const {
|
|
return {tup, static_cast<pybind11::ssize_t>(e)};
|
|
}
|
|
size_t size() const {
|
|
return e - b;
|
|
}
|
|
py::detail::tuple_accessor operator[](size_t index) const {
|
|
return {tup, static_cast<size_t>(b + index)};
|
|
}
|
|
|
|
private:
|
|
py::tuple tup;
|
|
int64_t b;
|
|
int64_t e;
|
|
};
|
|
|
|
inline Stack createStackForSchema(
|
|
const FunctionSchema& schema,
|
|
const tuple_slice& args,
|
|
const py::kwargs& kwargs,
|
|
c10::optional<IValue> self) {
|
|
size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size();
|
|
if (all_arguments > schema.arguments().size()) {
|
|
throw std::runtime_error(c10::str(
|
|
schema.name(),
|
|
"() expected at most ",
|
|
schema.arguments().size(),
|
|
" argument(s) but received ",
|
|
all_arguments,
|
|
" argument(s). Declaration: ",
|
|
schema));
|
|
}
|
|
Stack stack;
|
|
stack.reserve(schema.arguments().size());
|
|
|
|
if (self) {
|
|
push(stack, std::move(*self));
|
|
}
|
|
// First push all positional args.
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
// Use the type information from the schema to convert the PyObject.
|
|
push(stack, argumentToIValue(schema, stack.size(), args[i]));
|
|
}
|
|
|
|
// Now for every remaining non-positional argument in the schema, look for it
|
|
// in the kwargs dict and push it if found, or use its default value if it
|
|
// has one.
|
|
size_t consumed_kwargs = 0;
|
|
for (size_t i = stack.size(); i < schema.arguments().size(); ++i) {
|
|
const auto& arg = schema.arguments()[i];
|
|
if (kwargs.contains(arg.name().c_str())) {
|
|
push(stack, argumentToIValue(schema, i, kwargs[arg.name().c_str()]));
|
|
consumed_kwargs += 1;
|
|
} else if (arg.default_value()) {
|
|
push(stack, *arg.default_value());
|
|
} else {
|
|
throw std::runtime_error(c10::str(
|
|
schema.name(),
|
|
"() is missing value for argument '",
|
|
arg.name(),
|
|
"'. Declaration: ",
|
|
schema));
|
|
}
|
|
}
|
|
|
|
if (consumed_kwargs != kwargs.size()) {
|
|
std::vector<std::string> names;
|
|
for (const auto& kwarg : kwargs) {
|
|
names.emplace_back(py::cast<std::string>(kwarg.first));
|
|
}
|
|
schema.findErrorInKwargs(names);
|
|
}
|
|
|
|
return stack;
|
|
}
|
|
|
|
inline py::object createPyObjectForStack(Stack&& stack) {
|
|
if (stack.empty()) {
|
|
return py::none();
|
|
}
|
|
|
|
// Return a simple value and not a single-element tuple if there is only one
|
|
// return value.
|
|
if (stack.size() == 1) {
|
|
return toPyObject(std::move(stack[0]));
|
|
}
|
|
|
|
// If there is more than one return value, pop them into a py::tuple.
|
|
py::tuple return_values(stack.size());
|
|
for (size_t ret = 0; ret < return_values.size(); ++ret) {
|
|
return_values[ret] = toPyObject(std::move(stack[ret]));
|
|
}
|
|
|
|
return std::move(return_values);
|
|
}
|
|
|
|
// TODO: Remove once we clean up the GraphExecutor usage.
|
|
inline Stack evilDeprecatedBadCreateStackDoNotUse(
|
|
const py::tuple& tuple,
|
|
at::ArrayRef<Value*> inputs,
|
|
size_t reserve_extra_space = 0) {
|
|
if (tuple.size() != inputs.size()) {
|
|
AT_ERROR(
|
|
"expected " + std::to_string(inputs.size()) + " inputs, but got " +
|
|
std::to_string(tuple.size()));
|
|
}
|
|
Stack result;
|
|
result.reserve(tuple.size() + reserve_extra_space);
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
result.push_back(toIValue(std::move(tuple[i]), inputs[i]->type()));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
inline py::object invokeScriptFunctionFromPython(
|
|
Function& callee,
|
|
tuple_slice args,
|
|
py::kwargs kwargs,
|
|
c10::optional<IValue> self = c10::nullopt) {
|
|
auto stack = createStackForSchema(
|
|
callee.getSchema(), std::move(args), std::move(kwargs), std::move(self));
|
|
{
|
|
AutoNoGIL no_gil_guard;
|
|
callee.run(stack);
|
|
}
|
|
TORCH_CHECK(
|
|
stack.size() > 0,
|
|
"Expected values in the stack after execution but found none");
|
|
return toPyObject(std::move(stack.back()));
|
|
}
|
|
|
|
inline py::object invokeScriptMethodFromPython(
|
|
script::Method& callee,
|
|
tuple_slice args,
|
|
py::kwargs kwargs) {
|
|
return invokeScriptFunctionFromPython(
|
|
callee.function(),
|
|
std::move(args),
|
|
std::move(kwargs),
|
|
callee.owner().module_object());
|
|
}
|
|
inline py::object invokeOperatorFromPython(
|
|
const Operator& op,
|
|
py::args args,
|
|
py::kwargs kwargs) {
|
|
// Create a stack full of the arguments and keyword arguments.
|
|
auto stack = createStackForSchema(
|
|
op.schema(), std::move(args), std::move(kwargs), c10::nullopt);
|
|
|
|
// Invoke the operation, which puts the return values onto the stack.
|
|
op.getOperation()(stack);
|
|
|
|
return createPyObjectForStack(std::move(stack));
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|