mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Remove some Type.tensor usages and remove native_tensor without size. (#12355)
Summary: This is to move us along the path to removing Type from the public API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12355 Reviewed By: ezyang Differential Revision: D10212616 Pulled By: gchanan fbshipit-source-id: c9cd128d1111ab219cb0b2f3bf5b632502ab97c0
This commit is contained in:
parent
9ebac3d7fe
commit
705d80b51e
|
|
@ -4,11 +4,11 @@
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Backend sparseTensorIdToDenseBackend(TensorTypeId type_id) {
|
DeviceType sparseTensorIdToDeviceType(TensorTypeId type_id) {
|
||||||
if (type_id == SparseCPUTensorId()) {
|
if (type_id == SparseCPUTensorId()) {
|
||||||
return Backend::CPU;
|
return kCPU;
|
||||||
} else if (type_id == SparseCUDATensorId()) {
|
} else if (type_id == SparseCUDATensorId()) {
|
||||||
return Backend::CUDA;
|
return kCUDA;
|
||||||
} else {
|
} else {
|
||||||
AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id);
|
AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id);
|
||||||
}
|
}
|
||||||
|
|
@ -33,8 +33,8 @@ SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeM
|
||||||
, size_{0}
|
, size_{0}
|
||||||
, sparseDims_(1)
|
, sparseDims_(1)
|
||||||
, denseDims_(0)
|
, denseDims_(0)
|
||||||
, indices_(globalContext().getNonVariableTypeOpt(sparseTensorIdToDenseBackend(type_id), ScalarType::Long)->tensor({1, 0}))
|
, indices_(at::empty({1, 0}, TensorOptions(false).device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)))
|
||||||
, values_(globalContext().getNonVariableTypeOpt(sparseTensorIdToDenseBackend(type_id), dataTypeToScalarType(data_type.id()))->tensor()) {}
|
, values_(at::empty({0}, TensorOptions(false).device(sparseTensorIdToDeviceType(type_id)).dtype(dataTypeToScalarType(data_type.id())))) {}
|
||||||
|
|
||||||
IntList SparseTensorImpl::sizes() const {
|
IntList SparseTensorImpl::sizes() const {
|
||||||
return size_;
|
return size_;
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta
|
||||||
|
|
||||||
Tensor tensor(const Type& dtype) {
|
Tensor tensor(const Type& dtype) {
|
||||||
if (_type_has_native(dtype)) {
|
if (_type_has_native(dtype)) {
|
||||||
return at::getType(dtype.options()).native_tensor();
|
return at::getType(dtype.options()).native_tensor({0});
|
||||||
} else {
|
} else {
|
||||||
return at::getType(dtype.options()).th_tensor();
|
return at::getType(dtype.options()).th_tensor();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -147,11 +147,11 @@ Tensor empty_like(const Tensor& self) {
|
||||||
|
|
||||||
Tensor empty_like(const Tensor& self, const TensorOptions& options) {
|
Tensor empty_like(const Tensor& self, const TensorOptions& options) {
|
||||||
if (options.layout() == kSparse && self.type().is_sparse()) {
|
if (options.layout() == kSparse && self.type().is_sparse()) {
|
||||||
auto res = native::empty({0}, options); // to be resized
|
auto res = at::empty({0}, options); // to be resized
|
||||||
res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims());
|
res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims());
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
return native::empty(self.sizes(), options);
|
return at::empty(self.sizes(), options);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
@ -161,7 +161,7 @@ Tensor eye(int64_t n, const TensorOptions& options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor eye(int64_t n, int64_t m, const TensorOptions& options) {
|
Tensor eye(int64_t n, int64_t m, const TensorOptions& options) {
|
||||||
auto tensor = native::empty({0}, options); // to be resized
|
auto tensor = at::empty({0}, options); // to be resized
|
||||||
return at::eye_out(tensor, n, m);
|
return at::eye_out(tensor, n, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -196,7 +196,7 @@ Tensor full(IntList size, Scalar fill_value, const TensorOptions& options) {
|
||||||
if (options.layout() == kSparse) {
|
if (options.layout() == kSparse) {
|
||||||
AT_ERROR("full(...) is not implemented for sparse layout");
|
AT_ERROR("full(...) is not implemented for sparse layout");
|
||||||
}
|
}
|
||||||
auto result = native::empty(size, options);
|
auto result = at::empty(size, options);
|
||||||
return result.fill_(fill_value);
|
return result.fill_(fill_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -287,7 +287,7 @@ Tensor rand(IntList size, const TensorOptions& options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor rand(IntList size, Generator* generator, const TensorOptions& options) {
|
Tensor rand(IntList size, Generator* generator, const TensorOptions& options) {
|
||||||
auto result = native::empty(size, options);
|
auto result = at::empty(size, options);
|
||||||
return result.uniform_(0, 1, generator);
|
return result.uniform_(0, 1, generator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,7 +336,7 @@ Tensor randint(
|
||||||
IntList size,
|
IntList size,
|
||||||
Generator* generator,
|
Generator* generator,
|
||||||
const TensorOptions& options) {
|
const TensorOptions& options) {
|
||||||
auto result = native::empty(size, options);
|
auto result = at::empty(size, options);
|
||||||
return result.random_(low, high, generator);
|
return result.random_(low, high, generator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -397,7 +397,7 @@ Tensor randn(IntList size, const TensorOptions& options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor randn(IntList size, Generator* generator, const TensorOptions& options) {
|
Tensor randn(IntList size, Generator* generator, const TensorOptions& options) {
|
||||||
auto result = native::empty(size, options);
|
auto result = at::empty(size, options);
|
||||||
return result.normal_(0, 1, generator);
|
return result.normal_(0, 1, generator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -454,7 +454,7 @@ Tensor randperm(int64_t n, const TensorOptions& options) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor randperm(int64_t n, Generator* generator, const TensorOptions& options) {
|
Tensor randperm(int64_t n, Generator* generator, const TensorOptions& options) {
|
||||||
auto tensor = native::empty(n, options);
|
auto tensor = at::empty(n, options);
|
||||||
return at::randperm_out(tensor, n, generator);
|
return at::randperm_out(tensor, n, generator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -499,7 +499,7 @@ Tensor& range_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
Tensor zeros(IntList size, const TensorOptions& options) {
|
Tensor zeros(IntList size, const TensorOptions& options) {
|
||||||
auto result = native::empty(size, options);
|
auto result = at::empty(size, options);
|
||||||
return result.zero_();
|
return result.zero_();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -519,7 +519,7 @@ Tensor zeros_like(const Tensor& self) {
|
||||||
|
|
||||||
Tensor zeros_like(const Tensor& self, const TensorOptions& options) {
|
Tensor zeros_like(const Tensor& self, const TensorOptions& options) {
|
||||||
if (options.layout() == kSparse && self.type().is_sparse()) {
|
if (options.layout() == kSparse && self.type().is_sparse()) {
|
||||||
auto res = native::empty({0}, options); // to be resized
|
auto res = at::empty({0}, options); // to be resized
|
||||||
res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims());
|
res.sparse_resize_and_clear_(self.sizes(), self._sparseDims(), self._denseDims());
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
@ -538,7 +538,7 @@ Tensor bartlett_window(
|
||||||
const TensorOptions& options) {
|
const TensorOptions& options) {
|
||||||
window_function_checks("bartlett_window", options, window_length);
|
window_function_checks("bartlett_window", options, window_length);
|
||||||
if (window_length == 0) {
|
if (window_length == 0) {
|
||||||
return native::empty({0}, options);
|
return at::empty({0}, options);
|
||||||
}
|
}
|
||||||
if (window_length == 1) {
|
if (window_length == 1) {
|
||||||
return native::ones({1}, options);
|
return native::ones({1}, options);
|
||||||
|
|
@ -606,7 +606,7 @@ Tensor hamming_window(
|
||||||
const TensorOptions& options) {
|
const TensorOptions& options) {
|
||||||
window_function_checks("hamming_window", options, window_length);
|
window_function_checks("hamming_window", options, window_length);
|
||||||
if (window_length == 0) {
|
if (window_length == 0) {
|
||||||
return native::empty({0}, options);
|
return at::empty({0}, options);
|
||||||
}
|
}
|
||||||
if (window_length == 1) {
|
if (window_length == 1) {
|
||||||
return native::ones({1}, options);
|
return native::ones({1}, options);
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
|
||||||
result.resize_({n});
|
result.resize_({n});
|
||||||
|
|
||||||
if (result.type().scalarType() == at::ScalarType::Half) {
|
if (result.type().scalarType() == at::ScalarType::Half) {
|
||||||
auto result_float = CUDA(kFloat).tensor({n});
|
auto result_float = at::empty({n}, TensorOptions(false).device(Device(DeviceType::CUDA)));
|
||||||
result.copy_(randperm_out_cuda(result_float, n, generator));
|
result.copy_(randperm_out_cuda(result_float, n, generator));
|
||||||
} else {
|
} else {
|
||||||
if (n < 30000) { // For small inputs, we offload it to CPU instead.
|
if (n < 30000) { // For small inputs, we offload it to CPU instead.
|
||||||
|
|
|
||||||
|
|
@ -1895,13 +1895,6 @@
|
||||||
- func: addmm_(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
- func: addmm_(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||||
variants: method
|
variants: method
|
||||||
|
|
||||||
|
|
||||||
- func: native_tensor(Type self_ty) -> Tensor
|
|
||||||
variants: []
|
|
||||||
dispatch:
|
|
||||||
SparseCPU: new_sparse
|
|
||||||
SparseCUDA: new_sparse
|
|
||||||
|
|
||||||
- func: native_tensor(Type self_ty, IntList size) -> Tensor
|
- func: native_tensor(Type self_ty, IntList size) -> Tensor
|
||||||
variants: []
|
variants: []
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ SparseTensor new_with_tensor_sparse(const LongTensor& indices, const Tensor& val
|
||||||
computed_indices_sizes.add_(1); // len = max_index + 1
|
computed_indices_sizes.add_(1); // len = max_index + 1
|
||||||
LongTensor cpu_computed_indices_sizes;
|
LongTensor cpu_computed_indices_sizes;
|
||||||
if (computed_indices_sizes.is_cuda()) {
|
if (computed_indices_sizes.is_cuda()) {
|
||||||
cpu_computed_indices_sizes = at::CPU(kLong).tensor(computed_indices_sizes.sizes());
|
cpu_computed_indices_sizes = at::empty(computed_indices_sizes.sizes(), TensorOptions(false).dtype(kLong));
|
||||||
cpu_computed_indices_sizes.copy_(computed_indices_sizes);
|
cpu_computed_indices_sizes.copy_(computed_indices_sizes);
|
||||||
} else {
|
} else {
|
||||||
cpu_computed_indices_sizes = computed_indices_sizes;
|
cpu_computed_indices_sizes = computed_indices_sizes;
|
||||||
|
|
|
||||||
|
|
@ -623,7 +623,7 @@ SparseTensor& hspmm_out_sparse_cpu(SparseTensor& r, const SparseTensor& sparse_,
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
LongTensor indices = at::CPU(kLong).tensor({1, nnz});
|
LongTensor indices = at::empty({1, nnz}, TensorOptions(false).dtype(kLong));
|
||||||
|
|
||||||
// Initialize the sparse matrix that will be used with spaddmm to send rows
|
// Initialize the sparse matrix that will be used with spaddmm to send rows
|
||||||
// from the dense matrix to rows of the output's value tensor
|
// from the dense matrix to rows of the output's value tensor
|
||||||
|
|
@ -715,7 +715,7 @@ SparseTensor& _sspaddmm_out_cpu(
|
||||||
|
|
||||||
int64_t t_nnz = t._nnz();
|
int64_t t_nnz = t._nnz();
|
||||||
int64_t r_nnz = nnz * dim_k + t_nnz;
|
int64_t r_nnz = nnz * dim_k + t_nnz;
|
||||||
LongTensor newi = native::empty({2, r_nnz}, kLong);
|
LongTensor newi = at::empty({2, r_nnz}, kLong);
|
||||||
LongTensor newv = native::zeros({r_nnz}, values.options());
|
LongTensor newv = native::zeros({r_nnz}, values.options());
|
||||||
|
|
||||||
if (t_nnz != 0) {
|
if (t_nnz != 0) {
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ inline SparseTensor _new_with_dims_and_tensor_sparse(
|
||||||
ArrayRef<int64_t> sizes,
|
ArrayRef<int64_t> sizes,
|
||||||
const LongTensor& indices,
|
const LongTensor& indices,
|
||||||
const Tensor& values) {
|
const Tensor& values) {
|
||||||
SparseTensor self = new_sparse(dtype);
|
SparseTensor self = at::empty({0}, dtype.options());
|
||||||
_get_sparse_impl(self)->resize_(sparseDims, denseDims, sizes);
|
_get_sparse_impl(self)->resize_(sparseDims, denseDims, sizes);
|
||||||
_alias_into_sparse(self, indices, values);
|
_alias_into_sparse(self, indices, values);
|
||||||
return self;
|
return self;
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ Tensor TypeDefault::copy(const Tensor & src, bool non_blocking, optional<Device>
|
||||||
}
|
}
|
||||||
AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
|
AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
|
||||||
Tensor r;
|
Tensor r;
|
||||||
if (is_sparse()) r = this->native_tensor();
|
if (is_sparse()) r = this->native_tensor({0});
|
||||||
else r = this->tensor(src.sizes());
|
else r = at::empty(src.sizes(), this->options());
|
||||||
r.copy_(src, non_blocking);
|
r.copy_(src, non_blocking);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
@ -118,7 +118,7 @@ Storage TypeDefault::unsafeStorageFromTH(void * th_pointer, bool retain) const {
|
||||||
|
|
||||||
|
|
||||||
Tensor TypeDefault::scalarTensor(Scalar s) const {
|
Tensor TypeDefault::scalarTensor(Scalar s) const {
|
||||||
return tensor({}).fill_(s);
|
return at::empty({}, this->options()).fill_(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
${type_method_definitions}
|
${type_method_definitions}
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ void test(Type& type, IntList shape, int64_t a = 0, int64_t b = 1) {
|
||||||
auto a1 = at::empty({0}, type.options());
|
auto a1 = at::empty({0}, type.options());
|
||||||
auto a2 = at::empty({0}, type.options());
|
auto a2 = at::empty({0}, type.options());
|
||||||
auto a3 = at::empty({0}, type.options());
|
auto a3 = at::empty({0}, type.options());
|
||||||
auto a4 = CPU(kDouble).tensor();
|
auto a4 = at::empty({0}, at::TensorOptions(false).dtype(kDouble));
|
||||||
|
|
||||||
std::vector<Tensor> tensors({a0, a1, a2, a3, a4});
|
std::vector<Tensor> tensors({a0, a1, a2, a3, a4});
|
||||||
for (size_t i = 0; i < tensors.size(); i++) {
|
for (size_t i = 0; i < tensors.size(); i++) {
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,7 @@ TEST(atest, atest) {
|
||||||
if (at::hasCUDA()) {
|
if (at::hasCUDA()) {
|
||||||
int isgone = 0;
|
int isgone = 0;
|
||||||
{
|
{
|
||||||
auto base = CUDA(kFloat).tensor({1, 2, 3});
|
auto base = at::empty({1,2,3}, TensorOptions(false).device(kCUDA));
|
||||||
auto f2 = CUDA(kFloat).tensorFromBlob(
|
auto f2 = CUDA(kFloat).tensorFromBlob(
|
||||||
base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
|
base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ at::ScalarType DecoderBase::onnxTypeToATenType(onnx::TensorProto_DataType onnx_t
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) {
|
at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) {
|
||||||
at::Tensor tensor = at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor();
|
at::Tensor tensor = at::empty({0}, at::TensorOptions(false).dtype(onnxTypeToATenType(tensor_proto.data_type())));
|
||||||
std::vector<int64_t> sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() };
|
std::vector<int64_t> sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() };
|
||||||
tensor.resize_(sizes);
|
tensor.resize_(sizes);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -521,7 +521,7 @@ struct ADTestSpec {
|
||||||
std::vector<Variable> make_vars() const {
|
std::vector<Variable> make_vars() const {
|
||||||
std::vector<Variable> out;
|
std::vector<Variable> out;
|
||||||
for (const auto & m : input_meta) {
|
for (const auto & m : input_meta) {
|
||||||
out.emplace_back(autograd::make_variable(at::CPU(at::kFloat).tensor(m).normal_(), /*requires_grad=*/true));
|
out.emplace_back(autograd::make_variable(at::empty(m, at::TensorOptions(false)).normal_(), /*requires_grad=*/true));
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ Tensor internal_new_from_data(const Type & type, at::optional<Device> device_opt
|
||||||
|
|
||||||
auto sizes = compute_sizes(data);
|
auto sizes = compute_sizes(data);
|
||||||
ScalarType scalarType = type_inference ? infer_scalar_type(data) : type.scalarType();
|
ScalarType scalarType = type_inference ? infer_scalar_type(data) : type.scalarType();
|
||||||
auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), /*requires_grad=*/false);
|
auto tensor = autograd::make_variable(at::empty(sizes, at::TensorOptions(false).dtype(scalarType)), /*requires_grad=*/false);
|
||||||
recursive_store(
|
recursive_store(
|
||||||
(char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
|
(char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
|
||||||
scalarType, tensor.type().elementSizeInBytes(), data);
|
scalarType, tensor.type().elementSizeInBytes(), data);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user