mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25721 Context: I am starting to work on removing the BUILD_NAMEDTENSOR flag. Here is the approach: - Move the macro out of header areas - Include a new `enable_namedtensor.h` header that does a `#ifndef BUILD_NAMEDTENSOR #define BUILD_NAMEDTENSOR`. - Include `enable_namedtensor.h` where necessary. This only really needs to happen in two files (c10/TensorImpl.h, ATen/Dimname.h). - Incrementally delete usages of the BUILD_NAMEDTENSOR macro later. The alternative is to straight up delete all instances of BUILD_NAMEDTENSOR. This alternative could be disruptive, lead to merge conflicts, and isn't incremental. Along with the above, some work needs to be done on feature flagging named tensors, and merging the namedtensor CI with the regular CI, and communicating with devs. This work will too be done incrementally. Test Plan - [namedtensor ci] Test Plan: Imported from OSS Differential Revision: D17210913 Pulled By: zou3519 fbshipit-source-id: c73f128b976bb90212639e8f2a3ad2a6a52b8e0c
103 lines
3.2 KiB
C++
103 lines
3.2 KiB
C++
#include <torch/csrc/python_dimname.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
#include <c10/util/flat_hash_map.h>
|
|
|
|
#ifdef BUILD_NAMEDTENSOR
|
|
namespace torch {
|
|
|
|
struct InternedStringsTable {
|
|
InternedStringsTable() = default;
|
|
~InternedStringsTable();
|
|
InternedStringsTable(const InternedStringsTable &) = delete;
|
|
InternedStringsTable& operator =(InternedStringsTable const&) = delete;
|
|
InternedStringsTable(InternedStringsTable&&) = delete;
|
|
InternedStringsTable& operator=(InternedStringsTable&&) = delete;
|
|
|
|
at::optional<at::Dimname> lookup(PyObject* obj);
|
|
// Precondition: obj is an interned python string.
|
|
void addMapping(PyObject* obj, at::Dimname dimname);
|
|
private:
|
|
ska::flat_hash_map<PyObject*,at::Dimname> py_interned_string_to_dimname_;
|
|
};
|
|
|
|
InternedStringsTable kPyInternedStringToDimname;
|
|
|
|
InternedStringsTable::~InternedStringsTable() {
|
|
for (auto it = py_interned_string_to_dimname_.begin();
|
|
it != py_interned_string_to_dimname_.end(); ++it) {
|
|
// See Note [References to python interned strings]
|
|
Py_DECREF(it->first);
|
|
}
|
|
}
|
|
|
|
at::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
|
|
auto it = py_interned_string_to_dimname_.find(obj);
|
|
if (it == py_interned_string_to_dimname_.end()) {
|
|
return at::nullopt;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
|
|
void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
|
|
// Note [References to python interned strings]
|
|
// If a Python interned string has no references to it, then it gets
|
|
// deallocated, invalidating this mapping. Let's immortalize the string by
|
|
// holding a refcount to it and releasing it in the destructor
|
|
Py_INCREF(obj);
|
|
py_interned_string_to_dimname_.emplace(obj, dimname);
|
|
}
|
|
|
|
} // namespace torch
|
|
|
|
bool THPUtils_checkDimname(PyObject* obj) {
|
|
return obj == Py_None || THPUtils_checkString(obj);
|
|
}
|
|
|
|
// To avoid ambiguity with IntArrayRef, we parse obj as a DimnameList if
|
|
// it is a list or tuple and its first elt is a Dimname
|
|
bool THPUtils_checkDimnameList(PyObject* obj) {
|
|
auto tuple = PyTuple_Check(obj);
|
|
if (!tuple && !PyList_Check(obj)) {
|
|
return false;
|
|
}
|
|
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
|
if (size == 0) {
|
|
return true;
|
|
}
|
|
PyObject* first_elt = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
|
|
return THPUtils_checkDimname(first_elt);
|
|
}
|
|
|
|
at::Dimname THPDimname_parse(PyObject* obj) {
|
|
if (obj == Py_None) {
|
|
return at::Dimname::wildcard();
|
|
}
|
|
|
|
if (!THPUtils_checkString(obj)) {
|
|
throw torch::TypeError("expected None or string for Dimname but got %s", Py_TYPE(obj)->tp_name);
|
|
}
|
|
|
|
if (!THPUtils_isInterned(obj)) {
|
|
// internStringInPlace decrefs obj and increfs the result. Because we're
|
|
// not actually returning the result to the user, we need to undo these.
|
|
// See https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace
|
|
Py_INCREF(obj);
|
|
THPUtils_internStringInPlace(&obj);
|
|
Py_DECREF(obj);
|
|
}
|
|
|
|
auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
|
|
if (maybeDimname) {
|
|
return *maybeDimname;
|
|
}
|
|
|
|
const auto name = THPUtils_unpackString(obj);
|
|
auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
|
|
torch::kPyInternedStringToDimname.addMapping(obj, dimname);
|
|
return dimname;
|
|
}
|
|
|
|
#endif
|