mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52659 **Summary** This commit adds `torch._C.ScriptDict`, a dictionary type that has reference semantics across the Python/TorchScript boundary. That is, modifications made to instances of `torch._C.ScriptDict` in TorchScript are visible in Python even when it is not returned from the function. Instances can be constructed by passing an instance of a Python dictionary to `torch.jit.script`. In the case of an empty dictionary, its type is assumed to be `Dict[str, Tensor]` to be consistent with the handling of empty dictionaries in TorchScript source code. `torch._C.ScriptDict` is implemented using a modified version of pybind's `stl_bind.h`-style bindings attached to `ScriptDict`, `ScriptDictIterator` and `ScriptDictKeyIterator`, wrapper classes around `c10::impl::GenericDict` and `c10::impl::GenericDict::iterator`. These bindings allow instances of `torch._C.ScriptDict` to be used as if it were a regular `dict` Python. Reference semantics are achieved by simply retrieving the `IValue` contained in `ScriptDict` in `toIValue` (invoked when converting Python arguments to `IValues` before calling TorchScript code). **Test Plan** This commit adds `TestScriptDict` to `test_list_dict.py`, a set of tests that check that all of the common dictionary operations are supported and that instances have reference semantics across the Python/TorchScript boundary. Differential Revision: D27211605 D27211605 Test Plan: Imported from OSS Reviewed By: gmagogsfm Pulled By: SplitInfinity fbshipit-source-id: 446d4e5328375791aa73eb9e8b04dfe3465af960
129 lines
3.3 KiB
C++
129 lines
3.3 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/Dict.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void initScriptDictBindings(PyObject* module);
|
|
|
|
/// An iterator over the keys of ScriptDict. This is used to support
|
|
/// .keys() and iteration.
|
|
class ScriptDictKeyIterator final {
|
|
public:
|
|
ScriptDictKeyIterator(
|
|
c10::impl::GenericDict::iterator iter,
|
|
c10::impl::GenericDict::iterator end)
|
|
: iter_(std::move(iter)), end_(std::move(end)) {}
|
|
IValue next();
|
|
|
|
private:
|
|
c10::impl::GenericDict::iterator iter_;
|
|
c10::impl::GenericDict::iterator end_;
|
|
};
|
|
|
|
/// An iterator over the key-value pairs of ScriptDict. This is used to support
|
|
/// .items().
|
|
class ScriptDictIterator final {
|
|
public:
|
|
ScriptDictIterator(
|
|
c10::impl::GenericDict::iterator iter,
|
|
c10::impl::GenericDict::iterator end)
|
|
: iter_(std::move(iter)), end_(std::move(end)) {}
|
|
IValue next();
|
|
|
|
private:
|
|
c10::impl::GenericDict::iterator iter_;
|
|
c10::impl::GenericDict::iterator end_;
|
|
};
|
|
|
|
/// A wrapper around c10::Dict that can be exposed in Python via pybind
|
|
/// with an API identical to the Python dictionary class. This allows
|
|
/// dictionaries to have reference semantics across the Python/TorchScript
|
|
/// boundary.
|
|
class ScriptDict final {
|
|
public:
|
|
// Constructor.
|
|
ScriptDict(IValue data) : dict_(AnyType::get(), AnyType::get()) {
|
|
TORCH_INTERNAL_ASSERT(data.isGenericDict());
|
|
dict_ = data.toGenericDict();
|
|
}
|
|
|
|
// Get the type of the dictionary.
|
|
DictTypePtr type() const {
|
|
return DictType::create(dict_.keyType(), dict_.valueType());
|
|
}
|
|
|
|
// Return a string representation that can be used
|
|
// to reconstruct the instance.
|
|
std::string repr() const {
|
|
std::ostringstream s;
|
|
s << '{';
|
|
bool f = false;
|
|
for (auto const& kv : dict_) {
|
|
if (f) {
|
|
s << ", ";
|
|
}
|
|
s << kv.key() << ": " << kv.value();
|
|
f = true;
|
|
}
|
|
s << '}';
|
|
return s.str();
|
|
}
|
|
|
|
// Return an iterator over the keys of the dictionary.
|
|
ScriptDictKeyIterator iter() const {
|
|
auto begin = dict_.begin();
|
|
auto end = dict_.end();
|
|
return ScriptDictKeyIterator(begin, end);
|
|
}
|
|
|
|
// Return an iterator over the key-value pairs of the dictionary.
|
|
ScriptDictIterator items() const {
|
|
auto begin = dict_.begin();
|
|
auto end = dict_.end();
|
|
return ScriptDictIterator(begin, end);
|
|
}
|
|
|
|
// Interpret the dictionary as a boolean; empty means false, non-empty means
|
|
// true.
|
|
bool toBool() const {
|
|
return !(dict_.empty());
|
|
}
|
|
|
|
// Get the value for the given key. Throws std::out_of_range if the key does
|
|
// not exist.
|
|
IValue getItem(const IValue& key) {
|
|
return dict_.at(key);
|
|
};
|
|
|
|
// Set the value for the given key.
|
|
void setItem(const IValue& key, const IValue& value) {
|
|
dict_.insert_or_assign(key, value);
|
|
};
|
|
|
|
// Check whether the dictionary contains the given key.
|
|
bool contains(const IValue& key) {
|
|
return dict_.contains(key);
|
|
}
|
|
|
|
// Delete the given key from the dictionary.
|
|
bool delItem(const IValue& key) {
|
|
return dict_.erase(key);
|
|
}
|
|
|
|
// Get the size of the dictionary.
|
|
int64_t len() const {
|
|
return dict_.size();
|
|
}
|
|
|
|
// A c10::Dict instance that holds the actual data.
|
|
c10::impl::GenericDict dict_;
|
|
};
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|