mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient. In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack. This PR does a number of things: * It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error * It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack * It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode) * `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway * I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091 Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
1670 lines
53 KiB
C++
1670 lines
53 KiB
C++
#include <torch/csrc/utils/python_arg_parser.h>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/Layout.h>
|
|
#include <torch/csrc/MemoryFormat.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/utils/invalid_arguments.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
#include <torch/csrc/utils/python_torch_function_mode.h>
|
|
#include <torch/csrc/utils/torch_dispatch_mode.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/PythonTorchFunctionTLS.h>
|
|
#include <ATen/TracerMode.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
|
|
static std::unordered_map<std::string, ParameterType> type_map = {
|
|
{"Tensor", ParameterType::TENSOR},
|
|
{"Scalar", ParameterType::SCALAR},
|
|
{"int64_t", ParameterType::INT64},
|
|
{"SymInt", ParameterType::SYM_INT},
|
|
{"double", ParameterType::DOUBLE},
|
|
{"complex", ParameterType::COMPLEX},
|
|
{"TensorList", ParameterType::TENSOR_LIST},
|
|
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
|
|
{"IntArrayRef", ParameterType::INT_LIST},
|
|
{"SymIntArrayRef", ParameterType::SYM_INT_LIST},
|
|
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
|
{"Generator", ParameterType::GENERATOR},
|
|
{"bool", ParameterType::BOOL},
|
|
{"Storage", ParameterType::STORAGE},
|
|
{"PyObject*", ParameterType::PYOBJECT},
|
|
{"ScalarType", ParameterType::SCALARTYPE},
|
|
{"Layout", ParameterType::LAYOUT},
|
|
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
|
{"QScheme", ParameterType::QSCHEME},
|
|
{"Device", ParameterType::DEVICE},
|
|
{"Stream", ParameterType::STREAM},
|
|
{"std::string", ParameterType::STRING},
|
|
{"c10::string_view", ParameterType::STRING},
|
|
{"Dimname", ParameterType::DIMNAME},
|
|
{"DimnameList", ParameterType::DIMNAME_LIST},
|
|
{"ScalarList", ParameterType::SCALAR_LIST},
|
|
};
|
|
|
|
// Default arg name translations for compatibility with NumPy.
|
|
//
|
|
// Example:
|
|
// ```python
|
|
// t = torch.randn(10,10)
|
|
// torch.sum(a=t, axis=0, keepdim=True)
|
|
// ```
|
|
//
|
|
// A vector is necessary, because we might need to try multiple values.
|
|
// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input
|
|
// tensor. Rather than annotate each function separately with whether it should
|
|
// take "x" or "a", just try both.
|
|
//
|
|
// TODO: Allow individual functions to specify non-default translations:
|
|
// For example, `torch.pow` should translate "exponent" to "x2".
|
|
static const std::unordered_map<std::string, std::vector<std::string>>
|
|
numpy_compatibility_arg_names = {
|
|
{"dim", {"axis"}},
|
|
{"keepdim", {"keepdims"}},
|
|
{"input", {"x", "a", "x1"}},
|
|
{"other", {"x2"}},
|
|
};
|
|
|
|
// TODO: remove this. This is a temporary list of functions that allow Python
|
|
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
|
|
// overloads and binding to the Tensor overload with a number of a different
|
|
// type will trigger a type error.
|
|
//
|
|
// If you modify this, you will need to adjust the blocklist in
|
|
// tools/pyi/gen_pyi.py (and add hardcoded signatures for these
|
|
// functions.)
|
|
bool should_allow_numbers_as_tensors(const std::string& name) {
|
|
static std::unordered_set<std::string> allowed = {
|
|
"add", "add_", "add_out",
|
|
"div", "div_", "div_out",
|
|
"divide", "divide_", "divide_out", // alias of div
|
|
"mul", "mul_", "mul_out",
|
|
"multiply", "multiply_", "multiply_out", // alias of mul
|
|
"sub", "sub_", "sub_out",
|
|
"subtract", "subtract_", "subtract_out", // alias of sub
|
|
"true_divide", "true_divide_", "true_divide_out",
|
|
"to", "_to_copy", "copy_",
|
|
"floor_divide", "floor_divide_", "floor_divide_out"};
|
|
return allowed.find(name) != allowed.end();
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
|
: optional(false),
|
|
allow_none(false),
|
|
keyword_only(keyword_only),
|
|
size(0),
|
|
default_scalar(0) {
|
|
auto space = fmt.find(' ');
|
|
if (space == std::string::npos) {
|
|
throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
|
|
}
|
|
|
|
auto type_str = fmt.substr(0, space);
|
|
|
|
auto question = type_str.find('?');
|
|
if (question != std::string::npos) {
|
|
allow_none = true;
|
|
type_str = type_str.substr(0, question);
|
|
}
|
|
|
|
// Parse and remove brackets from type_str
|
|
auto bracket = type_str.find('[');
|
|
if (bracket != std::string::npos) {
|
|
auto size_str =
|
|
type_str.substr(bracket + 1, type_str.length() - bracket - 2);
|
|
size = atoi(size_str.c_str());
|
|
type_str = type_str.substr(0, bracket);
|
|
}
|
|
|
|
auto name_str = fmt.substr(space + 1);
|
|
auto it = type_map.find(type_str);
|
|
if (it == type_map.end()) {
|
|
throw std::runtime_error(
|
|
"FunctionParameter(): invalid type string: " + type_str);
|
|
}
|
|
type_ = it->second;
|
|
|
|
auto eq = name_str.find('=');
|
|
if (eq != std::string::npos) {
|
|
name = name_str.substr(0, eq);
|
|
optional = true;
|
|
set_default_str(name_str.substr(eq + 1));
|
|
} else {
|
|
name = name_str;
|
|
}
|
|
python_name = THPUtils_internString(name);
|
|
auto np_compat_it = numpy_compatibility_arg_names.find(name);
|
|
if (np_compat_it != numpy_compatibility_arg_names.end()) {
|
|
for (const auto& str : np_compat_it->second) {
|
|
numpy_python_names.push_back(THPUtils_internString(str));
|
|
}
|
|
}
|
|
}
|
|
|
|
auto handle_torch_function_getter(
|
|
THPVariable* self,
|
|
const std::string& property_name) -> PyObject* {
|
|
py::object torch_api = PyObject_FastGetAttrString(
|
|
THPVariableClass, (char*)property_name.c_str());
|
|
std::string module_name = "torch.Tensor." + property_name;
|
|
return handle_torch_function(
|
|
(PyObject*)self,
|
|
"__get__",
|
|
nullptr,
|
|
nullptr,
|
|
torch_api.ptr(),
|
|
module_name);
|
|
}
|
|
|
|
auto handle_torch_function_setter(
|
|
THPVariable* self,
|
|
const std::string& property_name,
|
|
PyObject* value) -> int {
|
|
py::object torch_api = PyObject_FastGetAttrString(
|
|
THPVariableClass, (char*)property_name.c_str());
|
|
std::string module_name = "torch.Tensor." + property_name;
|
|
if (value != nullptr) {
|
|
py::tuple args_ = py::make_tuple(py::handle(value));
|
|
handle_torch_function(
|
|
(PyObject*)self,
|
|
"__set__",
|
|
args_.ptr(),
|
|
nullptr,
|
|
torch_api.ptr(),
|
|
module_name);
|
|
} else {
|
|
handle_torch_function(
|
|
(PyObject*)self,
|
|
"__delete__",
|
|
nullptr,
|
|
nullptr,
|
|
torch_api.ptr(),
|
|
module_name);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Combines self and args into one tuple.
|
|
static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
|
|
if (args == nullptr) {
|
|
return py::make_tuple(py::handle(self));
|
|
} else if (self == nullptr) {
|
|
return py::reinterpret_borrow<py::tuple>(args);
|
|
}
|
|
|
|
auto py_args = py::reinterpret_borrow<py::tuple>(args);
|
|
size_t n = py_args.size();
|
|
auto args_ = py::tuple(n + 1);
|
|
args_[0] = py::handle(self);
|
|
for (const auto i : c10::irange(n)) {
|
|
args_[i + 1] = py_args[i];
|
|
}
|
|
return args_;
|
|
}
|
|
|
|
// TODO: I'm not sure if I should call this __torch_function__ or
|
|
// torch_function. The former makes it easier to take an existing
|
|
// Tensor-like __torch_function__ object and turn it into a mode;
|
|
// but in general modes don't have to be Tensor-like (and we will
|
|
// improperly accept mode objects as arguments when they shouldn't
|
|
// be passed around in this way).
|
|
const char* torch_function_mode_name = "__torch_function__";
|
|
|
|
auto handle_torch_function(
|
|
PyObject* self,
|
|
const std::string& func_name,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
PyObject* torch_api,
|
|
const std::string& module_name) -> PyObject* {
|
|
py::object torch_api_function =
|
|
PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
|
|
TORCH_INTERNAL_ASSERT(
|
|
torch_api_function.ptr() != nullptr, "torch API function must exist");
|
|
py::tuple args_ = combine_self_args(self, args);
|
|
return handle_torch_function_no_python_arg_parser(
|
|
{py::handle(self)},
|
|
args_.ptr(),
|
|
kwargs,
|
|
func_name.c_str(),
|
|
torch_api_function.ptr(),
|
|
module_name.c_str(),
|
|
TorchFunctionName::TorchFunction);
|
|
}
|
|
|
|
// Note: [Overloaded args]
|
|
// An overloaded arg may be one of the following:
|
|
// - an instance of an object that has a __torch_function__ method
|
|
// - an instance of an object that has a __torch_dispatch__ classmethod
|
|
// - a class type that has a __torch_dispatch__ classmethod
|
|
//
|
|
// This function returns the type of the arg (if the arg is an instance),
|
|
// otherwise, it returns the arg.
|
|
static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
|
|
if (PyType_Check(obj_or_type)) {
|
|
return obj_or_type;
|
|
}
|
|
return (PyObject*)Py_TYPE(obj_or_type);
|
|
}
|
|
|
|
// See Note: [Overloaded args] for what they hold
|
|
auto handle_torch_function_no_python_arg_parser(
|
|
at::ArrayRef<py::handle> overloaded_args,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
const char* func_name,
|
|
PyObject* torch_api_function,
|
|
const char* module_name,
|
|
TorchFunctionName torch_function_name) -> PyObject* {
|
|
const char* torch_function_name_str = nullptr;
|
|
switch (torch_function_name) {
|
|
case TorchFunctionName::TorchFunction:
|
|
torch_function_name_str = "__torch_function__";
|
|
break;
|
|
case TorchFunctionName::TorchDispatch:
|
|
torch_function_name_str = "__torch_dispatch__";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(0, static_cast<int>(torch_function_name));
|
|
}
|
|
// overloaded_args already all have unique types
|
|
// nb: modes don't go in the overloaded types list, as they are not
|
|
// necessarily types
|
|
std::vector<py::object> overloaded_types;
|
|
overloaded_types.reserve(overloaded_args.size());
|
|
for (auto& arg : overloaded_args) {
|
|
overloaded_types.push_back(py::reinterpret_borrow<py::object>(
|
|
get_type_of_overloaded_arg(arg.ptr())));
|
|
}
|
|
py::tuple py_types = py::cast(overloaded_types);
|
|
py::object ret;
|
|
py::object mode_obj;
|
|
|
|
const bool is_torch_function =
|
|
torch_function_name == TorchFunctionName::TorchFunction;
|
|
const auto is_mode_active = [&]() {
|
|
return is_torch_function ? at::impl::torch_function_mode_enabled()
|
|
: c10::impl::dispatch_mode_enabled();
|
|
};
|
|
|
|
if (is_mode_active()) {
|
|
// Disable mode on the inside; this makes for a more user-friendly
|
|
// experience if you try to, e.g., print your tensors.
|
|
at::optional<torch::overrides::StashTorchFunctionModeGuard> tf_g;
|
|
at::optional<torch_dispatch_mode::StashTorchDispatchModeGuard> td_g;
|
|
// NB: We only really need keep the mode_obj live if the function call
|
|
// fails for error reporting, but whatever, Python refcounts are cheap
|
|
if (is_torch_function) {
|
|
tf_g.emplace();
|
|
mode_obj = py::reinterpret_borrow<py::object>(
|
|
tf_g->get_cur_mode()->ptr(getPyInterpreter()));
|
|
} else {
|
|
td_g.emplace();
|
|
mode_obj = py::reinterpret_borrow<py::object>(
|
|
td_g->get_cur_mode()->ptr(getPyInterpreter()));
|
|
}
|
|
py::object torch_function =
|
|
PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str);
|
|
if (!torch_function) {
|
|
TORCH_INTERNAL_ASSERT(0);
|
|
}
|
|
TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr);
|
|
TORCH_INTERNAL_ASSERT(args != nullptr);
|
|
|
|
TORCH_CHECK(
|
|
PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
|
|
.is(mode_obj),
|
|
"Defining your mode's `",
|
|
torch_function_name_str,
|
|
"` as a classmethod is not supported, please make it a plain method");
|
|
|
|
// Blegh. This accidentally works in PyObject_CallFunctionObjArgs below
|
|
// because the nullptr terminates the argument list ick ick ick.
|
|
if (kwargs == nullptr) {
|
|
ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
|
|
mode_obj.ptr(),
|
|
torch_function_name_str,
|
|
"OOO",
|
|
torch_api_function,
|
|
py_types.ptr(),
|
|
args));
|
|
} else {
|
|
ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
|
|
mode_obj.ptr(),
|
|
torch_function_name_str,
|
|
"OOOO",
|
|
torch_api_function,
|
|
py_types.ptr(),
|
|
args,
|
|
kwargs));
|
|
}
|
|
if (ret.ptr() == nullptr) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
|
|
for (auto& arg : overloaded_args) {
|
|
py::object torch_function =
|
|
PyObject_FastGetAttrString(arg.ptr(), torch_function_name_str);
|
|
if (!torch_function) {
|
|
TORCH_INTERNAL_ASSERT(0);
|
|
}
|
|
|
|
// See https://github.com/pytorch/pytorch/issues/63767
|
|
if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
|
|
.is(arg) &&
|
|
torch_function.ptr() != torch::disabled_torch_function_impl()) {
|
|
TORCH_WARN(
|
|
"Defining your `",
|
|
torch_function_name_str,
|
|
"` as a plain method is deprecated ",
|
|
"and will be an error in future, please define it as a classmethod.");
|
|
}
|
|
|
|
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
|
|
torch_function.ptr(),
|
|
torch_api_function,
|
|
py_types.ptr(),
|
|
args,
|
|
kwargs,
|
|
NULL));
|
|
if (ret.ptr() != Py_NotImplemented) {
|
|
// Return the reference to the result. This also covers the case where
|
|
// ret is NULL and __torch_function__/__torch_dispatch raised an
|
|
// exception, which we throw below
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (ret.ptr() == nullptr) {
|
|
// if an exception occurred in a user's implementation of
|
|
// __torch_function__, throw it
|
|
throw python_error();
|
|
} else if (ret.ptr() == Py_NotImplemented) {
|
|
// all __torch_function__ implementations in overloaded_args
|
|
// returned NotImplemented, so we raise a TypeError.
|
|
std::stringstream ss;
|
|
ss << "Multiple dispatch failed for '";
|
|
if (module_name && func_name) {
|
|
ss << module_name << "." << func_name;
|
|
} else {
|
|
py::handle fn = torch_api_function;
|
|
ss << py::str(fn.attr("__module__")) << "."
|
|
<< py::str(fn.attr("__name__"));
|
|
}
|
|
ss << "'; all " << torch_function_name_str
|
|
<< " handlers returned NotImplemented:\n\n";
|
|
if (mode_obj) {
|
|
ss << " - mode object " << py::repr(mode_obj) << "\n";
|
|
}
|
|
for (auto& arg : overloaded_args) {
|
|
ss << " - tensor subclass "
|
|
<< py::repr(get_type_of_overloaded_arg(arg.ptr())) << "\n";
|
|
}
|
|
ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
|
|
const std::string& tmp = ss.str();
|
|
PyErr_SetString(PyExc_TypeError, tmp.c_str());
|
|
throw python_error();
|
|
}
|
|
return ret.release().ptr();
|
|
}
|
|
|
|
auto handle_torch_function(
|
|
PythonArgs& r,
|
|
PyObject* self,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
PyObject* torch_api,
|
|
const char* module_name,
|
|
const char* func_name_override) -> PyObject* {
|
|
py::object torch_api_function = PyObject_FastGetAttrString(
|
|
torch_api,
|
|
(char*)(func_name_override ? func_name_override : r.get_func_name().c_str()));
|
|
TORCH_INTERNAL_ASSERT(
|
|
torch_api_function.ptr() != nullptr, "torch API function must exist");
|
|
py::object ret;
|
|
py::tuple args_ = combine_self_args(self, args);
|
|
// overloaded_args already all have unique types
|
|
std::vector<py::object> overloaded_types;
|
|
overloaded_types.reserve(r.signature.overloaded_args.size());
|
|
for (auto& arg : r.signature.overloaded_args) {
|
|
overloaded_types.push_back(
|
|
py::reinterpret_borrow<py::object>((PyObject*)Py_TYPE(arg.ptr())));
|
|
}
|
|
py::tuple py_types = py::cast(overloaded_types);
|
|
return handle_torch_function_no_python_arg_parser(
|
|
r.signature.overloaded_args,
|
|
args_.ptr(),
|
|
kwargs,
|
|
r.get_func_name().c_str(),
|
|
torch_api_function.ptr(),
|
|
module_name);
|
|
}
|
|
|
|
auto handle_torch_function(
|
|
PythonArgs& r,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
PyObject* torch_api,
|
|
const char* module_name,
|
|
const char* func_name_override) -> PyObject* {
|
|
return handle_torch_function(
|
|
r, nullptr, args, kwargs, torch_api, module_name, func_name_override);
|
|
}
|
|
|
|
auto handle_torch_function_indexing(
|
|
PyObject* self,
|
|
PyObject* index,
|
|
PyObject* val) -> PyObject* {
|
|
const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
|
|
py::object index_tup;
|
|
if (PyTuple_Check(index)) {
|
|
index_tup = py::reinterpret_borrow<py::object>(index);
|
|
} else {
|
|
index_tup = py::make_tuple(py::handle(index));
|
|
}
|
|
std::vector<py::handle> overridable_args;
|
|
is_tensor_and_append_overloaded(self, &overridable_args);
|
|
auto size = PyTuple_GET_SIZE(index_tup.ptr());
|
|
for (auto i : c10::irange(size)) {
|
|
auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
|
|
is_tensor_and_append_overloaded(obj, &overridable_args);
|
|
}
|
|
if (val != nullptr) {
|
|
is_tensor_and_append_overloaded(val, &overridable_args);
|
|
}
|
|
py::object func =
|
|
PyObject_FastGetAttrString(THPVariableClass, (char*)func_name);
|
|
py::object args = (val == nullptr)
|
|
? py::make_tuple(py::handle(self), py::handle(index))
|
|
: py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
|
|
return handle_torch_function_no_python_arg_parser(
|
|
overridable_args,
|
|
args.ptr(),
|
|
nullptr,
|
|
func_name,
|
|
func.ptr(),
|
|
"torch.Tensor");
|
|
}
|
|
|
|
/*
|
|
* obj has a __torch_function__ implementation and may either be a
|
|
* subclass of Tensor or a Tensor-like duck type. We may need to
|
|
* append this object to the overloaded_args vector, which tracks all
|
|
* of the arguments with distinct __torch_function__ implementations
|
|
* we've seen so far.
|
|
*
|
|
* If this is the first argument we've seen with __torch_function__
|
|
* defined, we unconditionally add obj to the overloaded_args vector.
|
|
*
|
|
* If we've already seen arguments with __torch_function__ defined,
|
|
* then we first need to check if obj is the same type as any of the
|
|
* entries in overloaded_args. If so, we can ignore obj since we
|
|
* already have an entry in overloaded_args with the same
|
|
* __torch_function__ implementation.
|
|
*
|
|
* If it's a different type, we then need to check if it's a subclass
|
|
* of one of the types we've already seen. If so, we need to insert an
|
|
* entry in overloaded_args for this type with higher precedence than
|
|
* the superclass.
|
|
*
|
|
* See torch._overrides._get_overloaded_types_and_args for the equivalent
|
|
* function in the Python __torch_function__ implementation.
|
|
*
|
|
* The precedence-determining algorithm implemented in this function is
|
|
* described in NEP-0018:
|
|
* https://numpy.org/neps/nep-0018-array-function-protocol.html
|
|
*
|
|
* 'overloaded_args' is a raw pointer to a vector of pybind11 handles
|
|
* that have distinct __torch_function__ implementations, in order of calling
|
|
* precedence.
|
|
*
|
|
* 'obj' is an object to check for a __torch_function__ implementation
|
|
*
|
|
* If changing this file in a way that can affect the __torch_function__
|
|
* overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
|
|
* See the instructions in the 'README.md' in that directory.
|
|
*
|
|
*/
|
|
|
|
static void append_overloaded_arg(
|
|
std::vector<py::handle>* overloaded_args,
|
|
PyObject* obj,
|
|
bool obj_is_type) {
|
|
bool class_not_seen_yet = true;
|
|
PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
|
|
for (auto& arg : *overloaded_args) {
|
|
if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
|
|
// obj is the same type as another parameter we've seen in a prior
|
|
// iteration of the loop over parameters so we already have an entry
|
|
// with the proper __torch_function__ implementation to call, so skip
|
|
// this parameter
|
|
class_not_seen_yet = false;
|
|
break;
|
|
}
|
|
}
|
|
if (class_not_seen_yet) {
|
|
auto arg_index = overloaded_args->size();
|
|
for (const auto j : c10::irange(arg_index)) {
|
|
if (PyObject_IsSubclass(
|
|
obj_type,
|
|
(PyObject*)(get_type_of_overloaded_arg(
|
|
(*overloaded_args)[j].ptr())))) {
|
|
// obj is a subclass of another object we've seen already so its
|
|
// __torch_function__ should be called first, therefore we
|
|
// insert it into overloaded_args before the superclass
|
|
arg_index = j;
|
|
break;
|
|
}
|
|
}
|
|
// add object to overloaded_args. If it's a subclass of another class
|
|
// we've already seen it will be inserted before the superclass,
|
|
// otherwise it will be inserted at the end of the array
|
|
overloaded_args->insert(
|
|
overloaded_args->begin() + static_cast<long>(arg_index), obj);
|
|
}
|
|
}
|
|
|
|
void append_overloaded_tensor(
|
|
std::vector<py::handle>* overloaded_args,
|
|
PyObject* obj) {
|
|
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false);
|
|
}
|
|
|
|
void append_overloaded_type(
|
|
std::vector<py::handle>* overloaded_args,
|
|
PyObject* obj) {
|
|
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true);
|
|
}
|
|
|
|
bool is_tensor_and_append_overloaded(
|
|
PyObject* obj,
|
|
std::vector<py::handle>* overloaded_args) {
|
|
if (THPVariable_CheckExact(obj)) {
|
|
// torch.Tensor instances (not subclasses, except for Parameter)
|
|
return true;
|
|
}
|
|
|
|
if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
|
|
// tensor subclasses and unrelated objects with __torch_function__
|
|
append_overloaded_tensor(overloaded_args, obj);
|
|
return true;
|
|
} else if (THPVariable_Check(obj)) {
|
|
// tensor subclasses without __torch_function__
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static bool is_scalar_list(PyObject* obj) {
|
|
auto tuple = six::isTuple(obj);
|
|
if (!(tuple || PyList_Check(obj))) {
|
|
return false;
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
|
for (const auto idx : c10::irange(size)) {
|
|
PyObject* iobj =
|
|
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
|
|
if (!THPUtils_checkScalar(iobj)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool is_tensor_list_and_append_overloaded(
|
|
PyObject* obj,
|
|
std::vector<py::handle>* overloaded_args,
|
|
int argnum,
|
|
bool throw_error) {
|
|
auto tuple = six::isTuple(obj);
|
|
if (!(tuple || PyList_Check(obj))) {
|
|
return false;
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
|
for (long idx = 0; idx < size; idx++) {
|
|
PyObject* iobj =
|
|
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
|
|
if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
|
|
if (throw_error) {
|
|
throw TypeError(
|
|
"expected Tensor as element %d in argument %d, but got %s",
|
|
static_cast<int>(idx),
|
|
argnum,
|
|
Py_TYPE(iobj)->tp_name);
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static bool is_float_or_complex_list(PyObject* obj) {
|
|
auto tuple = six::isTuple(obj);
|
|
if (!(tuple || PyList_Check(obj))) {
|
|
return false;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
|
if (size > 0) {
|
|
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
|
|
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
static bool is_int_list(
|
|
PyObject* obj,
|
|
int broadcast_size,
|
|
int64_t* failed_idx = nullptr) {
|
|
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
|
auto len = PySequence_Size(obj);
|
|
if (len == 0) {
|
|
return true;
|
|
}
|
|
|
|
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
|
|
bool int_first = false;
|
|
if (THPUtils_checkIndex(item.ptr())) {
|
|
// we still have to check that the rest of items are NOT symint nodes
|
|
int_first = true;
|
|
}
|
|
|
|
// Make sure none of the later arguments are SymInt
|
|
// NB: do NOT check that the later arguments are ints, as this is
|
|
// BC-breaking for FX
|
|
for (int i = 1; i < len; i++) {
|
|
if (torch::is_symint(
|
|
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
|
|
if (failed_idx != nullptr) {
|
|
*failed_idx = i;
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (int_first) {
|
|
return true;
|
|
}
|
|
|
|
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
|
// in an intlist argument. Even float or complex scalar tensors.
|
|
bool r =
|
|
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
|
THPVariable_Unpack(item.ptr()).sizes().empty());
|
|
if (!r && failed_idx != nullptr) {
|
|
*failed_idx = 0;
|
|
}
|
|
return r;
|
|
}
|
|
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
|
// int
|
|
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
|
}
|
|
|
|
static bool is_int_or_symint(PyObject* obj) {
|
|
// THPUtils_checkIndex may call __index__ or __int__
|
|
// which may have side effects if obj is a symint node
|
|
// so we do `is_symint` check first
|
|
// TODO: maybe we should be using checkLong here?
|
|
return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
|
|
}
|
|
|
|
static bool is_int_or_symint_list(
|
|
PyObject* obj,
|
|
int broadcast_size,
|
|
int64_t* failed_idx = nullptr) {
|
|
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
|
if (PySequence_Size(obj) == 0) {
|
|
return true;
|
|
}
|
|
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
|
|
|
|
if (is_int_or_symint(item.ptr())) {
|
|
return true;
|
|
}
|
|
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
|
// in an intlist argument. Even float or complex scalar tensors.
|
|
bool r =
|
|
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
|
THPVariable_Unpack(item.ptr()).sizes().empty());
|
|
if (!r && failed_idx != nullptr) {
|
|
*failed_idx = 0;
|
|
}
|
|
return r;
|
|
}
|
|
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
|
// int
|
|
return broadcast_size > 0 && is_int_or_symint(obj);
|
|
}
|
|
|
|
// argnum is needed for raising the TypeError, it's used in the error message.
|
|
auto FunctionParameter::check(
|
|
PyObject* obj,
|
|
std::vector<py::handle>& overloaded_args,
|
|
int argnum,
|
|
int64_t* failed_idx) -> bool {
|
|
switch (type_) {
|
|
case ParameterType::TENSOR: {
|
|
if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
|
|
return true;
|
|
}
|
|
if (allow_numbers_as_tensors) {
|
|
return THPUtils_checkScalar(obj);
|
|
}
|
|
return false;
|
|
}
|
|
case ParameterType::SCALAR:
|
|
if (THPUtils_checkScalar(obj)) {
|
|
return true;
|
|
}
|
|
// fallthrough
|
|
case ParameterType::COMPLEX:
|
|
if (PyComplex_Check(obj)) {
|
|
return true;
|
|
}
|
|
// fallthrough
|
|
case ParameterType::DOUBLE: {
|
|
if (THPUtils_checkDouble(obj)) {
|
|
return true;
|
|
}
|
|
if (THPVariable_Check(obj)) {
|
|
const auto& var = THPVariable_Unpack(obj);
|
|
return !var.requires_grad() && var.dim() == 0;
|
|
}
|
|
if (torch::is_symfloat(py::handle(obj))) {
|
|
// This will induce a guard
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
case ParameterType::INT64: {
|
|
if (THPUtils_checkLong(obj)) {
|
|
return true;
|
|
}
|
|
if (THPVariable_Check(obj)) {
|
|
const auto& var = THPVariable_Unpack(obj);
|
|
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
|
|
!var.requires_grad() && var.dim() == 0;
|
|
}
|
|
if (torch::is_symint(py::handle(obj))) {
|
|
// This will induce a guard
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
case ParameterType::DIMNAME:
|
|
return THPUtils_checkDimname(obj);
|
|
case ParameterType::DIMNAME_LIST: {
|
|
if (THPUtils_checkDimnameList(obj)) {
|
|
return true;
|
|
}
|
|
// if a size is specified (e.g. DimnameList[1]) we also allow passing a
|
|
// single Dimname
|
|
return size == 1 && THPUtils_checkDimname(obj);
|
|
}
|
|
case ParameterType::TENSOR_LIST: {
|
|
return is_tensor_list_and_append_overloaded(
|
|
obj, &overloaded_args, argnum, true /* throw_error */);
|
|
}
|
|
case ParameterType::INT_LIST:
|
|
return is_int_list(obj, size, failed_idx);
|
|
case ParameterType::FLOAT_LIST:
|
|
return is_float_or_complex_list(obj);
|
|
case ParameterType::GENERATOR:
|
|
return THPGenerator_Check(obj);
|
|
case ParameterType::BOOL:
|
|
return PyBool_Check(obj);
|
|
case ParameterType::STORAGE:
|
|
return isStorage(obj);
|
|
case ParameterType::PYOBJECT:
|
|
return true;
|
|
case ParameterType::SCALARTYPE:
|
|
return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
|
|
case ParameterType::LAYOUT:
|
|
return THPLayout_Check(obj);
|
|
case ParameterType::MEMORY_FORMAT:
|
|
return THPMemoryFormat_Check(obj);
|
|
case ParameterType::QSCHEME:
|
|
return THPQScheme_Check(obj);
|
|
case ParameterType::DEVICE:
|
|
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
|
|
THPDevice_Check(obj);
|
|
case ParameterType::STREAM:
|
|
return THPStream_Check(obj);
|
|
case ParameterType::STRING:
|
|
return THPUtils_checkString(obj);
|
|
case ParameterType::SCALAR_LIST:
|
|
return is_scalar_list(obj);
|
|
case ParameterType::SYM_INT:
|
|
return is_int_or_symint(obj);
|
|
case ParameterType::SYM_INT_LIST:
|
|
return is_int_or_symint_list(obj, size, failed_idx);
|
|
default:
|
|
throw std::runtime_error("unknown parameter type");
|
|
}
|
|
}
|
|
|
|
// WARNING: these strings are parsed invalid_arguments.cpp
|
|
std::string FunctionParameter::type_name() const {
|
|
switch (type_) {
|
|
case ParameterType::TENSOR:
|
|
return "Tensor";
|
|
case ParameterType::SCALAR:
|
|
return "Number";
|
|
case ParameterType::INT64:
|
|
// NB: SymInt is intentionally not mentioned here, as conventional user
|
|
// use will only know about ints
|
|
case ParameterType::SYM_INT:
|
|
return "int";
|
|
case ParameterType::DOUBLE:
|
|
return "float";
|
|
case ParameterType::COMPLEX:
|
|
return "complex";
|
|
case ParameterType::TENSOR_LIST:
|
|
return "tuple of Tensors";
|
|
case ParameterType::INT_LIST:
|
|
return "tuple of ints";
|
|
case ParameterType::FLOAT_LIST:
|
|
return "tuple of floats";
|
|
case ParameterType::GENERATOR:
|
|
return "torch.Generator";
|
|
case ParameterType::BOOL:
|
|
return "bool";
|
|
case ParameterType::STORAGE:
|
|
return "torch.Storage";
|
|
case ParameterType::PYOBJECT:
|
|
return "object";
|
|
case ParameterType::SCALARTYPE:
|
|
return "torch.dtype";
|
|
case ParameterType::LAYOUT:
|
|
return "torch.layout";
|
|
case ParameterType::MEMORY_FORMAT:
|
|
return "torch.memory_format";
|
|
case ParameterType::QSCHEME:
|
|
return "torch.qscheme";
|
|
case ParameterType::DEVICE:
|
|
return "torch.device";
|
|
case ParameterType::STRING:
|
|
return "str";
|
|
case ParameterType::DIMNAME:
|
|
return "name";
|
|
case ParameterType::DIMNAME_LIST:
|
|
return "tuple of names";
|
|
case ParameterType::SCALAR_LIST:
|
|
return "tuple of Scalars";
|
|
case ParameterType::SYM_INT_LIST:
|
|
return "tuple of ints";
|
|
default:
|
|
throw std::runtime_error("unknown parameter type");
|
|
}
|
|
}
|
|
|
|
static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
|
|
if (s.empty())
|
|
return c10::nullopt;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
char* str_end;
|
|
long ans = strtol(s.c_str(), &str_end, 0);
|
|
// *str_end == 0 if the entire string was parsed as an integer.
|
|
return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
|
|
}
|
|
|
|
/*
|
|
Parse default value of IntArrayRef declared at native_functions.yaml
|
|
|
|
There are two kinds of default values:
|
|
1. IntArrayRef[2] x=1 (where size=2, value={1,1}
|
|
2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be
|
|
space after comma since native_parse.py uses ', ' to split args)
|
|
*/
|
|
static inline std::vector<int64_t> parse_intlist_args(
|
|
const std::string& s,
|
|
int64_t size) {
|
|
size_t n = s.size();
|
|
|
|
if (s.empty())
|
|
return std::vector<int64_t>();
|
|
|
|
// case 1. s is an int (e.g., s=2)
|
|
if (s[0] != '{') {
|
|
TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size);
|
|
return std::vector<int64_t>(size, std::stol(s));
|
|
}
|
|
|
|
// case 2. s is a list of dims (e.g., s={1,2})
|
|
|
|
// since already checked left brace '{' above, here only checks right brace
|
|
// '}'
|
|
TORCH_CHECK(
|
|
s[n - 1] == '}',
|
|
"Default value of IntArrayRef is missing right brace '}', found ",
|
|
s[n - 1]);
|
|
|
|
auto args = std::vector<int64_t>();
|
|
std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
|
|
std::string tok;
|
|
|
|
while (std::getline(ss, tok, ',')) {
|
|
args.emplace_back(std::stol(tok));
|
|
}
|
|
return args;
|
|
}
|
|
|
|
// Parse a string literal to remove quotes and escape sequences
|
|
static std::string parse_string_literal(c10::string_view str) {
|
|
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
|
|
|
|
if (str.front() == '"') {
|
|
TORCH_CHECK(
|
|
str.back() == '"', "Mismatched quotes in string default: ", str);
|
|
} else {
|
|
TORCH_CHECK(
|
|
str.front() == '\'' && str.back() == '\'',
|
|
"Invalid quotes in string default: ",
|
|
str)
|
|
}
|
|
|
|
std::string parsed;
|
|
parsed.reserve(str.size());
|
|
for (size_t i = 1; i < str.size() - 1;) {
|
|
if (str[i] != '\\') {
|
|
parsed.push_back(str[i]);
|
|
++i;
|
|
continue;
|
|
}
|
|
|
|
// Handle escape sequences
|
|
TORCH_CHECK(
|
|
i < str.size() - 2, "String ends with escaped final quote: ", str)
|
|
char c = str[i + 1];
|
|
switch (c) {
|
|
case '\\':
|
|
case '\'':
|
|
case '\"':
|
|
break;
|
|
case 'a':
|
|
c = '\a';
|
|
break;
|
|
case 'b':
|
|
c = '\b';
|
|
break;
|
|
case 'f':
|
|
c = '\f';
|
|
break;
|
|
case 'n':
|
|
c = '\n';
|
|
break;
|
|
case 'v':
|
|
c = '\v';
|
|
break;
|
|
case 't':
|
|
c = '\t';
|
|
break;
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"Unsupported escape sequence in string default: \\",
|
|
str[i + 1]);
|
|
}
|
|
parsed.push_back(c);
|
|
i += 2;
|
|
}
|
|
return parsed;
|
|
}
|
|
|
|
void FunctionParameter::set_default_str(const std::string& str) {
|
|
if (str == "None") {
|
|
allow_none = true;
|
|
}
|
|
if (type_ == ParameterType::TENSOR) {
|
|
if (str != "None") {
|
|
throw std::runtime_error(
|
|
"default value for Tensor must be none, got: " + str);
|
|
}
|
|
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
|
|
default_int = atol(str.c_str());
|
|
} else if (type_ == ParameterType::BOOL) {
|
|
default_bool = (str == "True" || str == "true");
|
|
} else if (type_ == ParameterType::DOUBLE) {
|
|
default_double = atof(str.c_str());
|
|
} else if (type_ == ParameterType::COMPLEX) {
|
|
default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
|
|
default_complex[1] = 0;
|
|
} else if (type_ == ParameterType::SCALAR) {
|
|
if (str != "None") {
|
|
// we sometimes rely on integer-vs-float values, e.g. with arange.
|
|
const auto as_integer = parse_as_integer(str);
|
|
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value())
|
|
: at::Scalar(atof(str.c_str()));
|
|
}
|
|
} else if (
|
|
type_ == ParameterType::INT_LIST ||
|
|
type_ == ParameterType::SYM_INT_LIST) {
|
|
if (str != "None") {
|
|
default_intlist = parse_intlist_args(str, size);
|
|
}
|
|
} else if (type_ == ParameterType::FLOAT_LIST) {
|
|
if (str != "None") {
|
|
throw std::runtime_error("Defaults not supported for float[]");
|
|
}
|
|
} else if (type_ == ParameterType::SCALARTYPE) {
|
|
if (str == "None") {
|
|
default_scalartype = at::ScalarType::Undefined;
|
|
} else if (str == "torch.int64") {
|
|
default_scalartype = at::ScalarType::Long;
|
|
} else {
|
|
throw std::runtime_error("invalid default value for ScalarType: " + str);
|
|
}
|
|
} else if (type_ == ParameterType::LAYOUT) {
|
|
if (str == "None") {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
|
|
} else if (str == "torch.strided") {
|
|
default_layout = at::Layout::Strided;
|
|
} else if (str == "torch.sparse_coo") {
|
|
default_layout = at::Layout::Sparse;
|
|
} else {
|
|
throw std::runtime_error("invalid default value for layout: " + str);
|
|
}
|
|
} else if (type_ == ParameterType::DEVICE) {
|
|
if (str != "None") {
|
|
throw std::runtime_error("invalid device: " + str);
|
|
}
|
|
} else if (type_ == ParameterType::STREAM) {
|
|
if (str != "None") {
|
|
throw std::runtime_error("invalid stream: " + str);
|
|
}
|
|
} else if (type_ == ParameterType::STRING) {
|
|
if (str != "None") {
|
|
default_string = parse_string_literal(str);
|
|
}
|
|
}
|
|
// These types weren't handled here before. Adding a default error
|
|
// led to a lot of test failures so adding this skip for now.
|
|
// We should correctly handle these though because it might be causing
|
|
// silent failures.
|
|
else if (type_ == ParameterType::TENSOR_LIST) { // NOLINT
|
|
// throw std::runtime_error("Invalid Tensor List");
|
|
} else if (type_ == ParameterType::GENERATOR) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::GENERATOR");
|
|
} else if (type_ == ParameterType::PYOBJECT) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::PYOBJECT");
|
|
} else if (type_ == ParameterType::MEMORY_FORMAT) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::MEMORY_FORMAT");
|
|
} else if (type_ == ParameterType::DIMNAME) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::DIMNAME");
|
|
} else if (type_ == ParameterType::DIMNAME_LIST) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::DIMNAME_LIST");
|
|
} else if (type_ == ParameterType::SCALAR_LIST) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::SCALAR_LIST");
|
|
} else if (type_ == ParameterType::STORAGE) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::STORAGE");
|
|
} else if (type_ == ParameterType::QSCHEME) { // NOLINT
|
|
// throw std::runtime_error("ParameterType::QSCHEME");
|
|
} else {
|
|
throw std::runtime_error("unknown parameter type");
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
|
: min_args(0),
|
|
max_args(0),
|
|
max_pos_args(0),
|
|
index(index),
|
|
hidden(false),
|
|
deprecated(false) {
|
|
auto open_paren = fmt.find('(');
|
|
if (open_paren == std::string::npos) {
|
|
throw std::runtime_error("missing opening parenthesis: " + fmt);
|
|
}
|
|
name = fmt.substr(0, open_paren);
|
|
|
|
bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
|
|
|
|
auto last_offset = open_paren + 1;
|
|
bool keyword_only = false;
|
|
bool done = false;
|
|
while (!done) {
|
|
auto offset = fmt.find(", ", last_offset);
|
|
auto next_offset = offset + 2;
|
|
if (offset == std::string::npos) {
|
|
offset = fmt.find(')', last_offset);
|
|
done = true;
|
|
next_offset = offset + 1;
|
|
// this 'if' happens for an empty parameter list, i.e. fn().
|
|
if (offset == last_offset) {
|
|
last_offset = next_offset;
|
|
break;
|
|
}
|
|
}
|
|
if (offset == std::string::npos) {
|
|
throw std::runtime_error("missing closing parenthesis: " + fmt);
|
|
}
|
|
if (offset == last_offset) {
|
|
throw std::runtime_error("malformed signature: " + fmt);
|
|
}
|
|
|
|
auto param_str = fmt.substr(last_offset, offset - last_offset);
|
|
last_offset = next_offset;
|
|
if (param_str == "*") {
|
|
keyword_only = true;
|
|
} else {
|
|
params.emplace_back(param_str, keyword_only);
|
|
params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
|
|
}
|
|
}
|
|
|
|
if (fmt.substr(last_offset) == "|deprecated") {
|
|
hidden = true;
|
|
// TODO: raise warning when parsing deprecated signatures
|
|
deprecated = true;
|
|
} else if (fmt.substr(last_offset) == "|hidden") {
|
|
hidden = true;
|
|
}
|
|
|
|
max_args = params.size();
|
|
|
|
// count the number of non-optional args
|
|
for (auto& param : params) {
|
|
if (!param.optional) {
|
|
min_args++;
|
|
}
|
|
if (!param.keyword_only) {
|
|
max_pos_args++;
|
|
}
|
|
}
|
|
}
|
|
|
|
std::string FunctionSignature::toString() const {
|
|
// TODO: consider printing more proper schema strings with defaults,
|
|
// optionals, etc.
|
|
std::ostringstream ss;
|
|
bool keyword_already = false;
|
|
ss << "(";
|
|
int i = 0;
|
|
for (auto& param : params) {
|
|
if (i != 0) {
|
|
ss << ", ";
|
|
}
|
|
if (param.keyword_only && !keyword_already) {
|
|
ss << "*, ";
|
|
keyword_already = true;
|
|
}
|
|
ss << param.type_name() << " " << param.name;
|
|
i++;
|
|
}
|
|
ss << ")";
|
|
return ss.str();
|
|
}
|
|
|
|
[[noreturn]] static void extra_args(
|
|
const FunctionSignature& signature,
|
|
Py_ssize_t nargs) {
|
|
const auto max_pos_args = signature.max_pos_args;
|
|
const auto min_args = signature.min_args;
|
|
const long nargs_ = nargs;
|
|
if (min_args != max_pos_args) {
|
|
throw TypeError(
|
|
"%s() takes from %zu to %zu positional arguments but %ld were given",
|
|
signature.name.c_str(),
|
|
min_args,
|
|
max_pos_args,
|
|
nargs_);
|
|
}
|
|
throw TypeError(
|
|
"%s() takes %zu positional argument%s but %ld %s given",
|
|
signature.name.c_str(),
|
|
max_pos_args,
|
|
max_pos_args == 1 ? "" : "s",
|
|
nargs_,
|
|
nargs == 1 ? "was" : "were");
|
|
}
|
|
|
|
[[noreturn]] static void missing_args(
|
|
const FunctionSignature& signature,
|
|
int idx) {
|
|
int num_missing = 0;
|
|
std::stringstream ss;
|
|
|
|
auto& params = signature.params;
|
|
for (auto it = params.begin() + idx; it != params.end(); ++it) {
|
|
if (!it->optional) {
|
|
if (num_missing > 0) {
|
|
ss << ", ";
|
|
}
|
|
ss << '"' << it->name << '"';
|
|
num_missing++;
|
|
}
|
|
}
|
|
|
|
throw TypeError(
|
|
"%s() missing %d required positional argument%s: %s",
|
|
signature.name.c_str(),
|
|
num_missing,
|
|
num_missing == 1 ? "s" : "",
|
|
ss.str().c_str());
|
|
}
|
|
|
|
static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) {
|
|
Py_ssize_t i = 0;
|
|
for (auto& param : signature.params) {
|
|
int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
|
|
if (cmp < 0) {
|
|
throw python_error();
|
|
} else if (cmp) {
|
|
return i;
|
|
}
|
|
i++;
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
[[noreturn]] static void extra_kwargs(
|
|
FunctionSignature& signature,
|
|
PyObject* kwargs,
|
|
Py_ssize_t num_pos_args) {
|
|
PyObject* key = nullptr;
|
|
PyObject* value = nullptr;
|
|
Py_ssize_t pos = 0;
|
|
|
|
while (PyDict_Next(kwargs, &pos, &key, &value)) {
|
|
if (!THPUtils_checkString(key)) {
|
|
throw TypeError("keywords must be strings");
|
|
}
|
|
|
|
auto param_idx = find_param(signature, key);
|
|
if (param_idx < 0) {
|
|
throw TypeError(
|
|
"%s() got an unexpected keyword argument '%s'",
|
|
signature.name.c_str(),
|
|
THPUtils_unpackString(key).c_str());
|
|
}
|
|
|
|
if (param_idx < num_pos_args) {
|
|
throw TypeError(
|
|
"%s() got multiple values for argument '%s'",
|
|
signature.name.c_str(),
|
|
THPUtils_unpackString(key).c_str());
|
|
}
|
|
}
|
|
|
|
// this should never be hit
|
|
throw TypeError("invalid keyword arguments");
|
|
}
|
|
|
|
bool FunctionSignature::parse(
|
|
PyObject* self,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
PyObject* dst[], // NOLINT
|
|
bool raise_exception) {
|
|
Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
|
|
auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
|
|
size_t arg_pos = 0;
|
|
bool allow_varargs_intlist = false;
|
|
|
|
// if there is a single positional IntArrayRef argument, i.e. expand(..),
|
|
// view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
|
|
// expand((5,3))
|
|
int int_list_overload = false;
|
|
if (max_pos_args == 1 &&
|
|
(params[0].type_ == ParameterType::INT_LIST ||
|
|
params[0].type_ == ParameterType::SYM_INT_LIST)) {
|
|
allow_varargs_intlist = true;
|
|
if (params[0].type_ == ParameterType::INT_LIST) {
|
|
int_list_overload = true;
|
|
}
|
|
}
|
|
|
|
if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
|
|
if (raise_exception) {
|
|
// foo() takes takes 2 positional arguments but 3 were given
|
|
extra_args(*this, nargs);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
if (!overloaded_args.empty()) {
|
|
overloaded_args.clear();
|
|
}
|
|
|
|
int i = 0;
|
|
if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) {
|
|
append_overloaded_tensor(&this->overloaded_args, self);
|
|
}
|
|
for (auto& param : params) {
|
|
PyObject* obj = nullptr;
|
|
bool is_kwd = false;
|
|
if (arg_pos < static_cast<size_t>(nargs)) {
|
|
// extra positional args given after single positional IntArrayRef arg
|
|
if (param.keyword_only) {
|
|
if (raise_exception) {
|
|
extra_args(*this, nargs);
|
|
}
|
|
return false;
|
|
}
|
|
obj = PyTuple_GET_ITEM(args, arg_pos);
|
|
} else if (kwargs) {
|
|
obj = PyDict_GetItem(kwargs, param.python_name);
|
|
for (PyObject* numpy_name : param.numpy_python_names) {
|
|
if (obj) {
|
|
break;
|
|
}
|
|
obj = PyDict_GetItem(kwargs, numpy_name);
|
|
}
|
|
is_kwd = true;
|
|
}
|
|
|
|
int64_t failed_idx = -1;
|
|
bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
|
|
if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
|
|
dst[i++] = nullptr;
|
|
} else if (!obj) {
|
|
if (raise_exception) {
|
|
// foo() missing 1 required positional argument: "b"
|
|
missing_args(*this, i);
|
|
}
|
|
return false;
|
|
} else if (param.check(obj, this->overloaded_args, i, &failed_idx)) {
|
|
dst[i++] = obj;
|
|
// XXX: the Variable check is necessary because sizes become tensors when
|
|
// tracer is enabled. This behavior easily leads to ambiguities, and we
|
|
// should avoid having complex signatures that make use of it...
|
|
} else if (
|
|
varargs_eligible &&
|
|
((int_list_overload
|
|
? is_int_list(args, param.size, &failed_idx)
|
|
: is_int_or_symint_list(args, param.size, &failed_idx)))) {
|
|
// take all positional arguments as this parameter
|
|
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
|
dst[i++] = args;
|
|
arg_pos = nargs;
|
|
continue;
|
|
} else if (raise_exception) {
|
|
if (is_kwd) {
|
|
// foo(): argument 'other' must be str, not int
|
|
throw TypeError(
|
|
"%s(): argument '%s' must be %s, not %s",
|
|
name.c_str(),
|
|
param.name.c_str(),
|
|
param.type_name().c_str(),
|
|
Py_TYPE(obj)->tp_name);
|
|
} else {
|
|
// foo(): argument 'other' (position 2) must be str, not int
|
|
if (failed_idx != -1) {
|
|
if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
|
|
TORCH_INTERNAL_ASSERT(varargs_eligible);
|
|
obj = args;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
|
|
throw TypeError(
|
|
"%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
|
|
name.c_str(),
|
|
param.name.c_str(),
|
|
static_cast<long>(arg_pos + 1),
|
|
param.type_name().c_str(),
|
|
Py_TYPE(py::reinterpret_steal<py::object>(
|
|
PySequence_GetItem(obj, failed_idx))
|
|
.ptr())
|
|
->tp_name,
|
|
static_cast<long>(failed_idx));
|
|
}
|
|
throw TypeError(
|
|
"%s(): argument '%s' (position %ld) must be %s, not %s",
|
|
name.c_str(),
|
|
param.name.c_str(),
|
|
static_cast<long>(arg_pos + 1),
|
|
param.type_name().c_str(),
|
|
Py_TYPE(obj)->tp_name);
|
|
}
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
if (!is_kwd) {
|
|
arg_pos++;
|
|
} else if (obj) {
|
|
remaining_kwargs--;
|
|
}
|
|
}
|
|
|
|
if (remaining_kwargs > 0) {
|
|
if (raise_exception) {
|
|
// foo() got an unexpected keyword argument "b"
|
|
extra_kwargs(*this, kwargs, nargs);
|
|
}
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
PythonArgParser::PythonArgParser(std::vector<std::string> fmts, bool traceable)
|
|
: max_args(0), traceable(traceable) {
|
|
int index = 0;
|
|
for (auto& fmt : fmts) {
|
|
signatures_.emplace_back(fmt, index);
|
|
++index;
|
|
}
|
|
for (auto& signature : signatures_) {
|
|
if (signature.max_args > max_args) {
|
|
max_args = signature.max_args;
|
|
}
|
|
}
|
|
if (!signatures_.empty()) {
|
|
function_name = signatures_[0].name;
|
|
}
|
|
|
|
// Check deprecated signatures last
|
|
std::stable_partition(
|
|
signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) {
|
|
return !sig.deprecated;
|
|
});
|
|
}
|
|
|
|
void PythonArgParser::check_deprecated(const FunctionSignature& signature) {
|
|
if (signature.deprecated) {
|
|
auto msg = c10::str(
|
|
"This overload of ",
|
|
signature.name,
|
|
" is deprecated:\n\t",
|
|
signature.name,
|
|
signature.toString());
|
|
auto signatures = get_signatures();
|
|
if (!signatures.empty()) {
|
|
msg += "\nConsider using one of the following signatures instead:";
|
|
for (const auto& sig : signatures) {
|
|
msg += "\n\t";
|
|
msg += signature.name;
|
|
msg += sig;
|
|
}
|
|
}
|
|
TORCH_WARN_ONCE(msg);
|
|
}
|
|
}
|
|
|
|
PythonArgs PythonArgParser::raw_parse(
|
|
PyObject* self,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
PyObject* parsed_args[]) { // NOLINT
|
|
if (signatures_.size() == 1) {
|
|
auto& signature = signatures_[0];
|
|
signature.parse(self, args, kwargs, parsed_args, true);
|
|
check_deprecated(signature);
|
|
return PythonArgs(traceable, signature, parsed_args);
|
|
}
|
|
|
|
for (auto& signature : signatures_) {
|
|
if (signature.parse(self, args, kwargs, parsed_args, false)) {
|
|
check_deprecated(signature);
|
|
return PythonArgs(traceable, signature, parsed_args);
|
|
}
|
|
}
|
|
|
|
print_error(self, args, kwargs, parsed_args);
|
|
}
|
|
|
|
void PythonArgParser::print_error(
|
|
PyObject* self,
|
|
PyObject* args,
|
|
PyObject* kwargs,
|
|
PyObject* parsed_args[]) { // NOLINT
|
|
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
|
size_t num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0);
|
|
std::vector<unsigned> plausible_idxs;
|
|
unsigned i = 0;
|
|
for (auto& signature : signatures_) {
|
|
if (num_args >= signature.min_args && num_args <= signature.max_args &&
|
|
!signature.hidden) {
|
|
plausible_idxs.push_back(i);
|
|
}
|
|
i++;
|
|
}
|
|
|
|
if (plausible_idxs.size() == 1) {
|
|
auto& signature = signatures_[plausible_idxs[0]];
|
|
signature.parse(self, args, kwargs, parsed_args, true);
|
|
}
|
|
|
|
auto options = get_signatures();
|
|
auto msg =
|
|
torch::format_invalid_args(args, kwargs, function_name + "()", options);
|
|
throw TypeError("%s", msg.c_str());
|
|
}
|
|
|
|
std::vector<std::string> PythonArgParser::get_signatures() const {
|
|
std::vector<std::string> options;
|
|
for (auto& signature : signatures_) {
|
|
if (!signature.hidden) {
|
|
options.push_back(signature.toString());
|
|
}
|
|
}
|
|
return options;
|
|
}
|
|
|
|
at::Tensor PythonArgs::tensor_slow(int i) {
|
|
PyObject* obj = args[i];
|
|
if (!obj) {
|
|
return at::Tensor();
|
|
}
|
|
if (THPVariable_Check(obj)) {
|
|
return THPVariable_Unpack(obj);
|
|
}
|
|
|
|
bool save_symint = false;
|
|
at::Scalar scalar;
|
|
if (PyBool_Check(obj)) {
|
|
scalar = at::Scalar(THPUtils_unpackBool(obj));
|
|
} else if (THPUtils_checkLong(obj)) {
|
|
scalar = at::Scalar(THPUtils_unpackLong(obj));
|
|
} else if (PyComplex_Check(obj)) {
|
|
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
|
|
} else if (THPUtils_checkDouble(obj)) {
|
|
scalar = at::Scalar(THPUtils_unpackDouble(obj));
|
|
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
|
|
// because although Scalar supports SymInt/SymFloat, the subsequent
|
|
// conversion to Tensor does not. Instead, do it out of band.
|
|
} else if (torch::is_symint(py::handle(obj))) {
|
|
save_symint = true;
|
|
// This scalar value doesn't matter, it shouldn't ever actually
|
|
// get read out. Make it a big and weird looking number to help
|
|
// people figure out if there's aproblem.
|
|
scalar = at::Scalar(7777777);
|
|
} else if (torch::is_symfloat(py::handle(obj))) {
|
|
save_symint = true;
|
|
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
|
} else if (torch::is_symbool(py::handle(obj))) {
|
|
save_symint = true;
|
|
scalar = at::Scalar(true);
|
|
} else {
|
|
// NB: Are you here because you passed None to a Variable method,
|
|
// and you expected an undefined tensor to be returned? Don't add
|
|
// a test for Py_None here; instead, you need to mark the argument
|
|
// as *allowing none*; you can do this by writing 'Tensor?' instead
|
|
// of 'Tensor' in the ATen metadata.
|
|
throw TypeError(
|
|
"expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name);
|
|
}
|
|
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
|
|
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
|
|
|
at::Tensor tensor = scalar_to_tensor(scalar);
|
|
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
|
|
|
|
if (save_symint) {
|
|
auto py_tensor = py::cast(tensor);
|
|
if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
return tensor;
|
|
}
|
|
|
|
at::Scalar PythonArgs::scalar_slow(int i) {
|
|
if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
|
|
auto& var = THPVariable_Unpack(args[i]);
|
|
jit::tracer::ArgumentStash::stashValue(
|
|
signature.params[i].name, idx, var, c10::NumberType::get());
|
|
}
|
|
|
|
return scalar_slow(args[i]);
|
|
}
|
|
|
|
at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
|
|
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't
|
|
// currently handle most NumPy scalar types except np.float64.
|
|
if (THPVariable_Check(arg)) {
|
|
return THPVariable_Unpack(arg).item();
|
|
}
|
|
|
|
if (THPUtils_checkLong(arg)) {
|
|
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
|
|
}
|
|
|
|
if (PyBool_Check(arg)) {
|
|
return at::Scalar(THPUtils_unpackBool(arg));
|
|
}
|
|
|
|
if (PyComplex_Check(arg)) {
|
|
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
|
}
|
|
|
|
if (torch::is_symint(arg)) {
|
|
return at::Scalar(py::cast<c10::SymInt>(arg));
|
|
}
|
|
|
|
if (torch::is_symfloat(arg)) {
|
|
return at::Scalar(py::cast<c10::SymFloat>(arg));
|
|
}
|
|
|
|
if (torch::is_symbool(arg)) {
|
|
// Windows build fails with C2440: '<function-style-cast>'
|
|
// when at:Scalar(py::cast<c10::SymBool>(arg))
|
|
auto sym_bool = py::handle(arg).cast<c10::SymBool>();
|
|
return at::Scalar(sym_bool);
|
|
}
|
|
|
|
return at::Scalar(THPUtils_unpackDouble(arg));
|
|
}
|
|
|
|
} // namespace torch
|