Remove some Type.tensor usages and remove native_tensor without size. (#12355)

Summary:
This is to move us along the path to removing Type from the public API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12355

Reviewed By: ezyang

Differential Revision: D10212616

Pulled By: gchanan

fbshipit-source-id: c9cd128d1111ab219cb0b2f3bf5b632502ab97c0
This commit is contained in:
Gregory Chanan 2018-10-05 11:05:43 -07:00 committed by Facebook Github Bot
parent 9ebac3d7fe
commit 705d80b51e
14 changed files with 31 additions and 38 deletions

View File

@ -4,11 +4,11 @@
namespace at { namespace at {
namespace { namespace {
Backend sparseTensorIdToDenseBackend(TensorTypeId type_id) { DeviceType sparseTensorIdToDeviceType(TensorTypeId type_id) {
if (type_id == SparseCPUTensorId()) { if (type_id == SparseCPUTensorId()) {
return Backend::CPU; return kCPU;
} else if (type_id == SparseCUDATensorId()) { } else if (type_id == SparseCUDATensorId()) {
return Backend::CUDA; return kCUDA;
} else { } else {
AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id); AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id);
} }
@ -33,8 +33,8 @@ SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeM
, size_{0} , size_{0}
, sparseDims_(1) , sparseDims_(1)
, denseDims_(0) , denseDims_(0)
, indices_(globalContext().getNonVariableTypeOpt(sparseTensorIdToDenseBackend(type_id), ScalarType::Long)->tensor({1, 0})) , indices_(at::empty({1, 0}, TensorOptions(false).device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)))
, values_(globalContext().getNonVariableTypeOpt(sparseTensorIdToDenseBackend(type_id), dataTypeToScalarType(data_type.id()))->tensor()) {} , values_(at::empty({0}, TensorOptions(false).device(sparseTensorIdToDeviceType(type_id)).dtype(dataTypeToScalarType(data_type.id())))) {}
IntList SparseTensorImpl::sizes() const { IntList SparseTensorImpl::sizes() const {
return size_; return size_;

View File

@ -136,7 +136,7 @@ Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta
Tensor tensor(const Type& dtype) { Tensor tensor(const Type& dtype) {
if (_type_has_native(dtype)) { if (_type_has_native(dtype)) {
return at::getType(dtype.options()).native_tensor(); return at::getType(dtype.options()).native_tensor({0});
} else { } else {
return at::getType(dtype.options()).th_tensor(); return at::getType(dtype.options()).th_tensor();
} }

View File

@ -147,11 +147,11 @@ Tensor empty_like(const Tensor& self) {
Tensor empty_like(const Tensor& self, const TensorOptions& options) { Tensor empty_like(const Tensor& self, const TensorOptions& options) {
if (options.layout() == kSparse && self.type().is_sparse()) { if (options.layout() == kSparse && self.type().is_sparse()) {
auto res = native::empty({0}, options); // to be resized auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims()); res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims());
return res; return res;
} }
return native::empty(self.sizes(), options); return at::empty(self.sizes(), options);
} }
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -161,7 +161,7 @@ Tensor eye(int64_t n, const TensorOptions& options) {
} }
Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { Tensor eye(int64_t n, int64_t m, const TensorOptions& options) {
auto tensor = native::empty({0}, options); // to be resized auto tensor = at::empty({0}, options); // to be resized
return at::eye_out(tensor, n, m); return at::eye_out(tensor, n, m);
} }
@ -196,7 +196,7 @@ Tensor full(IntList size, Scalar fill_value, const TensorOptions& options) {
if (options.layout() == kSparse) { if (options.layout() == kSparse) {
AT_ERROR("full(...) is not implemented for sparse layout"); AT_ERROR("full(...) is not implemented for sparse layout");
} }
auto result = native::empty(size, options); auto result = at::empty(size, options);
return result.fill_(fill_value); return result.fill_(fill_value);
} }
@ -287,7 +287,7 @@ Tensor rand(IntList size, const TensorOptions& options) {
} }
Tensor rand(IntList size, Generator* generator, const TensorOptions& options) { Tensor rand(IntList size, Generator* generator, const TensorOptions& options) {
auto result = native::empty(size, options); auto result = at::empty(size, options);
return result.uniform_(0, 1, generator); return result.uniform_(0, 1, generator);
} }
@ -336,7 +336,7 @@ Tensor randint(
IntList size, IntList size,
Generator* generator, Generator* generator,
const TensorOptions& options) { const TensorOptions& options) {
auto result = native::empty(size, options); auto result = at::empty(size, options);
return result.random_(low, high, generator); return result.random_(low, high, generator);
} }
@ -397,7 +397,7 @@ Tensor randn(IntList size, const TensorOptions& options) {
} }
Tensor randn(IntList size, Generator* generator, const TensorOptions& options) { Tensor randn(IntList size, Generator* generator, const TensorOptions& options) {
auto result = native::empty(size, options); auto result = at::empty(size, options);
return result.normal_(0, 1, generator); return result.normal_(0, 1, generator);
} }
@ -454,7 +454,7 @@ Tensor randperm(int64_t n, const TensorOptions& options) {
} }
Tensor randperm(int64_t n, Generator* generator, const TensorOptions& options) { Tensor randperm(int64_t n, Generator* generator, const TensorOptions& options) {
auto tensor = native::empty(n, options); auto tensor = at::empty(n, options);
return at::randperm_out(tensor, n, generator); return at::randperm_out(tensor, n, generator);
} }
@ -499,7 +499,7 @@ Tensor& range_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor zeros(IntList size, const TensorOptions& options) { Tensor zeros(IntList size, const TensorOptions& options) {
auto result = native::empty(size, options); auto result = at::empty(size, options);
return result.zero_(); return result.zero_();
} }
@ -519,7 +519,7 @@ Tensor zeros_like(const Tensor& self) {
Tensor zeros_like(const Tensor& self, const TensorOptions& options) { Tensor zeros_like(const Tensor& self, const TensorOptions& options) {
if (options.layout() == kSparse && self.type().is_sparse()) { if (options.layout() == kSparse && self.type().is_sparse()) {
auto res = native::empty({0}, options); // to be resized auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims()); res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims());
return res; return res;
} }
@ -538,7 +538,7 @@ Tensor bartlett_window(
const TensorOptions& options) { const TensorOptions& options) {
window_function_checks("bartlett_window", options, window_length); window_function_checks("bartlett_window", options, window_length);
if (window_length == 0) { if (window_length == 0) {
return native::empty({0}, options); return at::empty({0}, options);
} }
if (window_length == 1) { if (window_length == 1) {
return native::ones({1}, options); return native::ones({1}, options);
@ -606,7 +606,7 @@ Tensor hamming_window(
const TensorOptions& options) { const TensorOptions& options) {
window_function_checks("hamming_window", options, window_length); window_function_checks("hamming_window", options, window_length);
if (window_length == 0) { if (window_length == 0) {
return native::empty({0}, options); return at::empty({0}, options);
} }
if (window_length == 1) { if (window_length == 1) {
return native::ones({1}, options); return native::ones({1}, options);

View File

@ -45,7 +45,7 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
result.resize_({n}); result.resize_({n});
if (result.type().scalarType() == at::ScalarType::Half) { if (result.type().scalarType() == at::ScalarType::Half) {
auto result_float = CUDA(kFloat).tensor({n}); auto result_float = at::empty({n}, TensorOptions(false).device(Device(DeviceType::CUDA)));
result.copy_(randperm_out_cuda(result_float, n, generator)); result.copy_(randperm_out_cuda(result_float, n, generator));
} else { } else {
if (n < 30000) { // For small inputs, we offload it to CPU instead. if (n < 30000) { // For small inputs, we offload it to CPU instead.

View File

@ -1895,13 +1895,6 @@
- func: addmm_(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - func: addmm_(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: method variants: method
- func: native_tensor(Type self_ty) -> Tensor
variants: []
dispatch:
SparseCPU: new_sparse
SparseCUDA: new_sparse
- func: native_tensor(Type self_ty, IntList size) -> Tensor - func: native_tensor(Type self_ty, IntList size) -> Tensor
variants: [] variants: []
dispatch: dispatch:

View File

@ -99,7 +99,7 @@ SparseTensor new_with_tensor_sparse(const LongTensor& indices, const Tensor& val
computed_indices_sizes.add_(1); // len = max_index + 1 computed_indices_sizes.add_(1); // len = max_index + 1
LongTensor cpu_computed_indices_sizes; LongTensor cpu_computed_indices_sizes;
if (computed_indices_sizes.is_cuda()) { if (computed_indices_sizes.is_cuda()) {
cpu_computed_indices_sizes = at::CPU(kLong).tensor(computed_indices_sizes.sizes()); cpu_computed_indices_sizes = at::empty(computed_indices_sizes.sizes(), TensorOptions(false).dtype(kLong));
cpu_computed_indices_sizes.copy_(computed_indices_sizes); cpu_computed_indices_sizes.copy_(computed_indices_sizes);
} else { } else {
cpu_computed_indices_sizes = computed_indices_sizes; cpu_computed_indices_sizes = computed_indices_sizes;

View File

@ -623,7 +623,7 @@ SparseTensor& hspmm_out_sparse_cpu(SparseTensor& r, const SparseTensor& sparse_,
return r; return r;
} }
LongTensor indices = at::CPU(kLong).tensor({1, nnz}); LongTensor indices = at::empty({1, nnz}, TensorOptions(false).dtype(kLong));
// Initialize the sparse matrix that will be used with spaddmm to send rows // Initialize the sparse matrix that will be used with spaddmm to send rows
// from the dense matrix to rows of the output's value tensor // from the dense matrix to rows of the output's value tensor
@ -715,7 +715,7 @@ SparseTensor& _sspaddmm_out_cpu(
int64_t t_nnz = t._nnz(); int64_t t_nnz = t._nnz();
int64_t r_nnz = nnz * dim_k + t_nnz; int64_t r_nnz = nnz * dim_k + t_nnz;
LongTensor newi = native::empty({2, r_nnz}, kLong); LongTensor newi = at::empty({2, r_nnz}, kLong);
LongTensor newv = native::zeros({r_nnz}, values.options()); LongTensor newv = native::zeros({r_nnz}, values.options());
if (t_nnz != 0) { if (t_nnz != 0) {

View File

@ -62,7 +62,7 @@ inline SparseTensor _new_with_dims_and_tensor_sparse(
ArrayRef<int64_t> sizes, ArrayRef<int64_t> sizes,
const LongTensor& indices, const LongTensor& indices,
const Tensor& values) { const Tensor& values) {
SparseTensor self = new_sparse(dtype); SparseTensor self = at::empty({0}, dtype.options());
_get_sparse_impl(self)->resize_(sparseDims, denseDims, sizes); _get_sparse_impl(self)->resize_(sparseDims, denseDims, sizes);
_alias_into_sparse(self, indices, values); _alias_into_sparse(self, indices, values);
return self; return self;

View File

@ -30,8 +30,8 @@ Tensor TypeDefault::copy(const Tensor & src, bool non_blocking, optional<Device>
} }
AT_CHECK(src.defined(), "attempt to copy an undefined tensor"); AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
Tensor r; Tensor r;
if (is_sparse()) r = this->native_tensor(); if (is_sparse()) r = this->native_tensor({0});
else r = this->tensor(src.sizes()); else r = at::empty(src.sizes(), this->options());
r.copy_(src, non_blocking); r.copy_(src, non_blocking);
return r; return r;
} }
@ -118,7 +118,7 @@ Storage TypeDefault::unsafeStorageFromTH(void * th_pointer, bool retain) const {
Tensor TypeDefault::scalarTensor(Scalar s) const { Tensor TypeDefault::scalarTensor(Scalar s) const {
return tensor({}).fill_(s); return at::empty({}, this->options()).fill_(s);
} }
${type_method_definitions} ${type_method_definitions}

View File

@ -40,7 +40,7 @@ void test(Type& type, IntList shape, int64_t a = 0, int64_t b = 1) {
auto a1 = at::empty({0}, type.options()); auto a1 = at::empty({0}, type.options());
auto a2 = at::empty({0}, type.options()); auto a2 = at::empty({0}, type.options());
auto a3 = at::empty({0}, type.options()); auto a3 = at::empty({0}, type.options());
auto a4 = CPU(kDouble).tensor(); auto a4 = at::empty({0}, at::TensorOptions(false).dtype(kDouble));
std::vector<Tensor> tensors({a0, a1, a2, a3, a4}); std::vector<Tensor> tensors({a0, a1, a2, a3, a4});
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {

View File

@ -94,7 +94,7 @@ TEST(atest, atest) {
if (at::hasCUDA()) { if (at::hasCUDA()) {
int isgone = 0; int isgone = 0;
{ {
auto base = CUDA(kFloat).tensor({1, 2, 3}); auto base = at::empty({1,2,3}, TensorOptions(false).device(kCUDA));
auto f2 = CUDA(kFloat).tensorFromBlob( auto f2 = CUDA(kFloat).tensorFromBlob(
base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; }); base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
} }

View File

@ -64,7 +64,7 @@ at::ScalarType DecoderBase::onnxTypeToATenType(onnx::TensorProto_DataType onnx_t
} }
at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) { at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) {
at::Tensor tensor = at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor(); at::Tensor tensor = at::empty({0}, at::TensorOptions(false).dtype(onnxTypeToATenType(tensor_proto.data_type())));
std::vector<int64_t> sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() }; std::vector<int64_t> sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() };
tensor.resize_(sizes); tensor.resize_(sizes);

View File

@ -521,7 +521,7 @@ struct ADTestSpec {
std::vector<Variable> make_vars() const { std::vector<Variable> make_vars() const {
std::vector<Variable> out; std::vector<Variable> out;
for (const auto & m : input_meta) { for (const auto & m : input_meta) {
out.emplace_back(autograd::make_variable(at::CPU(at::kFloat).tensor(m).normal_(), /*requires_grad=*/true)); out.emplace_back(autograd::make_variable(at::empty(m, at::TensorOptions(false)).normal_(), /*requires_grad=*/true));
} }
return out; return out;
} }

View File

@ -231,7 +231,7 @@ Tensor internal_new_from_data(const Type & type, at::optional<Device> device_opt
auto sizes = compute_sizes(data); auto sizes = compute_sizes(data);
ScalarType scalarType = type_inference ? infer_scalar_type(data) : type.scalarType(); ScalarType scalarType = type_inference ? infer_scalar_type(data) : type.scalarType();
auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), /*requires_grad=*/false); auto tensor = autograd::make_variable(at::empty(sizes, at::TensorOptions(false).dtype(scalarType)), /*requires_grad=*/false);
recursive_store( recursive_store(
(char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0, (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
scalarType, tensor.type().elementSizeInBytes(), data); scalarType, tensor.type().elementSizeInBytes(), data);