Enabled Scalar lists (#48222)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48222

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25074765

Pulled By: izdeby

fbshipit-source-id: 96ebe3c9907178c9338c03fb7993b2ecb26db8f4
This commit is contained in:
Iurii Zdebskyi 2020-12-11 15:46:00 -08:00 committed by Facebook GitHub Bot
parent bfce69d620
commit 5716b7db72
6 changed files with 56 additions and 21 deletions

View File

@ -119,14 +119,6 @@ namespace impl {
"You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
};
template<class T, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<std::vector<T>, AllowDeprecatedTypes>
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
static_assert(!std::is_same<T, at::Scalar>::value,
"You tried to register a kernel with an unsupported input type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector<T>. Please use List<T> instead.");
};
template<class T, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {

View File

@ -620,6 +620,8 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
return f'IntArrayRef[{size}]' if size is not None else 'IntArrayRef'
elif str(t.elem) == 'Tensor':
return f'TensorList[{size}]' if size is not None else 'TensorList'
elif str(t.elem) == 'Scalar':
return f'ScalarList[{size}]' if size is not None else 'ScalarList'
elif str(t.elem) == 'Tensor?':
if simple_type:
return 'TensorList'
@ -1063,7 +1065,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
return 'intlist'
elif str(t) == 'float[]':
return 'doublelist'
elif str(t) == 'Scalar[]':
return 'scalarlist'
raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser')
# Return RHS expression for python argument using PythonArgParser output.

View File

@ -49,6 +49,7 @@ TYPE_MAP = {
'std::string': 'str',
'std::string?': 'str?',
'Scalar': 'Scalar',
'ScalarList': 'Scalar[]',
'MemoryFormat': 'MemoryFormat',
'MemoryFormat?': 'MemoryFormat?',
'QScheme': 'QScheme',
@ -131,6 +132,7 @@ FROM_IVALUE = {
'Tensor?': 'toOptionalTensor({})',
'Tensor?[]': 'toListOfOptionalTensor({})',
'TensorList': '{}.toTensorVector()',
'ScalarList': '{}.toScalarVector()',
'bool': '{}.toBool()',
'bool?': '{}.toOptional<bool>()',
'double': '{}.toDouble()',

View File

@ -122,7 +122,6 @@ blocklist = [
'floor_divide', 'floor_divide_', 'floor_divide_out',
]
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
'matmul', 'floordiv',
'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic

View File

@ -39,6 +39,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"std::string", ParameterType::STRING},
{"Dimname", ParameterType::DIMNAME},
{"DimnameList", ParameterType::DIMNAME_LIST},
{"ScalarList", ParameterType::SCALAR_LIST},
};
// Default arg name translations for compatibility with NumPy.
@ -348,13 +349,28 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
return false;
}
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error) {
bool is_scalar_list(PyObject* obj) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;
}
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
for (size_t idx = 0; idx < size; idx++) {
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
if (!THPUtils_checkScalar(iobj)) {
return false;
}
}
return true;
}
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;
}
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
for (long idx = 0; idx < size; idx++) {
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
if (throw_error) {
@ -453,6 +469,9 @@ auto FunctionParameter::check(PyObject* obj, std::vector<py::handle> &overloaded
return THPStream_Check(obj);
case ParameterType::STRING: return THPUtils_checkString(obj);
default: throw std::runtime_error("unknown parameter type");
case ParameterType::SCALAR_LIST: {
return is_scalar_list(obj);
}
}
}
@ -478,6 +497,7 @@ std::string FunctionParameter::type_name() const {
case ParameterType::STRING: return "str";
case ParameterType::DIMNAME: return "name";
case ParameterType::DIMNAME_LIST: return "tuple of names";
case ParameterType::SCALAR_LIST: return "tuple of Scalars";
default: throw std::runtime_error("unknown parameter type");
}
}
@ -1055,24 +1075,28 @@ at::Scalar PythonArgs::scalar_slow(int i) {
signature.params[i].name, idx, var, jit::NumberType::get());
}
return scalar_slow(args[i]);
}
at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently
// handle most NumPy scalar types except np.float64.
if (THPVariable_Check(args[i])) {
return ((THPVariable*)args[i])->cdata.item();
if (THPVariable_Check(arg)) {
return ((THPVariable*)arg)->cdata.item();
}
if (THPUtils_checkLong(args[i])) {
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(args[i])));
if (THPUtils_checkLong(arg)) {
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
}
if (PyBool_Check(args[i])) {
return at::Scalar(THPUtils_unpackBool(args[i]));
if (PyBool_Check(arg)) {
return at::Scalar(THPUtils_unpackBool(arg));
}
if (PyComplex_Check(args[i])) {
return at::Scalar(THPUtils_unpackComplexDouble(args[i]));
if (PyComplex_Check(arg)) {
return at::Scalar(THPUtils_unpackComplexDouble(arg));
}
return at::Scalar(THPUtils_unpackDouble(args[i]));
return at::Scalar(THPUtils_unpackDouble(arg));
}
} // namespace torch

View File

@ -80,7 +80,7 @@ namespace torch {
enum class ParameterType {
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST
};
struct FunctionParameter;
@ -158,6 +158,7 @@ struct PythonArgs {
inline c10::optional<at::Tensor> optionalTensor(int i);
inline at::Scalar scalar(int i);
inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
inline std::vector<at::Scalar> scalarlist(int i);
inline std::vector<at::Tensor> tensorlist(int i);
template<int N>
inline std::array<at::Tensor, N> tensorlist_n(int i);
@ -206,6 +207,7 @@ struct PythonArgs {
private:
at::Tensor tensor_slow(int i);
at::Scalar scalar_slow(int i);
at::Scalar scalar_slow(PyObject* arg);
};
struct FunctionParameter {
@ -287,6 +289,19 @@ inline at::Scalar PythonArgs::scalar(int i) {
return scalar_slow(i);
}
inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) {
if (!args[i]) return std::vector<at::Scalar>();
auto tuple = six::isTuple(args[i]);
THPObjectPtr arg = six::maybeAsTuple(args[i]);
auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
std::vector<at::Scalar> res(size);
for (int idx = 0; idx < size; idx++) {
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
res[idx] = scalar_slow(obj);
}
return res;
}
inline at::Scalar PythonArgs::scalarWithDefault(int i, at::Scalar default_scalar) {
if (!args[i]) return default_scalar;
return scalar_slow(i);