mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enabled Scalar lists (#48222)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48222 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25074765 Pulled By: izdeby fbshipit-source-id: 96ebe3c9907178c9338c03fb7993b2ecb26db8f4
This commit is contained in:
parent
bfce69d620
commit
5716b7db72
|
|
@ -119,14 +119,6 @@ namespace impl {
|
|||
"You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
||||
};
|
||||
|
||||
template<class T, bool AllowDeprecatedTypes>
|
||||
struct assert_is_valid_input_type<std::vector<T>, AllowDeprecatedTypes>
|
||||
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
||||
static_assert(!std::is_same<T, at::Scalar>::value,
|
||||
"You tried to register a kernel with an unsupported input type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
|
||||
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector<T>. Please use List<T> instead.");
|
||||
};
|
||||
|
||||
template<class T, bool AllowDeprecatedTypes>
|
||||
struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
|
||||
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
|
||||
|
|
|
|||
|
|
@ -620,6 +620,8 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
|
|||
return f'IntArrayRef[{size}]' if size is not None else 'IntArrayRef'
|
||||
elif str(t.elem) == 'Tensor':
|
||||
return f'TensorList[{size}]' if size is not None else 'TensorList'
|
||||
elif str(t.elem) == 'Scalar':
|
||||
return f'ScalarList[{size}]' if size is not None else 'ScalarList'
|
||||
elif str(t.elem) == 'Tensor?':
|
||||
if simple_type:
|
||||
return 'TensorList'
|
||||
|
|
@ -1063,7 +1065,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
|
|||
return 'intlist'
|
||||
elif str(t) == 'float[]':
|
||||
return 'doublelist'
|
||||
|
||||
elif str(t) == 'Scalar[]':
|
||||
return 'scalarlist'
|
||||
raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser')
|
||||
|
||||
# Return RHS expression for python argument using PythonArgParser output.
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ TYPE_MAP = {
|
|||
'std::string': 'str',
|
||||
'std::string?': 'str?',
|
||||
'Scalar': 'Scalar',
|
||||
'ScalarList': 'Scalar[]',
|
||||
'MemoryFormat': 'MemoryFormat',
|
||||
'MemoryFormat?': 'MemoryFormat?',
|
||||
'QScheme': 'QScheme',
|
||||
|
|
@ -131,6 +132,7 @@ FROM_IVALUE = {
|
|||
'Tensor?': 'toOptionalTensor({})',
|
||||
'Tensor?[]': 'toListOfOptionalTensor({})',
|
||||
'TensorList': '{}.toTensorVector()',
|
||||
'ScalarList': '{}.toScalarVector()',
|
||||
'bool': '{}.toBool()',
|
||||
'bool?': '{}.toOptional<bool>()',
|
||||
'double': '{}.toDouble()',
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ blocklist = [
|
|||
'floor_divide', 'floor_divide_', 'floor_divide_out',
|
||||
]
|
||||
|
||||
|
||||
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
|
||||
'matmul', 'floordiv',
|
||||
'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
|||
{"std::string", ParameterType::STRING},
|
||||
{"Dimname", ParameterType::DIMNAME},
|
||||
{"DimnameList", ParameterType::DIMNAME_LIST},
|
||||
{"ScalarList", ParameterType::SCALAR_LIST},
|
||||
};
|
||||
|
||||
// Default arg name translations for compatibility with NumPy.
|
||||
|
|
@ -348,13 +349,28 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
|
|||
return false;
|
||||
}
|
||||
|
||||
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error) {
|
||||
bool is_scalar_list(PyObject* obj) {
|
||||
auto tuple = six::isTuple(obj);
|
||||
if (!(tuple || PyList_Check(obj))) {
|
||||
return false;
|
||||
}
|
||||
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
||||
for (size_t idx = 0; idx < size; idx++) {
|
||||
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
|
||||
if (!THPUtils_checkScalar(iobj)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args, int argnum, bool throw_error) {
|
||||
auto tuple = six::isTuple(obj);
|
||||
if (!(tuple || PyList_Check(obj))) {
|
||||
return false;
|
||||
}
|
||||
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
||||
for (long idx = 0; idx < size; idx++) {
|
||||
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
|
||||
if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
|
||||
if (throw_error) {
|
||||
|
|
@ -453,6 +469,9 @@ auto FunctionParameter::check(PyObject* obj, std::vector<py::handle> &overloaded
|
|||
return THPStream_Check(obj);
|
||||
case ParameterType::STRING: return THPUtils_checkString(obj);
|
||||
default: throw std::runtime_error("unknown parameter type");
|
||||
case ParameterType::SCALAR_LIST: {
|
||||
return is_scalar_list(obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -478,6 +497,7 @@ std::string FunctionParameter::type_name() const {
|
|||
case ParameterType::STRING: return "str";
|
||||
case ParameterType::DIMNAME: return "name";
|
||||
case ParameterType::DIMNAME_LIST: return "tuple of names";
|
||||
case ParameterType::SCALAR_LIST: return "tuple of Scalars";
|
||||
default: throw std::runtime_error("unknown parameter type");
|
||||
}
|
||||
}
|
||||
|
|
@ -1055,24 +1075,28 @@ at::Scalar PythonArgs::scalar_slow(int i) {
|
|||
signature.params[i].name, idx, var, jit::NumberType::get());
|
||||
}
|
||||
|
||||
return scalar_slow(args[i]);
|
||||
}
|
||||
|
||||
at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
|
||||
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently
|
||||
// handle most NumPy scalar types except np.float64.
|
||||
if (THPVariable_Check(args[i])) {
|
||||
return ((THPVariable*)args[i])->cdata.item();
|
||||
if (THPVariable_Check(arg)) {
|
||||
return ((THPVariable*)arg)->cdata.item();
|
||||
}
|
||||
|
||||
if (THPUtils_checkLong(args[i])) {
|
||||
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(args[i])));
|
||||
if (THPUtils_checkLong(arg)) {
|
||||
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
|
||||
}
|
||||
|
||||
if (PyBool_Check(args[i])) {
|
||||
return at::Scalar(THPUtils_unpackBool(args[i]));
|
||||
if (PyBool_Check(arg)) {
|
||||
return at::Scalar(THPUtils_unpackBool(arg));
|
||||
}
|
||||
|
||||
if (PyComplex_Check(args[i])) {
|
||||
return at::Scalar(THPUtils_unpackComplexDouble(args[i]));
|
||||
if (PyComplex_Check(arg)) {
|
||||
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
||||
}
|
||||
return at::Scalar(THPUtils_unpackDouble(args[i]));
|
||||
return at::Scalar(THPUtils_unpackDouble(arg));
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ namespace torch {
|
|||
enum class ParameterType {
|
||||
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
||||
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
|
||||
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
|
||||
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST
|
||||
};
|
||||
|
||||
struct FunctionParameter;
|
||||
|
|
@ -158,6 +158,7 @@ struct PythonArgs {
|
|||
inline c10::optional<at::Tensor> optionalTensor(int i);
|
||||
inline at::Scalar scalar(int i);
|
||||
inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
||||
inline std::vector<at::Scalar> scalarlist(int i);
|
||||
inline std::vector<at::Tensor> tensorlist(int i);
|
||||
template<int N>
|
||||
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
||||
|
|
@ -206,6 +207,7 @@ struct PythonArgs {
|
|||
private:
|
||||
at::Tensor tensor_slow(int i);
|
||||
at::Scalar scalar_slow(int i);
|
||||
at::Scalar scalar_slow(PyObject* arg);
|
||||
};
|
||||
|
||||
struct FunctionParameter {
|
||||
|
|
@ -287,6 +289,19 @@ inline at::Scalar PythonArgs::scalar(int i) {
|
|||
return scalar_slow(i);
|
||||
}
|
||||
|
||||
inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) {
|
||||
if (!args[i]) return std::vector<at::Scalar>();
|
||||
auto tuple = six::isTuple(args[i]);
|
||||
THPObjectPtr arg = six::maybeAsTuple(args[i]);
|
||||
auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
|
||||
std::vector<at::Scalar> res(size);
|
||||
for (int idx = 0; idx < size; idx++) {
|
||||
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
|
||||
res[idx] = scalar_slow(obj);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
inline at::Scalar PythonArgs::scalarWithDefault(int i, at::Scalar default_scalar) {
|
||||
if (!args[i]) return default_scalar;
|
||||
return scalar_slow(i);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user