Port new_full to ATen. (#25583)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25583

Following the game plan from https://github.com/pytorch/pytorch/pull/25475

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D17183438

Pulled By: ezyang

fbshipit-source-id: 67bd98206f349ddf5ffdd7be0c16e45418c1b1cd
This commit is contained in:
Edward Yang 2019-09-04 14:29:18 -07:00 committed by Facebook Github Bot
parent 3d9c419648
commit 2e1a5cb80e
8 changed files with 22 additions and 28 deletions

View File

@ -446,6 +446,7 @@ class CAFFE2_API Tensor {
Tensor & div_(Scalar other) const;
Tensor dot(const Tensor & tensor) const;
Tensor new_empty(IntArrayRef size, const TensorOptions & options={}) const;
Tensor new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options={}) const;
Tensor & resize_(IntArrayRef size) const;
Tensor erf() const;
Tensor & erf_() const;

View File

@ -760,6 +760,14 @@ inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options)
return table->getOp<Tensor (const Tensor &, IntArrayRef, const TensorOptions &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast<Tensor&>(*this), size, options);
#endif
}
inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options) const {
#ifdef USE_STATIC_DISPATCH
return TypeDefault::new_full(const_cast<Tensor&>(*this), size, fill_value, options);
#else
static auto table = globalATenDispatch().getOpTable("aten::new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor");
return table->getOp<Tensor (const Tensor &, IntArrayRef, Scalar, const TensorOptions &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast<Tensor&>(*this), size, fill_value, options);
#endif
}
inline Tensor & Tensor::resize_(IntArrayRef size) const {
#ifdef USE_STATIC_DISPATCH
switch(tensorTypeIdToBackend(type_id())) {

View File

@ -307,6 +307,16 @@ Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& opt
return native::full(self.sizes(), fill_value, options);
}
Tensor new_full(
const Tensor& self,
IntArrayRef size,
Scalar fill_value,
const TensorOptions& options
) {
return at::full(size, fill_value, self.options().merge_in(options));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linspace ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor linspace(

View File

@ -813,6 +813,9 @@
- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
- func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
- func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
dispatch:
QuantizedCPU: empty_affine_quantized_cpu

View File

@ -585,15 +585,6 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_full(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
return THPVariable_Wrap(torch::utils::new_full(self_.type_id(), self_.scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -782,7 +773,6 @@ PyMethodDef variable_methods[] = {
{"ndimension", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},
{"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL},
{"new", (PyCFunction)THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_full", (PyCFunction)THPVariable_new_full, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_ones", (PyCFunction)THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_tensor", (PyCFunction)THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL},
{"new_zeros", (PyCFunction)THPVariable_new_zeros, METH_VARARGS | METH_KEYWORDS, NULL},

View File

@ -491,8 +491,6 @@ def gen_pyi(declarations_path, out):
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
format(type_to_python('IntArrayRef'), type_to_python('Scalar'), FACTORY_PARAMS)],
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
# clamp has no default values in the Declarations
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"

View File

@ -638,21 +638,6 @@ Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObjec
throw std::runtime_error("new_tensor(): invalid arguments");
}
Tensor new_full(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
"new_full(IntArrayRef size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
}, /*traceable=*/true);
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto actual_type_id = typeIdWithDefault(r, 3, type_id);
const auto actual_scalar_type = r.scalartypeWithDefault(2, scalar_type);
return dispatch_full(actual_type_id, actual_scalar_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
}
throw std::runtime_error("new_full(): invalid arguments");
}
Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
"new_ones(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",

View File

@ -17,7 +17,6 @@ at::Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scal
at::Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor new_full(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor new_zeros(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);