pytorch/torch/csrc/jit/python/python_dict.h
Meghan Lele b14c3205fd [JIT] Add torch._C.ScriptDict (#52659)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52659

**Summary**
This commit adds `torch._C.ScriptDict`, a dictionary type that has reference
semantics across the Python/TorchScript boundary. That is, modifications
made to instances of `torch._C.ScriptDict` in TorchScript are visible in
Python even when it is not returned from the function. Instances can be
constructed by passing an instance of a Python dictionary to
`torch.jit.script`. In the case of an empty dictionary, its type is
assumed to be `Dict[str, Tensor]` to be consistent with the handling of
empty dictionaries in TorchScript source code.

`torch._C.ScriptDict` is implemented using a modified version of pybind's `stl_bind.h`-style bindings attached to `ScriptDict`, `ScriptDictIterator` and `ScriptDictKeyIterator`, wrapper classes around `c10::impl::GenericDict` and `c10::impl::GenericDict::iterator`. These bindings allow instances of `torch._C.ScriptDict` to be used as if it were a regular `dict` Python. Reference semantics are achieved by simply retrieving the `IValue` contained in `ScriptDict` in `toIValue` (invoked when converting Python arguments to `IValues` before calling TorchScript code).

**Test Plan**
This commit adds `TestScriptDict` to `test_list_dict.py`, a set of tests
that check that all of the common dictionary operations are supported
and that instances have reference semantics across the
Python/TorchScript boundary.

Differential Revision:
D27211605
D27211605

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Pulled By: SplitInfinity

fbshipit-source-id: 446d4e5328375791aa73eb9e8b04dfe3465af960
2021-05-27 10:25:30 -07:00

129 lines
3.3 KiB
C++

#pragma once
#include <ATen/core/Dict.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace jit {
void initScriptDictBindings(PyObject* module);
/// An iterator over the keys of ScriptDict. This is used to support
/// .keys() and iteration.
class ScriptDictKeyIterator final {
public:
ScriptDictKeyIterator(
c10::impl::GenericDict::iterator iter,
c10::impl::GenericDict::iterator end)
: iter_(std::move(iter)), end_(std::move(end)) {}
IValue next();
private:
c10::impl::GenericDict::iterator iter_;
c10::impl::GenericDict::iterator end_;
};
/// An iterator over the key-value pairs of ScriptDict. This is used to support
/// .items().
class ScriptDictIterator final {
public:
ScriptDictIterator(
c10::impl::GenericDict::iterator iter,
c10::impl::GenericDict::iterator end)
: iter_(std::move(iter)), end_(std::move(end)) {}
IValue next();
private:
c10::impl::GenericDict::iterator iter_;
c10::impl::GenericDict::iterator end_;
};
/// A wrapper around c10::Dict that can be exposed in Python via pybind
/// with an API identical to the Python dictionary class. This allows
/// dictionaries to have reference semantics across the Python/TorchScript
/// boundary.
class ScriptDict final {
public:
// Constructor.
ScriptDict(IValue data) : dict_(AnyType::get(), AnyType::get()) {
TORCH_INTERNAL_ASSERT(data.isGenericDict());
dict_ = data.toGenericDict();
}
// Get the type of the dictionary.
DictTypePtr type() const {
return DictType::create(dict_.keyType(), dict_.valueType());
}
// Return a string representation that can be used
// to reconstruct the instance.
std::string repr() const {
std::ostringstream s;
s << '{';
bool f = false;
for (auto const& kv : dict_) {
if (f) {
s << ", ";
}
s << kv.key() << ": " << kv.value();
f = true;
}
s << '}';
return s.str();
}
// Return an iterator over the keys of the dictionary.
ScriptDictKeyIterator iter() const {
auto begin = dict_.begin();
auto end = dict_.end();
return ScriptDictKeyIterator(begin, end);
}
// Return an iterator over the key-value pairs of the dictionary.
ScriptDictIterator items() const {
auto begin = dict_.begin();
auto end = dict_.end();
return ScriptDictIterator(begin, end);
}
// Interpret the dictionary as a boolean; empty means false, non-empty means
// true.
bool toBool() const {
return !(dict_.empty());
}
// Get the value for the given key. Throws std::out_of_range if the key does
// not exist.
IValue getItem(const IValue& key) {
return dict_.at(key);
};
// Set the value for the given key.
void setItem(const IValue& key, const IValue& value) {
dict_.insert_or_assign(key, value);
};
// Check whether the dictionary contains the given key.
bool contains(const IValue& key) {
return dict_.contains(key);
}
// Delete the given key from the dictionary.
bool delItem(const IValue& key) {
return dict_.erase(key);
}
// Get the size of the dictionary.
int64_t len() const {
return dict_.size();
}
// A c10::Dict instance that holds the actual data.
c10::impl::GenericDict dict_;
};
} // namespace jit
} // namespace torch