mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Port new_full to ATen. (#25583)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25583 Following the game plan from https://github.com/pytorch/pytorch/pull/25475 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D17183438 Pulled By: ezyang fbshipit-source-id: 67bd98206f349ddf5ffdd7be0c16e45418c1b1cd
This commit is contained in:
parent
3d9c419648
commit
2e1a5cb80e
|
|
@ -446,6 +446,7 @@ class CAFFE2_API Tensor {
|
|||
Tensor & div_(Scalar other) const;
|
||||
Tensor dot(const Tensor & tensor) const;
|
||||
Tensor new_empty(IntArrayRef size, const TensorOptions & options={}) const;
|
||||
Tensor new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options={}) const;
|
||||
Tensor & resize_(IntArrayRef size) const;
|
||||
Tensor erf() const;
|
||||
Tensor & erf_() const;
|
||||
|
|
|
|||
|
|
@ -760,6 +760,14 @@ inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options)
|
|||
return table->getOp<Tensor (const Tensor &, IntArrayRef, const TensorOptions &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast<Tensor&>(*this), size, options);
|
||||
#endif
|
||||
}
|
||||
inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options) const {
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
return TypeDefault::new_full(const_cast<Tensor&>(*this), size, fill_value, options);
|
||||
#else
|
||||
static auto table = globalATenDispatch().getOpTable("aten::new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor");
|
||||
return table->getOp<Tensor (const Tensor &, IntArrayRef, Scalar, const TensorOptions &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast<Tensor&>(*this), size, fill_value, options);
|
||||
#endif
|
||||
}
|
||||
inline Tensor & Tensor::resize_(IntArrayRef size) const {
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
switch(tensorTypeIdToBackend(type_id())) {
|
||||
|
|
|
|||
|
|
@ -307,6 +307,16 @@ Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& opt
|
|||
return native::full(self.sizes(), fill_value, options);
|
||||
}
|
||||
|
||||
Tensor new_full(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
Scalar fill_value,
|
||||
const TensorOptions& options
|
||||
) {
|
||||
return at::full(size, fill_value, self.options().merge_in(options));
|
||||
}
|
||||
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linspace ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tensor linspace(
|
||||
|
|
|
|||
|
|
@ -813,6 +813,9 @@
|
|||
- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
|
||||
- func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: method
|
||||
|
||||
- func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
||||
dispatch:
|
||||
QuantizedCPU: empty_affine_quantized_cpu
|
||||
|
|
|
|||
|
|
@ -585,15 +585,6 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_new_full(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
OptionalDeviceGuard device_guard(device_of(self_));
|
||||
return THPVariable_Wrap(torch::utils::new_full(self_.type_id(), self_.scalar_type(), args, kwargs));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -782,7 +773,6 @@ PyMethodDef variable_methods[] = {
|
|||
{"ndimension", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},
|
||||
{"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL},
|
||||
{"new", (PyCFunction)THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"new_full", (PyCFunction)THPVariable_new_full, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"new_ones", (PyCFunction)THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"new_tensor", (PyCFunction)THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"new_zeros", (PyCFunction)THPVariable_new_zeros, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
|
|
|||
|
|
@ -491,8 +491,6 @@ def gen_pyi(declarations_path, out):
|
|||
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
|
||||
'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
|
||||
'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntArrayRef'), type_to_python('Scalar'), FACTORY_PARAMS)],
|
||||
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
||||
# clamp has no default values in the Declarations
|
||||
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
||||
|
|
|
|||
|
|
@ -638,21 +638,6 @@ Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObjec
|
|||
throw std::runtime_error("new_tensor(): invalid arguments");
|
||||
}
|
||||
|
||||
Tensor new_full(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
|
||||
static PythonArgParser parser({
|
||||
"new_full(IntArrayRef size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||
}, /*traceable=*/true);
|
||||
|
||||
ParsedArgs<5> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
const auto actual_type_id = typeIdWithDefault(r, 3, type_id);
|
||||
const auto actual_scalar_type = r.scalartypeWithDefault(2, scalar_type);
|
||||
return dispatch_full(actual_type_id, actual_scalar_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||
}
|
||||
throw std::runtime_error("new_full(): invalid arguments");
|
||||
}
|
||||
|
||||
Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
|
||||
static PythonArgParser parser({
|
||||
"new_ones(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ at::Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scal
|
|||
at::Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor new_full(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor new_zeros(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user