mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52659 **Summary** This commit adds `torch._C.ScriptDict`, a dictionary type that has reference semantics across the Python/TorchScript boundary. That is, modifications made to instances of `torch._C.ScriptDict` in TorchScript are visible in Python even when it is not returned from the function. Instances can be constructed by passing an instance of a Python dictionary to `torch.jit.script`. In the case of an empty dictionary, its type is assumed to be `Dict[str, Tensor]` to be consistent with the handling of empty dictionaries in TorchScript source code. `torch._C.ScriptDict` is implemented using a modified version of pybind's `stl_bind.h`-style bindings attached to `ScriptDict`, `ScriptDictIterator` and `ScriptDictKeyIterator`, wrapper classes around `c10::impl::GenericDict` and `c10::impl::GenericDict::iterator`. These bindings allow instances of `torch._C.ScriptDict` to be used as if it were a regular `dict` Python. Reference semantics are achieved by simply retrieving the `IValue` contained in `ScriptDict` in `toIValue` (invoked when converting Python arguments to `IValues` before calling TorchScript code). **Test Plan** This commit adds `TestScriptDict` to `test_list_dict.py`, a set of tests that check that all of the common dictionary operations are supported and that instances have reference semantics across the Python/TorchScript boundary. Differential Revision: D27211605 D27211605 Test Plan: Imported from OSS Reviewed By: gmagogsfm Pulled By: SplitInfinity fbshipit-source-id: 446d4e5328375791aa73eb9e8b04dfe3465af960
202 lines
6.4 KiB
C++
202 lines
6.4 KiB
C++
#include <ATen/core/ivalue.h>
|
|
#include <pybind11/cast.h>
|
|
#include <pybind11/detail/common.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/jit/python/python_dict.h>
|
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
IValue ScriptDictIterator::next() {
|
|
if (iter_ == end_) {
|
|
throw py::stop_iteration();
|
|
}
|
|
|
|
// Since this is the iterator for .items(), the current key and value
|
|
// should be returned as a tuple.
|
|
IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
|
|
|
|
// Advance the iterator for next time.
|
|
iter_++;
|
|
|
|
return result;
|
|
}
|
|
|
|
IValue ScriptDictKeyIterator::next() {
|
|
if (iter_ == end_) {
|
|
throw py::stop_iteration();
|
|
}
|
|
|
|
// Since this is the iterator for .keys() and __iter__(), return only the key.
|
|
IValue result = iter_->key();
|
|
|
|
// Advance the iterator for next time.
|
|
iter_++;
|
|
|
|
return result;
|
|
}
|
|
|
|
void initScriptDictBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
|
|
.def(
|
|
"__next__",
|
|
[](ScriptDictKeyIterator& iter) {
|
|
auto result = iter.next();
|
|
return toPyObject(result);
|
|
})
|
|
.def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
|
|
|
|
py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
|
|
.def(
|
|
"__next__",
|
|
[](ScriptDictIterator& iter) {
|
|
auto result = iter.next();
|
|
return toPyObject(result);
|
|
})
|
|
.def("__iter__", [](ScriptDictIterator& iter) { return iter; });
|
|
|
|
py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
|
|
.def(py::init([](py::dict dict) {
|
|
TypePtr type = nullptr;
|
|
|
|
if (dict.size() > 0) {
|
|
// If the source dictionary is nonempty, try to infer its type.
|
|
auto inferred_type = tryToInferType(dict);
|
|
|
|
if (!inferred_type.success()) {
|
|
std::stringstream ss;
|
|
ss << "Unable to infer type of dictionary: "
|
|
<< inferred_type.reason();
|
|
throw JITException(ss.str());
|
|
}
|
|
|
|
type = inferred_type.type();
|
|
} else {
|
|
// If is empty, assume the type is Dict[str, Tensor] as is done in
|
|
// TorchScript code.
|
|
type = DictType::create(StringType::get(), TensorType::getInferred());
|
|
}
|
|
|
|
auto data = toIValue(std::move(dict), type);
|
|
return std::make_shared<ScriptDict>(data);
|
|
}))
|
|
.def(
|
|
"__repr__",
|
|
[](const std::shared_ptr<ScriptDict>& self) {
|
|
return toPyObject(self->repr());
|
|
})
|
|
.def(
|
|
"__bool__",
|
|
[](const std::shared_ptr<ScriptDict>& self) {
|
|
return toPyObject(self->toBool());
|
|
})
|
|
.def(
|
|
"__len__",
|
|
[](const std::shared_ptr<ScriptDict>& self) {
|
|
return toPyObject(self->len());
|
|
})
|
|
.def(
|
|
"__contains__",
|
|
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
|
|
try {
|
|
return toPyObject(self->contains(
|
|
toIValue(std::move(key), self->type()->getKeyType())));
|
|
} catch (const py::cast_error& e) {
|
|
throw py::key_error();
|
|
}
|
|
})
|
|
.def(
|
|
"__getitem__",
|
|
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
|
|
IValue value;
|
|
|
|
// Convert key to IValue.
|
|
try {
|
|
value = toIValue(std::move(key), self->type()->getKeyType());
|
|
} catch (const py::cast_error& e) {
|
|
// It would be nice to throw py::type_error here but py::key_error
|
|
// needs to be thrown for parity with eager mode.
|
|
throw py::key_error();
|
|
}
|
|
|
|
// Call getItem on self.
|
|
try {
|
|
value = self->getItem(value);
|
|
} catch (const std::out_of_range& e) { // Key doesn't exist.
|
|
throw py::key_error();
|
|
}
|
|
|
|
return toPyObject(std::move(value));
|
|
},
|
|
py::return_value_policy::
|
|
reference_internal) // Return value is a reference to an object
|
|
// that resides in the ScriptDict
|
|
.def(
|
|
"__setitem__",
|
|
[](const std::shared_ptr<ScriptDict>& self,
|
|
py::object key,
|
|
py::object value) {
|
|
IValue key_ivalue, value_ivalue;
|
|
|
|
// Try to convert the key to an IValue.
|
|
try {
|
|
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
|
|
} catch (const py::cast_error& e) {
|
|
throw py::type_error();
|
|
}
|
|
|
|
// Try to convert the value to an IValue.
|
|
try {
|
|
value_ivalue =
|
|
toIValue(std::move(value), self->type()->getValueType());
|
|
} catch (const py::cast_error& e) {
|
|
throw py::type_error();
|
|
}
|
|
|
|
self->setItem(key_ivalue, value_ivalue);
|
|
})
|
|
.def(
|
|
"__delitem__",
|
|
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
|
|
IValue key_ivalue;
|
|
|
|
// Try to convert the key to an IValue.
|
|
try {
|
|
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
|
|
} catch (const py::cast_error& e) {
|
|
throw py::type_error();
|
|
}
|
|
|
|
// If removed = false, that means the key didn't exist in the
|
|
// dictionary.
|
|
bool removed = self->delItem(key_ivalue);
|
|
|
|
if (!removed) {
|
|
throw py::key_error();
|
|
}
|
|
})
|
|
.def(
|
|
"__iter__",
|
|
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
|
|
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
|
|
// long as the iterator
|
|
.def(
|
|
"items",
|
|
[](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
|
|
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
|
|
// long as the iterator
|
|
.def(
|
|
"keys",
|
|
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
|
|
py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
|
|
// long as the iterator
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|