pytorch/torch/csrc/utils/python_arg_parser.cpp
Brian Hirsh da54f3c519 reorder proxy / fake modes so they always run last (#104482)
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:

(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.

I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).

`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.

`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".

`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).

There were also some mild codegen changes to support the new enum

(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`

The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.

I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.

(2) dispatching logic in `python_arg_parser.cpp`

This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode()  # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```

Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.

How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.

(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD  pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).

----- original description below -----

The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit

Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.

**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**

This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.

I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.

This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.

**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**

This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.

**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**

Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.

**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"

**(5) Fixed subclass fakeification**

@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.

**(6) make tensor subclasses resizeable**

Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.

**(7) Added a basic DoubleTensor subclass for testing**

I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482
Approved by: https://github.com/ezyang
ghstack dependencies: #107415
2023-08-29 02:36:48 +00:00

1806 lines
58 KiB
C++

#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/MemoryFormat.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/invalid_arguments.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/python_torch_function_mode.h>
#include <torch/csrc/utils/torch_dispatch_mode.h>
#include <ATen/ATen.h>
#include <ATen/PythonTorchFunctionTLS.h>
#include <ATen/TracerMode.h>
#include <c10/util/irange.h>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
namespace torch {
static std::unordered_map<std::string, ParameterType> type_map = {
{"Tensor", ParameterType::TENSOR},
{"Scalar", ParameterType::SCALAR},
{"int64_t", ParameterType::INT64},
{"SymInt", ParameterType::SYM_INT},
{"double", ParameterType::DOUBLE},
{"complex", ParameterType::COMPLEX},
{"TensorList", ParameterType::TENSOR_LIST},
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
{"IntArrayRef", ParameterType::INT_LIST},
{"SymIntArrayRef", ParameterType::SYM_INT_LIST},
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
{"Generator", ParameterType::GENERATOR},
{"bool", ParameterType::BOOL},
{"Storage", ParameterType::STORAGE},
{"PyObject*", ParameterType::PYOBJECT},
{"ScalarType", ParameterType::SCALARTYPE},
{"Layout", ParameterType::LAYOUT},
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
{"QScheme", ParameterType::QSCHEME},
{"Device", ParameterType::DEVICE},
{"DeviceIndex", ParameterType::INT64},
{"Stream", ParameterType::STREAM},
{"std::string", ParameterType::STRING},
{"c10::string_view", ParameterType::STRING},
{"Dimname", ParameterType::DIMNAME},
{"DimnameList", ParameterType::DIMNAME_LIST},
{"ScalarList", ParameterType::SCALAR_LIST},
};
// Default arg name translations for compatibility with NumPy.
//
// Example:
// ```python
// t = torch.randn(10,10)
// torch.sum(a=t, axis=0, keepdim=True)
// ```
//
// A vector is necessary, because we might need to try multiple values.
// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input
// tensor. Rather than annotate each function separately with whether it should
// take "x" or "a", just try both.
//
// TODO: Allow individual functions to specify non-default translations:
// For example, `torch.pow` should translate "exponent" to "x2".
static const std::unordered_map<std::string, std::vector<std::string>>
numpy_compatibility_arg_names = {
{"dim", {"axis"}},
{"keepdim", {"keepdims"}},
{"input", {"x", "a", "x1"}},
{"other", {"x2"}},
};
// TODO: remove this. This is a temporary list of functions that allow Python
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
// overloads and binding to the Tensor overload with a number of a different
// type will trigger a type error.
//
// If you modify this, you will need to adjust the blocklist in
// tools/pyi/gen_pyi.py (and add hardcoded signatures for these
// functions.)
bool should_allow_numbers_as_tensors(const std::string& name) {
static std::unordered_set<std::string> allowed = {
"add", "add_", "add_out",
"div", "div_", "div_out",
"divide", "divide_", "divide_out", // alias of div
"mul", "mul_", "mul_out",
"multiply", "multiply_", "multiply_out", // alias of mul
"sub", "sub_", "sub_out",
"subtract", "subtract_", "subtract_out", // alias of sub
"true_divide", "true_divide_", "true_divide_out",
"to", "_to_copy", "copy_",
"floor_divide", "floor_divide_", "floor_divide_out"};
return allowed.find(name) != allowed.end();
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
: optional(false),
allow_none(false),
keyword_only(keyword_only),
size(0),
default_scalar(0) {
auto space = fmt.find(' ');
if (space == std::string::npos) {
throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
}
auto type_str = fmt.substr(0, space);
auto question = type_str.find('?');
if (question != std::string::npos) {
allow_none = true;
type_str = type_str.substr(0, question);
}
// Parse and remove brackets from type_str
auto bracket = type_str.find('[');
if (bracket != std::string::npos) {
auto size_str =
type_str.substr(bracket + 1, type_str.length() - bracket - 2);
size = atoi(size_str.c_str());
type_str = type_str.substr(0, bracket);
}
auto name_str = fmt.substr(space + 1);
auto it = type_map.find(type_str);
if (it == type_map.end()) {
throw std::runtime_error(
"FunctionParameter(): invalid type string: " + type_str);
}
type_ = it->second;
auto eq = name_str.find('=');
if (eq != std::string::npos) {
name = name_str.substr(0, eq);
optional = true;
set_default_str(name_str.substr(eq + 1));
} else {
name = name_str;
}
python_name = THPUtils_internString(name);
auto np_compat_it = numpy_compatibility_arg_names.find(name);
if (np_compat_it != numpy_compatibility_arg_names.end()) {
for (const auto& str : np_compat_it->second) {
numpy_python_names.push_back(THPUtils_internString(str));
}
}
}
auto handle_torch_function_getter(
THPVariable* self,
const std::string& property_name) -> PyObject* {
py::object torch_api = PyObject_FastGetAttrString(
THPVariableClass, (char*)property_name.c_str());
std::string module_name = "torch.Tensor." + property_name;
return handle_torch_function(
(PyObject*)self,
"__get__",
nullptr,
nullptr,
torch_api.ptr(),
module_name);
}
auto handle_torch_function_setter(
THPVariable* self,
const std::string& property_name,
PyObject* value) -> int {
py::object torch_api = PyObject_FastGetAttrString(
THPVariableClass, (char*)property_name.c_str());
std::string module_name = "torch.Tensor." + property_name;
if (value != nullptr) {
py::tuple args_ = py::make_tuple(py::handle(value));
handle_torch_function(
(PyObject*)self,
"__set__",
args_.ptr(),
nullptr,
torch_api.ptr(),
module_name);
} else {
handle_torch_function(
(PyObject*)self,
"__delete__",
nullptr,
nullptr,
torch_api.ptr(),
module_name);
}
return 0;
}
// Combines self and args into one tuple.
static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
if (args == nullptr) {
return py::make_tuple(py::handle(self));
} else if (self == nullptr) {
return py::reinterpret_borrow<py::tuple>(args);
}
auto py_args = py::reinterpret_borrow<py::tuple>(args);
size_t n = py_args.size();
auto args_ = py::tuple(n + 1);
args_[0] = py::handle(self);
for (const auto i : c10::irange(n)) {
args_[i + 1] = py_args[i];
}
return args_;
}
// TODO: I'm not sure if I should call this __torch_function__ or
// torch_function. The former makes it easier to take an existing
// Tensor-like __torch_function__ object and turn it into a mode;
// but in general modes don't have to be Tensor-like (and we will
// improperly accept mode objects as arguments when they shouldn't
// be passed around in this way).
const char* torch_function_mode_name = "__torch_function__";
auto handle_torch_function(
PyObject* self,
const std::string& func_name,
PyObject* args,
PyObject* kwargs,
PyObject* torch_api,
const std::string& module_name) -> PyObject* {
py::object torch_api_function =
PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
TORCH_INTERNAL_ASSERT(
torch_api_function.ptr() != nullptr, "torch API function must exist");
py::tuple args_ = combine_self_args(self, args);
return handle_torch_function_no_python_arg_parser(
{self},
args_.ptr(),
kwargs,
func_name.c_str(),
torch_api_function.ptr(),
module_name.c_str(),
TorchFunctionName::TorchFunction);
}
// Note: [Overloaded args]
// An overloaded arg may be one of the following:
// - an instance of an object that has a __torch_function__ method
// - an instance of an object that has a __torch_dispatch__ classmethod
// - a class type that has a __torch_dispatch__ classmethod
//
// This function returns the type of the arg (if the arg is an instance),
// otherwise, it returns the arg.
static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
if (PyType_Check(obj_or_type)) {
return obj_or_type;
}
return (PyObject*)Py_TYPE(obj_or_type);
}
static py::object dispatch_on_subclass(
PyObject* args,
PyObject* kwargs,
at::ArrayRef<PyObject*> overloaded_args,
py::tuple py_types,
PyObject* torch_api_function,
bool is_torch_function,
const char* torch_function_name_str,
c10::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key =
c10::nullopt) {
py::object ret;
for (auto& arg : overloaded_args) {
py::object torch_function =
PyObject_FastGetAttrString(arg, torch_function_name_str);
if (!torch_function) {
TORCH_INTERNAL_ASSERT(0);
}
if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) {
// During __torch_dispatch__, don't dispatch on args with a disabled
// torch_dispatch. This code runs before infra modes, so we need to make
// sure that infra modes can run first. (In theory, maybe we can rearrange
// things so that infra modes are *always* attempted first, and just
// return NotImplemented when there are any user subclasses. Maybe that
// would fix this problem?)
continue;
}
// See https://github.com/pytorch/pytorch/issues/63767
if (is_torch_function &&
PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
.is(py::handle(arg)) &&
torch_function.ptr() != torch::disabled_torch_function_impl()) {
TORCH_WARN(
"Defining your `",
torch_function_name_str,
"` as a plain method is deprecated ",
"and will be an error in future, please define it as a classmethod.");
}
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
torch_function.ptr(),
torch_api_function,
py_types.ptr(),
args,
kwargs,
NULL));
if (ret.ptr() == nullptr) {
throw python_error();
}
if (ret.ptr() != Py_NotImplemented) {
// Return the reference to the result. This also covers the case where
// ret is NULL and __torch_function__/__torch_dispatch raised an
// exception, which we throw below
break;
}
}
return ret;
}
static std::tuple<py::object, py::object> dispatch_on_mode(
PyObject* args,
PyObject* kwargs,
py::tuple py_types,
PyObject* torch_api_function,
bool is_torch_function,
const char* torch_function_name_str) {
// Disable mode on the inside; this makes for a more user-friendly
// experience if you try to, e.g., print your tensors.
at::optional<torch::overrides::StashTorchFunctionModeGuard> tf_g;
at::optional<torch_dispatch_mode::StashTorchDispatchModeGuard> td_g;
py::object mode_obj;
// NB: We only really need keep the mode_obj live if the function call
// fails for error reporting, but whatever, Python refcounts are cheap
if (is_torch_function) {
tf_g.emplace();
mode_obj = py::reinterpret_borrow<py::object>(
tf_g->get_cur_mode()->ptr(getPyInterpreter()));
} else {
td_g.emplace();
mode_obj = py::reinterpret_borrow<py::object>(
td_g->get_cur_mode()->ptr(getPyInterpreter()));
}
py::object torch_function =
PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str);
if (!torch_function) {
TORCH_INTERNAL_ASSERT(0);
}
TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr);
TORCH_INTERNAL_ASSERT(args != nullptr);
TORCH_CHECK(
PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj),
"Defining your mode's `",
torch_function_name_str,
"` as a classmethod is not supported, please make it a plain method");
// Blegh. This accidentally works in PyObject_CallFunctionObjArgs below
// because the nullptr terminates the argument list ick ick ick.
py::object ret;
if (kwargs == nullptr) {
ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
mode_obj.ptr(),
torch_function_name_str,
"OOO",
torch_api_function,
py_types.ptr(),
args));
} else {
ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
mode_obj.ptr(),
torch_function_name_str,
"OOOO",
torch_api_function,
py_types.ptr(),
args,
kwargs));
}
if (ret.ptr() == nullptr) {
throw python_error();
}
return std::make_tuple(ret, mode_obj);
}
// See Note: [Overloaded args] for what they hold
auto handle_torch_function_no_python_arg_parser(
at::ArrayRef<PyObject*> overloaded_args,
PyObject* args,
PyObject* kwargs,
const char* func_name,
PyObject* torch_api_function,
const char* module_name,
TorchFunctionName torch_function_name) -> PyObject* {
const char* torch_function_name_str = nullptr;
switch (torch_function_name) {
case TorchFunctionName::TorchFunction:
torch_function_name_str = "__torch_function__";
break;
case TorchFunctionName::TorchDispatch:
torch_function_name_str = "__torch_dispatch__";
break;
default:
TORCH_INTERNAL_ASSERT(0, static_cast<int>(torch_function_name));
}
// overloaded_args already all have unique types
// nb: modes don't go in the overloaded types list, as they are not
// necessarily types
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& arg : overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg)));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
py::object mode_obj;
// Step 1: Try to dispatch based on the mode stack, *ignoring* infra
// torch_dispatch modes.
const bool is_torch_function =
torch_function_name == TorchFunctionName::TorchFunction;
const auto is_mode_active = [&]() {
return is_torch_function
? at::impl::torch_function_mode_enabled()
// Check if any *user* torch_dispatch modes are active (not including
// fake and proxy modes, which are special)
: c10::impl::dispatch_mode_enabled();
};
// Note [__torch_dispatch__ dispatching order]
// The high-level idea motivating the dispatching
// order below is that: (1) modes get higher dispatch precedence over
// subclasses (2) "user" modes/subclasses get higher dispatch precedence over
// "infra" modes/subclasses.
//
// To give a complete example: let's say we are running torch.compile, with
// the following "user" modes and subclasses:
// mode_stack: [ModeA]
// user_args: [MyWrapperSubclassB(torchTensor)]
// During tracing in AOTAutograd tracing, we use some additional infra modes
// and subclasses to perform tracing:
// FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode,
// FunctionalTensor, FakeTensor
// The modified mode stack and tracing arguments will look like this:
// mode_stack (user modes): [ModeA]
// mode_stack (infra modes): [
// FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode
// ]
// tracing_args: [
// MyWrapperSubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor)))
// ]
// And the dispatching order that we want is as follows:
// (1) ModeA.__torch_dispatch__ (user modes highest)
// (2) MyWrapperSubclassB.__torch_dispatch__ (user subclasses next highest)
// (3) FunctionalTensorMode.__torch_dispatch__ (infra modes next highest)
// (4) ProxyTorchDispatchMode.__torch_dispatch__ (infra modes next highest)
// (5) FakeTensorMode.__torch_dispatch__ (infra modes next highest)
// (6) FakeTensor.__torch_fake_dispatch__ (infra subclasses next highest)
// Why does do FunctionalTensor and FakeTensor even need to be special-cased
// in the ordering?
// In theory we could remove their __torch_dispatch__, but both of these
// subclasses override sizes/strides metadata calls with __torch_dispatch__,
// which would mean a mode would be **required** to access their metadata.
if (is_mode_active()) {
// Step 1: Try to dispatch on any user TorchDispatchModes (including infra
// modes, which will always be at the bottom of the mode stack).
auto ret_ = dispatch_on_mode(
args,
kwargs,
py_types,
torch_api_function,
is_torch_function,
torch_function_name_str);
ret = std::get<0>(ret_);
mode_obj = std::get<1>(ret_);
}
// Step 2: Try to dispatch based on any user subclasses,
// ignoring any subclasses that have a _mode_key field
// (corresponding to infra subclasses)
// Note: user subclasses should always run *before* infra modes like
// proxy/fake. This is handles by having proxy/fake modes return
// NotImplemented when they see a user subclass that they don't understand.
if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
auto curr_ret = dispatch_on_subclass(
args,
kwargs,
overloaded_args,
py_types,
torch_api_function,
is_torch_function,
torch_function_name_str);
if (curr_ret.ptr() != nullptr) {
ret = curr_ret;
}
}
if (ret.ptr() == nullptr) {
// if an exception occurred in a user's implementation of
// __torch_function__, throw it
throw python_error();
} else if (ret.ptr() == Py_NotImplemented) {
// all __torch_function__ implementations in overloaded_args
// returned NotImplemented, so we raise a TypeError.
std::stringstream ss;
ss << "Multiple dispatch failed for '";
if (module_name && func_name) {
ss << module_name << "." << func_name;
} else {
py::handle fn = torch_api_function;
ss << py::str(fn.attr("__module__")) << "."
<< py::str(fn.attr("__name__"));
}
ss << "'; all " << torch_function_name_str
<< " handlers returned NotImplemented:\n\n";
if (mode_obj) {
ss << " - mode object " << py::repr(mode_obj) << "\n";
}
for (auto& arg : overloaded_args) {
ss << " - tensor subclass " << py::repr(get_type_of_overloaded_arg(arg))
<< "\n";
}
ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
const std::string& tmp = ss.str();
PyErr_SetString(PyExc_TypeError, tmp.c_str());
throw python_error();
}
return ret.release().ptr();
}
auto handle_torch_function(
PythonArgs& r,
PyObject* self,
PyObject* args,
PyObject* kwargs,
PyObject* torch_api,
const char* module_name,
const char* func_name_override) -> PyObject* {
py::object torch_api_function = PyObject_FastGetAttrString(
torch_api,
(char*)(func_name_override ? func_name_override : r.get_func_name().c_str()));
TORCH_INTERNAL_ASSERT(
torch_api_function.ptr() != nullptr, "torch API function must exist");
py::tuple args_ = combine_self_args(self, args);
return handle_torch_function_no_python_arg_parser(
r.overloaded_args,
args_.ptr(),
kwargs,
r.get_func_name().c_str(),
torch_api_function.ptr(),
module_name);
}
auto handle_torch_function(
PythonArgs& r,
PyObject* args,
PyObject* kwargs,
PyObject* torch_api,
const char* module_name,
const char* func_name_override) -> PyObject* {
return handle_torch_function(
r, nullptr, args, kwargs, torch_api, module_name, func_name_override);
}
auto handle_torch_function_indexing(
PyObject* self,
PyObject* index,
PyObject* val) -> PyObject* {
const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
py::object index_tup;
if (PyTuple_Check(index)) {
index_tup = py::reinterpret_borrow<py::object>(index);
} else {
index_tup = py::make_tuple(py::handle(index));
}
std::vector<PyObject*> overridable_args;
is_tensor_and_append_overloaded(self, &overridable_args);
auto size = PyTuple_GET_SIZE(index_tup.ptr());
for (auto i : c10::irange(size)) {
auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
is_tensor_and_append_overloaded(obj, &overridable_args);
}
if (val != nullptr) {
is_tensor_and_append_overloaded(val, &overridable_args);
}
py::object func =
PyObject_FastGetAttrString(THPVariableClass, (char*)func_name);
py::object args = (val == nullptr)
? py::make_tuple(py::handle(self), py::handle(index))
: py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
return handle_torch_function_no_python_arg_parser(
overridable_args,
args.ptr(),
nullptr,
func_name,
func.ptr(),
"torch.Tensor");
}
/*
* obj has a __torch_function__ implementation and may either be a
* subclass of Tensor or a Tensor-like duck type. We may need to
* append this object to the overloaded_args vector, which tracks all
* of the arguments with distinct __torch_function__ implementations
* we've seen so far.
*
* If this is the first argument we've seen with __torch_function__
* defined, we unconditionally add obj to the overloaded_args vector.
*
* If we've already seen arguments with __torch_function__ defined,
* then we first need to check if obj is the same type as any of the
* entries in overloaded_args. If so, we can ignore obj since we
* already have an entry in overloaded_args with the same
* __torch_function__ implementation.
*
* If it's a different type, we then need to check if it's a subclass
* of one of the types we've already seen. If so, we need to insert an
* entry in overloaded_args for this type with higher precedence than
* the superclass.
*
* See torch._overrides._get_overloaded_args for the equivalent
* function in the Python __torch_function__ implementation.
*
* The precedence-determining algorithm implemented in this function is
* described in NEP-0018:
* https://numpy.org/neps/nep-0018-array-function-protocol.html
*
* 'overloaded_args' is a raw pointer to a vector of pybind11 handles
* that have distinct __torch_function__ implementations, in order of calling
* precedence.
*
* 'obj' is an object to check for a __torch_function__ implementation
*
* If changing this file in a way that can affect the __torch_function__
* overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
* See the instructions in the 'README.md' in that directory.
*
*/
static void append_overloaded_arg(
std::vector<PyObject*>* overloaded_args,
PyObject* obj,
bool obj_is_type) {
bool class_not_seen_yet = true;
PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
for (auto& arg : *overloaded_args) {
if (obj_type == get_type_of_overloaded_arg(arg)) {
// obj is the same type as another parameter we've seen in a prior
// iteration of the loop over parameters so we already have an entry
// with the proper __torch_function__ implementation to call, so skip
// this parameter
class_not_seen_yet = false;
break;
}
}
if (class_not_seen_yet) {
auto arg_index = overloaded_args->size();
for (const auto j : c10::irange(arg_index)) {
if (PyObject_IsSubclass(
obj_type, get_type_of_overloaded_arg((*overloaded_args)[j]))) {
// obj is a subclass of another object we've seen already so its
// __torch_function__ should be called first, therefore we
// insert it into overloaded_args before the superclass
arg_index = j;
break;
}
}
// add object to overloaded_args. If it's a subclass of another class
// we've already seen it will be inserted before the superclass,
// otherwise it will be inserted at the end of the array
overloaded_args->insert(
overloaded_args->begin() + static_cast<long>(arg_index), obj);
}
}
void append_overloaded_tensor(
std::vector<PyObject*>* overloaded_args,
PyObject* obj) {
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false);
}
void append_overloaded_type(
std::vector<PyObject*>* overloaded_args,
PyObject* obj) {
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true);
}
bool is_tensor_and_append_overloaded(
PyObject* obj,
std::vector<PyObject*>* overloaded_args) {
if (THPVariable_CheckExact(obj)) {
// torch.Tensor instances (not subclasses, except for Parameter)
return true;
}
if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
// tensor subclasses and unrelated objects with __torch_function__
append_overloaded_tensor(overloaded_args, obj);
return true;
} else if (THPVariable_Check(obj)) {
// tensor subclasses without __torch_function__
return true;
}
return false;
}
static bool is_scalar_list(PyObject* obj) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
for (const auto idx : c10::irange(size)) {
PyObject* iobj =
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
if (!THPUtils_checkScalar(iobj)) {
return false;
}
}
return true;
}
bool is_tensor_list_and_append_overloaded(
PyObject* obj,
std::vector<PyObject*>* overloaded_args,
int argnum,
bool throw_error) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
for (long idx = 0; idx < size; idx++) {
PyObject* iobj =
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
if (throw_error) {
throw TypeError(
"expected Tensor as element %d in argument %d, but got %s",
static_cast<int>(idx),
argnum,
Py_TYPE(iobj)->tp_name);
}
return false;
}
}
return true;
}
static bool is_float_or_complex_list(PyObject* obj) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
if (size > 0) {
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
return false;
}
}
return true;
}
static bool is_int_list(
PyObject* obj,
int broadcast_size,
int64_t* failed_idx = nullptr) {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
auto len = PySequence_Size(obj);
if (len == 0) {
return true;
}
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
bool int_first = false;
if (THPUtils_checkIndex(item.ptr())) {
// we still have to check that the rest of items are NOT symint nodes
int_first = true;
}
// Make sure none of the later arguments are SymInt
// NB: do NOT check that the later arguments are ints, as this is
// BC-breaking for FX
for (int i = 1; i < len; i++) {
if (torch::is_symint(
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
if (failed_idx != nullptr) {
*failed_idx = i;
}
return false;
}
}
if (int_first) {
return true;
}
// in dynamo, FakeTensor is qualified for INT_LIST
if (is_dynamo_compiling && THPVariable_Check(item.ptr())) {
auto& var = THPVariable_Unpack(item.ptr());
if (var.numel() != 1 || !var.sizes().empty() ||
!at::isIntegralType(
var.dtype().toScalarType(), /*include_bool*/ true)) {
if (failed_idx != nullptr) {
*failed_idx = 0;
}
return false;
}
return true;
}
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
bool r =
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes().empty());
if (!r && failed_idx != nullptr) {
*failed_idx = 0;
}
return r;
}
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
// int
return broadcast_size > 0 && THPUtils_checkLong(obj);
}
static bool is_int_or_symint(PyObject* obj) {
// THPUtils_checkIndex may call __index__ or __int__
// which may have side effects if obj is a symint node
// so we do `is_symint` check first
// TODO: maybe we should be using checkLong here?
if (torch::is_symint(py::handle(obj))) {
return true;
}
if (THPUtils_checkIndex(obj)) {
return true;
}
// FakeTensor(..., size=()) is qualified for SymInt param
if (is_dynamo_compiling && THPVariable_Check(obj)) {
auto& var = THPVariable_Unpack(obj);
if (var.numel() == 1 && var.sizes().empty() &&
at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
return true;
}
}
return false;
}
static bool is_int_or_symint_list(
PyObject* obj,
int broadcast_size,
int64_t* failed_idx = nullptr) {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
if (PySequence_Size(obj) == 0) {
return true;
}
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
if (is_int_or_symint(item.ptr())) {
return true;
}
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
bool r =
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes().empty());
if (!r && failed_idx != nullptr) {
*failed_idx = 0;
}
return r;
}
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
// int
return broadcast_size > 0 && is_int_or_symint(obj);
}
// argnum is needed for raising the TypeError, it's used in the error message.
auto FunctionParameter::check(
PyObject* obj,
std::vector<PyObject*>& overloaded_args,
int argnum,
int64_t* failed_idx) -> bool {
switch (type_) {
case ParameterType::TENSOR: {
if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
return true;
}
if (allow_numbers_as_tensors) {
return THPUtils_checkScalar(obj);
}
return false;
}
case ParameterType::SCALAR:
if (THPUtils_checkScalar(obj)) {
return true;
}
[[fallthrough]];
case ParameterType::COMPLEX:
if (PyComplex_Check(obj)) {
return true;
}
[[fallthrough]];
case ParameterType::DOUBLE: {
if (THPUtils_checkDouble(obj)) {
return true;
}
if (THPVariable_Check(obj)) {
const auto& var = THPVariable_Unpack(obj);
return !var.requires_grad() && var.dim() == 0;
}
if (torch::is_symfloat(py::handle(obj))) {
// This will induce a guard
return true;
}
return false;
}
case ParameterType::INT64: {
if (THPUtils_checkLong(obj)) {
return true;
}
if (THPVariable_Check(obj)) {
const auto& var = THPVariable_Unpack(obj);
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
!var.requires_grad() && var.dim() == 0;
}
if (torch::is_symint(py::handle(obj))) {
// This will induce a guard
return true;
}
return false;
}
case ParameterType::DIMNAME:
return THPUtils_checkDimname(obj);
case ParameterType::DIMNAME_LIST: {
if (THPUtils_checkDimnameList(obj)) {
return true;
}
// if a size is specified (e.g. DimnameList[1]) we also allow passing a
// single Dimname
return size == 1 && THPUtils_checkDimname(obj);
}
case ParameterType::TENSOR_LIST: {
return is_tensor_list_and_append_overloaded(
obj, &overloaded_args, argnum, true /* throw_error */);
}
case ParameterType::INT_LIST:
return is_int_list(obj, size, failed_idx);
case ParameterType::FLOAT_LIST:
return is_float_or_complex_list(obj);
case ParameterType::GENERATOR:
return THPGenerator_Check(obj);
case ParameterType::BOOL:
return PyBool_Check(obj);
case ParameterType::STORAGE:
return isStorage(obj);
case ParameterType::PYOBJECT:
return true;
case ParameterType::SCALARTYPE:
return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
case ParameterType::LAYOUT:
return THPLayout_Check(obj);
case ParameterType::MEMORY_FORMAT:
return THPMemoryFormat_Check(obj);
case ParameterType::QSCHEME:
return THPQScheme_Check(obj);
case ParameterType::DEVICE:
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
THPDevice_Check(obj);
case ParameterType::STREAM:
return THPStream_Check(obj);
case ParameterType::STRING:
return THPUtils_checkString(obj);
case ParameterType::SCALAR_LIST:
return is_scalar_list(obj);
case ParameterType::SYM_INT:
return is_int_or_symint(obj);
case ParameterType::SYM_INT_LIST:
return is_int_or_symint_list(obj, size, failed_idx);
default:
throw std::runtime_error("unknown parameter type");
}
}
// WARNING: these strings are parsed invalid_arguments.cpp
std::string FunctionParameter::type_name() const {
switch (type_) {
case ParameterType::TENSOR:
return "Tensor";
case ParameterType::SCALAR:
return "Number";
case ParameterType::INT64:
// NB: SymInt is intentionally not mentioned here, as conventional user
// use will only know about ints
case ParameterType::SYM_INT:
return "int";
case ParameterType::DOUBLE:
return "float";
case ParameterType::COMPLEX:
return "complex";
case ParameterType::TENSOR_LIST:
return "tuple of Tensors";
case ParameterType::INT_LIST:
return "tuple of ints";
case ParameterType::FLOAT_LIST:
return "tuple of floats";
case ParameterType::GENERATOR:
return "torch.Generator";
case ParameterType::BOOL:
return "bool";
case ParameterType::STORAGE:
return "torch.Storage";
case ParameterType::PYOBJECT:
return "object";
case ParameterType::SCALARTYPE:
return "torch.dtype";
case ParameterType::LAYOUT:
return "torch.layout";
case ParameterType::MEMORY_FORMAT:
return "torch.memory_format";
case ParameterType::QSCHEME:
return "torch.qscheme";
case ParameterType::DEVICE:
return "torch.device";
case ParameterType::STRING:
return "str";
case ParameterType::DIMNAME:
return "name";
case ParameterType::DIMNAME_LIST:
return "tuple of names";
case ParameterType::SCALAR_LIST:
return "tuple of Scalars";
case ParameterType::SYM_INT_LIST:
return "tuple of ints";
default:
throw std::runtime_error("unknown parameter type");
}
}
static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
if (s.empty())
return c10::nullopt;
char* str_end = nullptr;
long ans = strtol(s.c_str(), &str_end, 0);
// *str_end == 0 if the entire string was parsed as an integer.
return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
}
/*
Parse default value of IntArrayRef declared at native_functions.yaml
There are two kinds of default values:
1. IntArrayRef[2] x=1 (where size=2, value={1,1}
2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be
space after comma since native_parse.py uses ', ' to split args)
*/
static inline std::vector<int64_t> parse_intlist_args(
const std::string& s,
int64_t size) {
size_t n = s.size();
if (s.empty())
return std::vector<int64_t>();
// case 1. s is an int (e.g., s=2)
if (s[0] != '{') {
TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size);
return std::vector<int64_t>(size, std::stol(s));
}
// case 2. s is a list of dims (e.g., s={1,2})
// since already checked left brace '{' above, here only checks right brace
// '}'
TORCH_CHECK(
s[n - 1] == '}',
"Default value of IntArrayRef is missing right brace '}', found ",
s[n - 1]);
auto args = std::vector<int64_t>();
std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
std::string tok;
while (std::getline(ss, tok, ',')) {
args.emplace_back(std::stol(tok));
}
return args;
}
// Parse a string literal to remove quotes and escape sequences
static std::string parse_string_literal(c10::string_view str) {
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
if (str.front() == '"') {
TORCH_CHECK(
str.back() == '"', "Mismatched quotes in string default: ", str);
} else {
TORCH_CHECK(
str.front() == '\'' && str.back() == '\'',
"Invalid quotes in string default: ",
str)
}
std::string parsed;
parsed.reserve(str.size());
for (size_t i = 1; i < str.size() - 1;) {
if (str[i] != '\\') {
parsed.push_back(str[i]);
++i;
continue;
}
// Handle escape sequences
TORCH_CHECK(
i < str.size() - 2, "String ends with escaped final quote: ", str)
char c = str[i + 1];
switch (c) {
case '\\':
case '\'':
case '\"':
break;
case 'a':
c = '\a';
break;
case 'b':
c = '\b';
break;
case 'f':
c = '\f';
break;
case 'n':
c = '\n';
break;
case 'v':
c = '\v';
break;
case 't':
c = '\t';
break;
default:
TORCH_CHECK(
false,
"Unsupported escape sequence in string default: \\",
str[i + 1]);
}
parsed.push_back(c);
i += 2;
}
return parsed;
}
void FunctionParameter::set_default_str(const std::string& str) {
if (str == "None") {
allow_none = true;
}
if (type_ == ParameterType::TENSOR) {
if (str != "None") {
throw std::runtime_error(
"default value for Tensor must be none, got: " + str);
}
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
default_int = atol(str.c_str());
} else if (type_ == ParameterType::BOOL) {
default_bool = (str == "True" || str == "true");
} else if (type_ == ParameterType::DOUBLE) {
default_double = atof(str.c_str());
} else if (type_ == ParameterType::COMPLEX) {
default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
default_complex[1] = 0;
} else if (type_ == ParameterType::SCALAR) {
if (str != "None") {
// we sometimes rely on integer-vs-float values, e.g. with arange.
const auto as_integer = parse_as_integer(str);
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value())
: at::Scalar(atof(str.c_str()));
}
} else if (
type_ == ParameterType::INT_LIST ||
type_ == ParameterType::SYM_INT_LIST) {
if (str != "None") {
default_intlist = parse_intlist_args(str, size);
}
} else if (type_ == ParameterType::FLOAT_LIST) {
if (str != "None") {
throw std::runtime_error("Defaults not supported for float[]");
}
} else if (type_ == ParameterType::SCALARTYPE) {
if (str == "None") {
default_scalartype = at::ScalarType::Undefined;
} else if (str == "torch.int64") {
default_scalartype = at::ScalarType::Long;
} else {
throw std::runtime_error("invalid default value for ScalarType: " + str);
}
} else if (type_ == ParameterType::LAYOUT) {
if (str == "None") {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
} else if (str == "torch.strided") {
default_layout = at::Layout::Strided;
} else if (str == "torch.sparse_coo") {
default_layout = at::Layout::Sparse;
} else {
throw std::runtime_error("invalid default value for layout: " + str);
}
} else if (type_ == ParameterType::DEVICE) {
if (str != "None") {
throw std::runtime_error("invalid device: " + str);
}
} else if (type_ == ParameterType::STREAM) {
if (str != "None") {
throw std::runtime_error("invalid stream: " + str);
}
} else if (type_ == ParameterType::STRING) {
if (str != "None") {
default_string = parse_string_literal(str);
}
}
// These types weren't handled here before. Adding a default error
// led to a lot of test failures so adding this skip for now.
// We should correctly handle these though because it might be causing
// silent failures.
else if (type_ == ParameterType::TENSOR_LIST) { // NOLINT
// throw std::runtime_error("Invalid Tensor List");
} else if (type_ == ParameterType::GENERATOR) { // NOLINT
// throw std::runtime_error("ParameterType::GENERATOR");
} else if (type_ == ParameterType::PYOBJECT) { // NOLINT
// throw std::runtime_error("ParameterType::PYOBJECT");
} else if (type_ == ParameterType::MEMORY_FORMAT) { // NOLINT
// throw std::runtime_error("ParameterType::MEMORY_FORMAT");
} else if (type_ == ParameterType::DIMNAME) { // NOLINT
// throw std::runtime_error("ParameterType::DIMNAME");
} else if (type_ == ParameterType::DIMNAME_LIST) { // NOLINT
// throw std::runtime_error("ParameterType::DIMNAME_LIST");
} else if (type_ == ParameterType::SCALAR_LIST) { // NOLINT
// throw std::runtime_error("ParameterType::SCALAR_LIST");
} else if (type_ == ParameterType::STORAGE) { // NOLINT
// throw std::runtime_error("ParameterType::STORAGE");
} else if (type_ == ParameterType::QSCHEME) { // NOLINT
// throw std::runtime_error("ParameterType::QSCHEME");
} else {
throw std::runtime_error("unknown parameter type");
}
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
: min_args(0),
max_args(0),
max_pos_args(0),
index(index),
hidden(false),
deprecated(false) {
auto open_paren = fmt.find('(');
if (open_paren == std::string::npos) {
throw std::runtime_error("missing opening parenthesis: " + fmt);
}
name = fmt.substr(0, open_paren);
bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
auto last_offset = open_paren + 1;
bool keyword_only = false;
bool done = false;
while (!done) {
auto offset = fmt.find(", ", last_offset);
auto next_offset = offset + 2;
if (offset == std::string::npos) {
offset = fmt.find(')', last_offset);
done = true;
next_offset = offset + 1;
// this 'if' happens for an empty parameter list, i.e. fn().
if (offset == last_offset) {
last_offset = next_offset;
break;
}
}
if (offset == std::string::npos) {
throw std::runtime_error("missing closing parenthesis: " + fmt);
}
if (offset == last_offset) {
throw std::runtime_error("malformed signature: " + fmt);
}
auto param_str = fmt.substr(last_offset, offset - last_offset);
last_offset = next_offset;
if (param_str == "*") {
keyword_only = true;
} else {
params.emplace_back(param_str, keyword_only);
params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
}
}
if (fmt.substr(last_offset) == "|deprecated") {
hidden = true;
// TODO: raise warning when parsing deprecated signatures
deprecated = true;
} else if (fmt.substr(last_offset) == "|hidden") {
hidden = true;
}
max_args = params.size();
// count the number of non-optional args
for (auto& param : params) {
if (!param.optional) {
min_args++;
}
if (!param.keyword_only) {
max_pos_args++;
}
}
}
std::string FunctionSignature::toString() const {
// TODO: consider printing more proper schema strings with defaults,
// optionals, etc.
std::ostringstream ss;
bool keyword_already = false;
ss << "(";
int i = 0;
for (auto& param : params) {
if (i != 0) {
ss << ", ";
}
if (param.keyword_only && !keyword_already) {
ss << "*, ";
keyword_already = true;
}
ss << param.type_name() << " " << param.name;
i++;
}
ss << ")";
return ss.str();
}
[[noreturn]] static void extra_args(
const FunctionSignature& signature,
Py_ssize_t nargs) {
const auto max_pos_args = signature.max_pos_args;
const auto min_args = signature.min_args;
const long nargs_ = nargs;
if (min_args != max_pos_args) {
throw TypeError(
"%s() takes from %zu to %zu positional arguments but %ld were given",
signature.name.c_str(),
min_args,
max_pos_args,
nargs_);
}
throw TypeError(
"%s() takes %zu positional argument%s but %ld %s given",
signature.name.c_str(),
max_pos_args,
max_pos_args == 1 ? "" : "s",
nargs_,
nargs == 1 ? "was" : "were");
}
[[noreturn]] static void missing_args(
const FunctionSignature& signature,
int idx) {
int num_missing = 0;
std::stringstream ss;
auto& params = signature.params;
for (auto it = params.begin() + idx; it != params.end(); ++it) {
if (!it->optional) {
if (num_missing > 0) {
ss << ", ";
}
ss << '"' << it->name << '"';
num_missing++;
}
}
throw TypeError(
"%s() missing %d required positional argument%s: %s",
signature.name.c_str(),
num_missing,
num_missing == 1 ? "s" : "",
ss.str().c_str());
}
static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) {
Py_ssize_t i = 0;
for (auto& param : signature.params) {
int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
if (cmp < 0) {
throw python_error();
} else if (cmp) {
return i;
}
i++;
}
return -1;
}
[[noreturn]] static void extra_kwargs(
FunctionSignature& signature,
PyObject* kwargs,
Py_ssize_t num_pos_args) {
PyObject* key = nullptr;
PyObject* value = nullptr;
Py_ssize_t pos = 0;
while (PyDict_Next(kwargs, &pos, &key, &value)) {
if (!THPUtils_checkString(key)) {
throw TypeError("keywords must be strings");
}
auto param_idx = find_param(signature, key);
if (param_idx < 0) {
throw TypeError(
"%s() got an unexpected keyword argument '%s'",
signature.name.c_str(),
THPUtils_unpackString(key).c_str());
}
if (param_idx < num_pos_args) {
throw TypeError(
"%s() got multiple values for argument '%s'",
signature.name.c_str(),
THPUtils_unpackString(key).c_str());
}
}
// this should never be hit
throw TypeError("invalid keyword arguments");
}
bool FunctionSignature::parse(
PyObject* self,
PyObject* args,
PyObject* kwargs,
PyObject* dst[], // NOLINT
std::vector<PyObject*>& overloaded_args,
bool raise_exception) {
Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
size_t arg_pos = 0;
bool allow_varargs_intlist = false;
// if there is a single positional IntArrayRef argument, i.e. expand(..),
// view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
// expand((5,3))
int int_list_overload = false;
if (max_pos_args == 1 &&
(params[0].type_ == ParameterType::INT_LIST ||
params[0].type_ == ParameterType::SYM_INT_LIST)) {
allow_varargs_intlist = true;
if (params[0].type_ == ParameterType::INT_LIST) {
int_list_overload = true;
}
}
if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
if (raise_exception) {
// foo() takes takes 2 positional arguments but 3 were given
extra_args(*this, nargs);
}
return false;
}
int i = 0;
if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) {
append_overloaded_tensor(&overloaded_args, self);
}
for (auto& param : params) {
PyObject* obj = nullptr;
bool is_kwd = false;
if (arg_pos < static_cast<size_t>(nargs)) {
// extra positional args given after single positional IntArrayRef arg
if (param.keyword_only) {
if (raise_exception) {
extra_args(*this, nargs);
}
return false;
}
obj = PyTuple_GET_ITEM(args, arg_pos);
} else if (kwargs) {
obj = PyDict_GetItem(kwargs, param.python_name);
for (PyObject* numpy_name : param.numpy_python_names) {
if (obj) {
break;
}
obj = PyDict_GetItem(kwargs, numpy_name);
}
is_kwd = true;
}
int64_t failed_idx = -1;
bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
dst[i++] = nullptr;
} else if (!obj) {
if (raise_exception) {
// foo() missing 1 required positional argument: "b"
missing_args(*this, i);
}
return false;
} else if (param.check(obj, overloaded_args, i, &failed_idx)) {
dst[i++] = obj;
// XXX: the Variable check is necessary because sizes become tensors when
// tracer is enabled. This behavior easily leads to ambiguities, and we
// should avoid having complex signatures that make use of it...
} else if (
varargs_eligible &&
((int_list_overload
? is_int_list(args, param.size, &failed_idx)
: is_int_or_symint_list(args, param.size, &failed_idx)))) {
// take all positional arguments as this parameter
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
dst[i++] = args;
arg_pos = nargs;
continue;
} else if (raise_exception) {
if (is_kwd) {
// foo(): argument 'other' must be str, not int
throw TypeError(
"%s(): argument '%s' must be %s, not %s",
name.c_str(),
param.name.c_str(),
param.type_name().c_str(),
Py_TYPE(obj)->tp_name);
} else {
// foo(): argument 'other' (position 2) must be str, not int
if (failed_idx != -1) {
if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
TORCH_INTERNAL_ASSERT(varargs_eligible);
obj = args;
}
TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
throw TypeError(
"%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
name.c_str(),
param.name.c_str(),
static_cast<long>(arg_pos + 1),
param.type_name().c_str(),
Py_TYPE(py::reinterpret_steal<py::object>(
PySequence_GetItem(obj, failed_idx))
.ptr())
->tp_name,
static_cast<long>(failed_idx));
}
throw TypeError(
"%s(): argument '%s' (position %ld) must be %s, not %s",
name.c_str(),
param.name.c_str(),
static_cast<long>(arg_pos + 1),
param.type_name().c_str(),
Py_TYPE(obj)->tp_name);
}
} else {
return false;
}
if (!is_kwd) {
arg_pos++;
} else if (obj) {
remaining_kwargs--;
}
}
if (remaining_kwargs > 0) {
if (raise_exception) {
// foo() got an unexpected keyword argument "b"
extra_kwargs(*this, kwargs, nargs);
}
return false;
}
return true;
}
PythonArgParser::PythonArgParser(
const std::vector<std::string>& fmts,
bool traceable)
: max_args(0), traceable(traceable) {
int index = 0;
for (auto& fmt : fmts) {
signatures_.emplace_back(fmt, index);
++index;
}
for (auto& signature : signatures_) {
if (signature.max_args > max_args) {
max_args = signature.max_args;
}
}
if (!signatures_.empty()) {
function_name = signatures_[0].name;
}
// Check deprecated signatures last
std::stable_partition(
signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) {
return !sig.deprecated;
});
}
void PythonArgParser::check_deprecated(const FunctionSignature& signature) {
if (signature.deprecated) {
auto msg = c10::str(
"This overload of ",
signature.name,
" is deprecated:\n\t",
signature.name,
signature.toString());
auto signatures = get_signatures();
if (!signatures.empty()) {
msg += "\nConsider using one of the following signatures instead:";
for (const auto& sig : signatures) {
msg += "\n\t";
msg += signature.name;
msg += sig;
}
}
TORCH_WARN_ONCE(msg);
}
}
PythonArgs PythonArgParser::raw_parse(
PyObject* self,
PyObject* args,
PyObject* kwargs,
PyObject* parsed_args[]) { // NOLINT
if (signatures_.size() == 1) {
auto& signature = signatures_[0];
std::vector<PyObject*> overloaded_args;
signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
check_deprecated(signature);
return PythonArgs(
traceable, signature, parsed_args, std::move(overloaded_args));
}
for (auto& signature : signatures_) {
std::vector<PyObject*> overloaded_args;
if (signature.parse(
self, args, kwargs, parsed_args, overloaded_args, false)) {
check_deprecated(signature);
return PythonArgs(
traceable, signature, parsed_args, std::move(overloaded_args));
}
}
print_error(self, args, kwargs, parsed_args);
}
void PythonArgParser::print_error(
PyObject* self,
PyObject* args,
PyObject* kwargs,
PyObject* parsed_args[]) { // NOLINT
size_t num_args =
(args ? PyTuple_GET_SIZE(args) : 0) + (kwargs ? PyDict_Size(kwargs) : 0);
std::vector<unsigned> plausible_idxs;
unsigned i = 0;
for (auto& signature : signatures_) {
if (num_args >= signature.min_args && num_args <= signature.max_args &&
!signature.hidden) {
plausible_idxs.push_back(i);
}
i++;
}
if (plausible_idxs.size() == 1) {
auto& signature = signatures_[plausible_idxs[0]];
std::vector<PyObject*> overloaded_args;
signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
}
auto options = get_signatures();
auto msg =
torch::format_invalid_args(args, kwargs, function_name + "()", options);
throw TypeError("%s", msg.c_str());
}
std::vector<std::string> PythonArgParser::get_signatures() const {
std::vector<std::string> options;
for (auto& signature : signatures_) {
if (!signature.hidden) {
options.push_back(signature.toString());
}
}
return options;
}
at::Tensor PythonArgs::tensor_slow(int i) {
PyObject* obj = args[i];
if (!obj) {
return at::Tensor();
}
if (THPVariable_Check(obj)) {
return THPVariable_Unpack(obj);
}
bool save_symint = false;
at::Scalar scalar;
if (PyBool_Check(obj)) {
scalar = at::Scalar(THPUtils_unpackBool(obj));
} else if (THPUtils_checkLong(obj)) {
scalar = at::Scalar(THPUtils_unpackLong(obj));
} else if (PyComplex_Check(obj)) {
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
} else if (THPUtils_checkDouble(obj)) {
scalar = at::Scalar(THPUtils_unpackDouble(obj));
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
// because although Scalar supports SymInt/SymFloat, the subsequent
// conversion to Tensor does not. Instead, do it out of band.
} else if (torch::is_symint(py::handle(obj))) {
save_symint = true;
// This scalar value doesn't matter, it shouldn't ever actually
// get read out. Make it a big and weird looking number to help
// people figure out if there's aproblem.
scalar = at::Scalar(7777777);
} else if (torch::is_symfloat(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
} else if (torch::is_symbool(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(true);
} else {
// NB: Are you here because you passed None to a Variable method,
// and you expected an undefined tensor to be returned? Don't add
// a test for Py_None here; instead, you need to mark the argument
// as *allowing none*; you can do this by writing 'Tensor?' instead
// of 'Tensor' in the ATen metadata.
throw TypeError(
"expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name);
}
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
at::Tensor tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
if (save_symint) {
auto py_tensor = py::cast(tensor);
if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) {
throw python_error();
}
}
return tensor;
}
at::Scalar PythonArgs::scalar_slow(int i) {
if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
auto& var = THPVariable_Unpack(args[i]);
jit::tracer::ArgumentStash::stashValue(
signature.params[i].name, idx, var, c10::NumberType::get());
}
return scalar_slow(args[i]);
}
at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't
// currently handle most NumPy scalar types except np.float64.
if (THPVariable_Check(arg)) {
return THPVariable_Unpack(arg).item();
}
if (THPUtils_checkLong(arg)) {
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
}
if (PyBool_Check(arg)) {
return at::Scalar(THPUtils_unpackBool(arg));
}
if (PyComplex_Check(arg)) {
return at::Scalar(THPUtils_unpackComplexDouble(arg));
}
if (torch::is_symint(arg)) {
return at::Scalar(py::cast<c10::SymInt>(arg));
}
if (torch::is_symfloat(arg)) {
return at::Scalar(py::cast<c10::SymFloat>(arg));
}
if (torch::is_symbool(arg)) {
// Windows build fails with C2440: '<function-style-cast>'
// when at:Scalar(py::cast<c10::SymBool>(arg))
auto sym_bool = py::handle(arg).cast<c10::SymBool>();
return at::Scalar(sym_bool);
}
return at::Scalar(THPUtils_unpackDouble(arg));
}
} // namespace torch