diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 699c1347368..b111e304e1d 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -99,7 +99,7 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons TORCH_CHECK(values.device().type() == device().type(), "device type of values (", values.device().type(), ") must match device type of device().type()", device().type(), ")"); TORCH_CHECK(values.scalar_type() == typeMetaToScalarType(dtype()), "dtype of values (", values.scalar_type(), ") must match dtype of sparse tensor (", typeMetaToScalarType(dtype()), ")"); TORCH_CHECK(indices.scalar_type() == kLong, "indices must be an int64 tensor"); - TORCH_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")"); + TORCH_CHECK(indices.options().backend() == values.options().backend(), "backend of indices (", indices.options().backend(), ") must match backend of values (", values.options().backend(), ")"); TORCH_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")"); TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes()); diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 8e4e7a2d2ac..45ccd0bbcde 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -134,7 +134,7 @@ void checkAllSameGPU(CheckedFrom c, ArrayRef tensors) { void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { TORCH_CHECK( - t1->type() == t2->type(), + t1->options().type_equal(t2->options()), "Expected tensor for ", t1, " to have the same type as tensor for ", t2, "; but type ", t1->toString(), " does not equal ", t2->toString(), " (while checking arguments for ", c, ")"); @@ -196,9 +196,9 @@ void checkAllDefined(CheckedFrom c, ArrayRef ts) { void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) { TORCH_CHECK( - !t.defined() || t.type().backend() == backend, + !t.defined() || t.options().backend() == backend, "Expected tensor to have ", toString(backend), - " Backend, but got tensor with ", toString(t.type().backend()), " Backend ", + " Backend, but got tensor with ", toString(t.options().backend()), " Backend ", "(while checking arguments for ", c, ")"); } @@ -210,9 +210,9 @@ void checkBackend(CheckedFrom c, at::ArrayRef tensors, at::Backend backe void checkDeviceType(CheckedFrom c, const Tensor& t, DeviceType device_type) { TORCH_CHECK( - !t.defined() || t.type().device_type() == device_type, + !t.defined() || t.device().type() == device_type, "Expected tensor to have ", device_type, - " DeviceType, but got tensor with ", t.type().device_type(), " DeviceType ", + " DeviceType, but got tensor with ", t.device().type(), " DeviceType ", "(while checking arguments for ", c, ")"); } diff --git a/aten/src/ATen/core/Tensor.cpp b/aten/src/ATen/core/Tensor.cpp index 217eed0462f..d7c9d592829 100644 --- a/aten/src/ATen/core/Tensor.cpp +++ b/aten/src/ATen/core/Tensor.cpp @@ -28,14 +28,20 @@ void Tensor::enforce_invariants() { void Tensor::print() const { if (defined()) { - std::cerr << "[" << type().toString() << " " << sizes() << "]" << std::endl; + std::cerr << "[" << toString() << " " << sizes() << "]" << std::endl; } else { std::cerr << "[UndefinedTensor]" << std::endl; } } std::string Tensor::toString() const { - return type().toString(); + std::string base_str; + if (scalar_type() == ScalarType::Undefined) { + base_str = "UndefinedType"; + } else { + base_str = std::string(at::toString(options().backend())) + at::toString(scalar_type()) + "Type"; + } + return base_str; } Tensor Tensor::variable_data() const { diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index d2253ed2532..087b2ab7f70 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -179,7 +179,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool { return false; } return (input.is_mkldnn()) || // input is mkldnn Tensor - (input.type().backend() == at::Backend::CPU && + (input.options().backend() == at::Backend::CPU && input.scalar_type() == kFloat && // only on CPU Float Tensors !is_dilated() && // doesn't support dilation !transposed && // or transposed tensors @@ -190,7 +190,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool { auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool { #if AT_NNPACK_ENABLED() return at::_nnpack_available() && - input.type().backend() == at::Backend::CPU && + input.options().backend() == at::Backend::CPU && input.scalar_type() == kFloat && // only on CPU Float Tensors !is_dilated() && // or dilation !transposed && // or transposed tensors @@ -594,11 +594,11 @@ at::Tensor _convolution( output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation); } } else if (params.use_cudnn(input)) { - TORCH_CHECK(input.type() == weight.type(), - "Input type (", input.type().toString(), ") and weight type (", weight.type().toString(), + TORCH_CHECK(input.options().type_equal(weight.options()), + "Input type (", input.toString(), ") and weight type (", weight.toString(), ") should be the same"); - TORCH_CHECK(!bias.defined() || (input.type() == bias.type()), - "Input type (", input.type().toString(), ") and bias type (", bias.type().toString(), + TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())), + "Input type (", input.toString(), ") and bias type (", bias.toString(), ") should be the same"); if (params.transposed) { @@ -611,11 +611,11 @@ at::Tensor _convolution( params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); } } else if (params.use_miopen(input)) { - TORCH_CHECK(input.type() == weight.type(), - "Input type (", input.type().toString(), ") and weight type (", weight.type().toString(), + TORCH_CHECK(input.options().type_equal(weight.options()), + "Input type (", input.toString(), ") and weight type (", weight.toString(), ") should be the same"); - TORCH_CHECK(!bias.defined() || (input.type() == bias.type()), - "Input type (", input.type().toString(), ") and bias type (", bias.type().toString(), + TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())), + "Input type (", input.toString(), ") and bias type (", bias.toString(), ") should be the same"); if (params.transposed) { @@ -629,11 +629,11 @@ at::Tensor _convolution( } } else if (params.use_mkldnn(input)) { #if AT_MKLDNN_ENABLED() - TORCH_CHECK(input.type() == weight.type(), - "Input type (", input.type().toString(), ") and weight type (", weight.type().toString(), + TORCH_CHECK(input.options().type_equal(weight.options()), + "Input type (", input.toString(), ") and weight type (", weight.toString(), ") should be the same"); - TORCH_CHECK(!bias.defined() || (input.type() == bias.type()), - "Input type (", input.type().toString(), ") and bias type (", bias.type().toString(), + TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())), + "Input type (", input.toString(), ") and bias type (", bias.toString(), ") should be the same"); if (!input_is_mkldnn) { output = at::mkldnn_convolution(input.contiguous(), weight.contiguous(), bias.defined() ? bias.contiguous() : bias, diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 163d8285222..d1cfe3ca2a2 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -95,7 +95,7 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return at::copy_sparse_to_sparse_(self, src, non_blocking); } else if (self.is_sparse() || src.is_sparse()) { AT_ERROR("copy_() between dense and sparse Tensors is not implemented! Found self type = ", - self.type(), " and src type = ", src.type()); + self.toString(), " and src type = ", src.toString()); } if (self.is_same(src)) { diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp index 698686eb5ae..99fc2c297c6 100644 --- a/aten/src/ATen/native/Cross.cpp +++ b/aten/src/ATen/native/Cross.cpp @@ -15,11 +15,11 @@ Tensor cross(const Tensor & input, const Tensor & other, const c10::optional dimension) { - auto device_res = input.type().device_type(); + auto device_res = input.device().type(); TORCH_CHECK(device_res == kCPU || device_res == kCUDA, "cross only supports CPU and CUDA devices, out got: ", device_res); - auto device1 = input.type().device_type(); + auto device1 = input.device().type(); TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cross only supports CPU and CUDA devices, input got: ", device1); - auto device2 = other.type().device_type(); + auto device2 = other.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cross only supports CPU and CUDA devices, other got: ", device2); TORCH_CHECK(device_res == device1, "out and input must have the same device type. out: ", device_res, " input: ", device1); TORCH_CHECK(device1 == device2, "input and other must have the same device type. input: ", device1, " other: ", device2); diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index 52ef4d8bf7f..902947e236d 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -39,10 +39,10 @@ Tensor euclidean_dist_out(const Tensor& x1, const Tensor& x2) { static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10::optional compute_mode) { TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); - auto device1 = x1.type().device_type(); + auto device1 = x1.device().type(); TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1); TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); - auto device2 = x2.type().device_type(); + auto device2 = x2.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2); TORCH_CHECK(p >= 0, "cdist only supports non-negative p values"); TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2); @@ -123,9 +123,9 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c TORCH_CHECK(grad.is_contiguous(), "_cdist_backward requires grad to be contiguous"); int64_t n = x1.size(-2); int64_t m = x1.size(-1); - auto device1 = x1.type().device_type(); + auto device1 = x1.device().type(); TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X1 got: ", device1); - auto device2 = x2.type().device_type(); + auto device2 = x2.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2); IntArrayRef batch_tensor1(x1.sizes().data(), std::max(x1.dim() - 2, 0)); int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies()); @@ -136,7 +136,7 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c Tensor _pdist_forward(const Tensor& self, const double p) { TORCH_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input"); - auto device = self.type().device_type(); + auto device = self.device().type(); TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device); Tensor result = at::empty({0}, self.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); if (self.size(0) <= 1) { @@ -157,7 +157,7 @@ Tensor _pdist_forward(const Tensor& self, const double p) { Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, const Tensor& pdist) { TORCH_CHECK(self.is_contiguous(), "_pdist_backward requires self to be contiguous"); TORCH_CHECK(pdist.is_contiguous(), "_pdist_backward requires pdist to be contiguous"); - auto device = self.type().device_type(); + auto device = self.device().type(); TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device); Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); pdist_backward_stub(device, result, grad, self, p, pdist); diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 644888a6cc0..d091e997c11 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -321,7 +321,7 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo } else { result.resize_({n_sample}); } - multinomial_stub(result.type().device_type(), result, self, n_sample, with_replacement, gen); + multinomial_stub(result.device().type(), result, self, n_sample, with_replacement, gen); return result; } diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index a0f7ecd4850..dca29bb5643 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -165,7 +165,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) // For CUDA tensors, force all index tensors to have the same striding to // simplify the CUDA kernel. - if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) { + if (indices.size() >= 2 && this->src.device().type() == kCUDA) { if (!all_strides_match(indices)) { for (size_t i = 0; i < indices.size(); i++) { indices[i] = indices[i].contiguous(); @@ -251,8 +251,8 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu if (indices.size() > (size_t)self.dim()) { AT_INDEX_ERROR("too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); } - if (accumulate && self.type().device_type() == kCUDA) { - index_put_accum_stub(self.type().device_type(), self, indices, value, unsafe); + if (accumulate && self.device().type() == kCUDA) { + index_put_accum_stub(self.device().type(), self, indices, value, unsafe); return self; } auto info = make_info(self, indices); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index d4e768508b7..f6dbddeb512 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -88,7 +88,7 @@ std::tuple slogdet(const Tensor& self) { Tensor pinverse(const Tensor& self, double rcond) { TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() >= 2, - "pinverse(", self.type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions " + "pinverse(", self.scalar_type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions " "of floating types"); if (self.numel() == 0) { // Match NumPy @@ -118,7 +118,7 @@ static inline Tensor _matrix_rank_helper(const Tensor& self, bool symmetric) { Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) { TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2, - "matrix_rank(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " + "matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor " "of floating types"); Tensor S = _matrix_rank_helper(self, symmetric); @@ -127,7 +127,7 @@ Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) { Tensor matrix_rank(const Tensor& self, bool symmetric) { TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2, - "matrix_rank(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " + "matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor " "of floating types"); Tensor S = _matrix_rank_helper(self, symmetric); @@ -479,7 +479,7 @@ Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor Tensor matrix_power(const Tensor& a, int64_t n) { TORCH_CHECK(a.dim() >= 2 && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())), - "matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor " + "matrix_power(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor " "of floating types with dim at least 2"); if (n == 0) { return a.clone(at::MemoryFormat::Contiguous).copy_(at::eye(a.size(-2), a.options()).expand_as(a)); diff --git a/aten/src/ATen/native/Memory.cpp b/aten/src/ATen/native/Memory.cpp index d0b24fff1ec..a69c7c62a98 100644 --- a/aten/src/ATen/native/Memory.cpp +++ b/aten/src/ATen/native/Memory.cpp @@ -14,8 +14,8 @@ bool is_pinned(const Tensor& self) { } Tensor pin_memory(const Tensor& self) { - if (self.type().backend() != Backend::CPU) { - AT_ERROR("cannot pin '", self.type().toString(), "' only dense CPU tensors can be pinned"); + if (self.options().backend() != Backend::CPU) { + AT_ERROR("cannot pin '", self.toString(), "' only dense CPU tensors can be pinned"); } if (self.is_pinned()) { return self; diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index 191524a1987..ae4bd65afcb 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -4,7 +4,7 @@ namespace at { namespace native { void checkLongTensor(const Tensor& tensor) { - TORCH_CHECK(tensor.dim() == 1 && tensor.type().device_type() == at::kCPU && tensor.scalar_type() == at::kLong, + TORCH_CHECK(tensor.dim() == 1 && tensor.device().type() == at::kCPU && tensor.scalar_type() == at::kLong, "'lengths' argument should be a 1D CPU int64 tensor"); } diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp index 125a33309af..b9e8de9541f 100644 --- a/aten/src/ATen/native/PointwiseOps.cpp +++ b/aten/src/ATen/native/PointwiseOps.cpp @@ -35,7 +35,7 @@ Tensor& addcmul_out( const Tensor& tensor1, const Tensor& tensor2, Scalar value) { - checkBackend("addcmul_cpu", result, self.type().backend()); + checkBackend("addcmul_cpu", result, self.options().backend()); auto iter = at::TensorIterator(); iter.set_check_mem_overlap(true); iter.add_output(result); @@ -70,7 +70,7 @@ Tensor& addcdiv_out( const Tensor& tensor1, const Tensor& tensor2, Scalar value) { - checkBackend("addcdiv_cpu", result, self.type().backend()); + checkBackend("addcdiv_cpu", result, self.options().backend()); auto iter = at::TensorIterator(); iter.set_check_mem_overlap(true); iter.add_output(result); diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 7c355dbd138..031182c7d38 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -850,13 +850,13 @@ std::tuple NAME( \ bool batch_first) { \ if (at::cudnn_is_acceptable(_input)) { \ Tensor output, hy; \ - NAME##_cudnn_stub(_input.type().device_type(), output, hy, _input, hx, _params, has_biases, \ + NAME##_cudnn_stub(_input.device().type(), output, hy, _input, hx, _params, has_biases, \ num_layers, dropout_p, train, bidirectional, batch_first); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ if (use_miopen(_input, dropout_p)) { \ Tensor output, hy; \ - NAME##_miopen_stub(_input.type().device_type(), output, hy, _input, hx, _params, has_biases, \ + NAME##_miopen_stub(_input.device().type(), output, hy, _input, hx, _params, has_biases, \ num_layers, dropout_p, train, bidirectional, batch_first); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ @@ -883,13 +883,13 @@ std::tuple NAME( \ bool bidirectional) { \ if (at::cudnn_is_acceptable(data)) { \ Tensor output, hy; \ - NAME##_packed_cudnn_stub(data.type().device_type(), output, hy, data, batch_sizes, hx, \ + NAME##_packed_cudnn_stub(data.device().type(), output, hy, data, batch_sizes, hx, \ _params, has_biases, num_layers, dropout_p, train, bidirectional); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ if (use_miopen(data, dropout_p)) { \ Tensor output, hy; \ - NAME##_packed_miopen_stub(data.type().device_type(), output, hy, data, batch_sizes, hx, \ + NAME##_packed_miopen_stub(data.device().type(), output, hy, data, batch_sizes, hx, \ _params, has_biases, num_layers, dropout_p, train, bidirectional); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ @@ -914,7 +914,7 @@ std::tuple NAME( \ bool batch_first) { \ if (at::cudnn_is_acceptable(_input)) { \ Tensor output, hy; \ - gru_cudnn_stub(_input.type().device_type(), output, hy, _input, hx, _params, has_biases, \ + gru_cudnn_stub(_input.device().type(), output, hy, _input, hx, _params, has_biases, \ num_layers, dropout_p, train, bidirectional, batch_first); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ @@ -941,7 +941,7 @@ std::tuple NAME( \ bool bidirectional) { \ if (at::cudnn_is_acceptable(data)) { \ Tensor output, hy; \ - gru_packed_cudnn_stub(data.type().device_type(), output, hy, data, batch_sizes, hx, \ + gru_packed_cudnn_stub(data.device().type(), output, hy, data, batch_sizes, hx, \ _params, has_biases, num_layers, dropout_p, train, bidirectional); \ return std::make_tuple(std::move(output), std::move(hy)); \ } \ @@ -976,14 +976,14 @@ std::tuple lstm( TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states"); if (at::cudnn_is_acceptable(_input)) { Tensor output, hy, cy; - lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases, + lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } if (use_miopen(_input, dropout_p)) { Tensor output, hy, cy; - lstm_miopen_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases, + lstm_miopen_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } @@ -1005,14 +1005,14 @@ std::tuple lstm( TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states"); if (at::cudnn_is_acceptable(data)) { Tensor output, hy, cy; - lstm_packed_cudnn_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx, + lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } if (use_miopen(data, dropout_p)) { Tensor output, hy, cy; - lstm_packed_miopen_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx, + lstm_packed_miopen_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } @@ -1154,7 +1154,7 @@ std::tuple quantized_lstm( TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states"); if (at::cudnn_is_acceptable(_input)) { Tensor output, hy, cy; - lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases, + lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } @@ -1202,7 +1202,7 @@ std::tuple quantized_lstm( TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states"); if (at::cudnn_is_acceptable(data)) { Tensor output, hy, cy; - lstm_packed_cudnn_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx, + lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional); return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 7813b677396..4e7e607e127 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -136,9 +136,9 @@ static TensorIterator make_reduction( for (const Tensor *t: {&result1, &result2}) { const Tensor& result = *t; TORCH_CHECK( - !result.defined() || result.type().scalarType() == dtype, + !result.defined() || result.scalar_type() == dtype, name, ": provided dtype must match dtype of result. Got ", - toString(result.type().scalarType()), + toString(result.scalar_type()), " and ", toString(dtype), "."); @@ -161,8 +161,8 @@ static TensorIterator make_reduction( // efficiency. // We don't generalize this to common mismatched input/output types to avoid cross // product of templated kernel launches. - if (self.type().scalarType() == dtype || - (self.is_cuda() && self.type().scalarType() == kHalf && dtype == kFloat)) { + if (self.scalar_type() == dtype || + (self.is_cuda() && self.scalar_type() == kHalf && dtype == kFloat)) { return TensorIterator::reduce_op(viewed_result1, viewed_result2, self); } return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype)); @@ -434,8 +434,9 @@ Tensor& logsumexp_out(Tensor& result, const Tensor& self, DimnameList dims, bool static Tensor& norm_out(Tensor &result, const Tensor &self, optional opt_p, IntArrayRef dim, bool keepdim, optional opt_dtype) { auto p = opt_p.value_or(2.0); - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "norm only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); + ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); TORCH_CHECK( @@ -458,8 +459,8 @@ static inline Tensor _norm(const Tensor &self, Scalar p) { if (self.is_sparse()) { return at::native_norm(self, p); } else { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "norm only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "norm only supports floating-point dtypes"); @@ -510,9 +511,9 @@ inline Tensor & _all(Tensor & result, TensorIterator & iter) { } Tensor all(const Tensor& self) { - TORCH_CHECK(self.type().backend() == Backend::CPU || - self.type().backend() == Backend::CUDA, "all only supports CPU AND CUDA " - "backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || + self.options().backend() == Backend::CUDA, "all only supports CPU AND CUDA " + "backend, got: ", toString(self.options().backend())); TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, "all only supports torch.uint8 and torch.bool dtypes"); @@ -528,9 +529,9 @@ Tensor all(const Tensor& self, int64_t dim, bool keepdim) { } Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { - TORCH_CHECK(self.type().backend() == Backend::CPU || - self.type().backend() == Backend::CUDA, "all only supports CPU AND CUDA " - "backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || + self.options().backend() == Backend::CUDA, "all only supports CPU AND CUDA " + "backend, got: ", toString(self.options().backend())); TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, "all only supports torch.uint8 and torch.bool dtypes"); dim = maybe_wrap_dim(dim, self.dim()); @@ -554,11 +555,11 @@ inline Tensor & _any(Tensor & result, TensorIterator & iter) { } Tensor any(const Tensor& self) { - TORCH_CHECK(self.type().backend() == Backend::CPU || - self.type().backend() == Backend::CUDA || - self.type().backend() == Backend::SparseCPU || - self.type().backend() == Backend::SparseCUDA, "any only supports CPU, CUDA, " - "SparseCPU and SparseCUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || + self.options().backend() == Backend::CUDA || + self.options().backend() == Backend::SparseCPU || + self.options().backend() == Backend::SparseCUDA, "any only supports CPU, CUDA, " + "SparseCPU and SparseCUDA backend, got: ", toString(self.options().backend())); TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, "all only supports torch.uint8 and torch.bool dtypes"); @@ -574,9 +575,9 @@ Tensor any(const Tensor& self, int64_t dim, bool keepdim) { } Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) { - TORCH_CHECK(self.type().backend() == Backend::CPU || - self.type().backend() == Backend::CUDA, "any only supports CPU AND CUDA " - "backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || + self.options().backend() == Backend::CUDA, "any only supports CPU AND CUDA " + "backend, got: ", toString(self.options().backend())); TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool, "all only supports torch.uint8 and torch.bool dtypes"); dim = maybe_wrap_dim(dim, self.dim()); @@ -636,7 +637,7 @@ Tensor& argmax_out(Tensor& result, const Tensor& self, c10::optional di in = self.reshape({-1}); keepdim = false; } - if (self.type().backend() != Backend::CPU && self.type().backend() != Backend::CUDA) { + if (self.options().backend() != Backend::CPU && self.options().backend() != Backend::CUDA) { Tensor ignored = at::empty({0}, self.options()); return std::get<1>(at::max_out(ignored, result, in, dim.value_or(0), keepdim)); } @@ -661,7 +662,7 @@ Tensor& argmin_out(Tensor& result, const Tensor& self, c10::optional di in = self.reshape({-1}); keepdim = false; } - if (self.type().backend() != Backend::CPU && self.type().backend() != Backend::CUDA) { + if (self.options().backend() != Backend::CPU && self.options().backend() != Backend::CUDA) { Tensor ignored = at::empty({0}, self.options()); return std::get<1>(at::min_out(ignored, result, in, dim.value_or(0), keepdim)); } @@ -677,8 +678,8 @@ Tensor argmin(const Tensor& self, c10::optional dim, bool keepdims) { } static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "std and var only support CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "std and var only support CPU AND CUDA backend, got: ", toString(self.options().backend())); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "std and var only support floating-point dtypes"); @@ -716,15 +717,13 @@ static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim, static std::tuple std_var_mean_out(const char* fname, Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) { AT_ASSERT(result1.defined() && result2.defined()); - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - fname, " only support CPU and CUDA backend, got: ", toString(self.type().backend())); - TORCH_CHECK(at::isFloatingType(self.type().scalarType()) || at::isComplexType(self.scalar_type()), - fname, " only support floating-point dtypes"); - TORCH_CHECK(result1.type().scalarType() == result2.type().scalarType(), + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, fname, " only support CPU AND CUDA backend, got: ", toString(self.options().backend())); + TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), fname, " only support floating-point dtypes"); + TORCH_CHECK(result1.scalar_type() == result2.scalar_type(), "provided by result1 dtype must match dtype of result2. Got ", - toString(result1.type().scalarType()), + toString(result1.scalar_type()), " and ", - toString(result2.type().scalarType()), + toString(result2.scalar_type()), "."); if (at::isComplexType(self.scalar_type())){ ScalarType dtype = c10::toValueType(get_dtype(result1, self, {}, true)); @@ -805,8 +804,8 @@ std::tuple var_mean(const Tensor& self, bool unbiased) { } Tensor var(const Tensor& self, bool unbiased) { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "var only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "var only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "var only supports floating-point dtypes"); auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits::quiet_NaN()); @@ -823,8 +822,8 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias } Tensor std(const Tensor& self, bool unbiased) { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "std only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "std only supports floating-point dtypes"); auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits::quiet_NaN()); diff --git a/aten/src/ATen/native/SortingUtils.h b/aten/src/ATen/native/SortingUtils.h index 58617890e59..6b83be84ce9 100644 --- a/aten/src/ATen/native/SortingUtils.h +++ b/aten/src/ATen/native/SortingUtils.h @@ -55,7 +55,7 @@ inline void _reduction_with_indices_allocate_or_resize_output( } if (values.defined()) { TORCH_CHECK( - self.type() == values.type(), + self.options().type_equal(values.options()), "output values must be of same type as input"); if (!keepdim && values.dim() == self.dim() - 1) { // unsqueeze to preserve passed in noncontiguous tensor in resize @@ -95,7 +95,7 @@ inline void _allocate_or_resize_output_with_indices( } if (values.defined()) { TORCH_CHECK( - self.type() == values.type(), + self.options().type_equal(values.options()), "output values must be of same type as input"); values.resize_(result_sizes); } else { diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 83d37452e47..00abe39da80 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -29,7 +29,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim, signal_ndim); TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Expected an input tensor of floating types, but got input=", - self.type(), self.sizes()); + self.toString(), self.sizes()); auto signal_tensor_ndim = signal_ndim + static_cast(complex_input); // add complex dim if (self.dim() < signal_tensor_ndim) { @@ -39,7 +39,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim, if (complex_input) { ss << " (complex input adds an extra dimension)"; } - ss << ", but got input=" << self.type() << self.sizes(); + ss << ", but got input=" << self.toString() << self.sizes(); AT_ERROR(ss.str()); } @@ -65,7 +65,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim, TORCH_CHECK(input.size(signal_ndim + 1) == 2, "Expected an input tensor with a last dimension of size 2 " "representing real + imaginary components, but got input ", - self.type(), self.sizes()); + self.toString(), self.sizes()); } // build signal_sizes and output_size @@ -101,7 +101,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim, TORCH_CHECK(signal_sizes.size() == 0 || signal_sizes[i] == checked_signal_sizes[i], "Expected given signal_sizes=", signal_sizes," to have same " "shape with input at signal dimension ", i, ", but got " - "signal_sizes=", signal_sizes, " and input=", self.type(), + "signal_sizes=", signal_sizes, " and input=", self.toString(), self.sizes()); } } @@ -177,11 +177,11 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop const optional win_lengthOpt, const Tensor& window, const bool normalized, const bool onesided) { #define REPR(SS) \ - SS << "stft(" << self.type() << self.sizes() << ", n_fft=" << n_fft \ + SS << "stft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \ << ", hop_length=" << hop_length << ", win_length=" << win_length \ << ", window="; \ if (window.defined()) { \ - SS << window.type() << "{" << window.sizes() << "}"; \ + SS << window.toString() << "{" << window.sizes() << "}"; \ } else { \ SS << "None"; \ } \ diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 9fa76ad16b6..372dd7b346a 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -150,8 +150,8 @@ std::tuple mode(const Tensor& self, int64_t dim, bool keepdim) { std::tuple mode_out(Tensor& values, Tensor& indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "mode only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "mode only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "mode")) { AT_ASSERT(values.dim() == 0); @@ -202,8 +202,8 @@ std::tuple max(const Tensor& self, int64_t dim, bool keepdim) { static std::tuple max_out_impl(Tensor& max, Tensor& max_indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "max only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "max only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) { AT_ASSERT(max.dim() == 0); @@ -262,8 +262,8 @@ std::tuple min(const Tensor& self, int64_t dim, bool keepdim) { static std::tuple min_out_impl(Tensor& min, Tensor& min_indices, const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, - "min only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); + TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, + "min only supports CPU AND CUDA backend, got: ", toString(self.options().backend())); dim = maybe_wrap_dim(dim, self.dim()); if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) { AT_ASSERT(min.dim() == 0); diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index c8c00bbc6dc..193a2a4cb82 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -67,7 +67,7 @@ inline void check_size_nonnegative(IntArrayRef size) { inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) { TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(), - "n is too large for result tensor type: '", tensor.type().toString(), "'"); + "n is too large for result tensor type: '", tensor.toString(), "'"); // Ensure sufficient precision for floating point representation. switch (tensor.scalar_type()) { diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 48ba96d2d58..6d51d955c65 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -350,7 +350,7 @@ Tensor expand(const Tensor& self, IntArrayRef size, bool implicit) { // requested by the user, because it is legal to remove implicit expands // from the graph, but not legal to remove the explicit ones. TORCH_CHECK(size.size() >= (size_t)self.dim(), - "expand(", self.type(), "{", self.sizes(), "}, size=", size, + "expand(", self.toString(), "{", self.sizes(), "}, size=", size, "): the number of sizes provided (", size.size(), ") ", "must be greater or equal to the number of dimensions in the tensor (", self.dim(), ")"); diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 8573388df88..783f25c06ab 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -267,7 +267,7 @@ Tensor& _clamp_min_out_cpu(Tensor& result, const Tensor& self, Scalar min) { Tensor mvlgamma(const Tensor& self, int64_t p) { TORCH_CHECK(at::isFloatingType(self.scalar_type()), - "mvlgamma is not implemented for ", self.type()); + "mvlgamma is not implemented for ", self.scalar_type()); TORCH_CHECK((self > 0.5 * (p - 1.)).all().item(), "Condition for computing multivariate log-gamma not met"); TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1"); @@ -278,7 +278,7 @@ Tensor mvlgamma(const Tensor& self, int64_t p) { Tensor& mvlgamma_(Tensor& self, int64_t p) { TORCH_CHECK(at::isFloatingType(self.scalar_type()), - "mvlgamma is not implemented for ", self.type()); + "mvlgamma is not implemented for ", self.scalar_type()); TORCH_CHECK((self > 0.5 * (p - 1.)).all().item(), "Condition for computing multivariate log-gamma not met"); TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1"); diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 55bdaa5acbc..c9036158d74 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -238,6 +238,7 @@ class CAFFE2_API Tensor { return impl_->itemsize(); } + C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") DeprecatedTypeProperties & type() const { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( tensorTypeIdToBackend(legacyExtractTypeId(type_set())), diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index 2ff2585a26f..f5772d69823 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -20,7 +20,7 @@ struct Foo { static void apply(Tensor a, Tensor b) { scalar_type s = 1; std::stringstream ss; - ss << "hello, dispatch: " << a.type().toString() << s << "\n"; + ss << "hello, dispatch: " << a.toString() << s << "\n"; auto data = (scalar_type*)a.data_ptr(); (void)data; } @@ -110,7 +110,7 @@ TEST(TestScalar, TestScalar) { scalar_t s = 1; std::stringstream ss; ASSERT_NO_THROW( - ss << "hello, dispatch" << x.type().toString() << s << "\n"); + ss << "hello, dispatch" << x.toString() << s << "\n"); auto data = (scalar_t*)x.data_ptr(); (void)data; }); diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index bdae437877f..b6e5b5d1a2b 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -292,6 +292,15 @@ struct C10_API TensorOptions { return has_pinned_memory_; } + /// Returns if the layout is sparse + bool is_sparse() const { + return layout_ == c10::Layout::Sparse; + } + + // For compatibility with legacy tensor.type() comparisons + bool type_equal(const TensorOptions& other) const { + return backend() == other.backend() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype()); + } /// Returns the `pinned_memory` property of the `TensorOptions`, or /// `c10::nullopt` if `pinned_memory` is not specified. @@ -538,6 +547,12 @@ inline TensorOptions dtype() { return dtype(caffe2::TypeMeta::Make()); } +inline std::string toString(const TensorOptions options) { + std::ostringstream stream; + stream << options; + return stream.str(); +} + // This is intended to be a centralized location by which we can determine // what an appropriate TensorTypeId for a tensor is. // diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp index 7ed87e9906f..5de56139702 100644 --- a/test/cpp/api/tensor_options.cpp +++ b/test/cpp/api/tensor_options.cpp @@ -21,7 +21,7 @@ using namespace torch::test; ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \ ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \ ASSERT_EQ(tensor.scalar_type(), (type_)); \ - ASSERT_TRUE(tensor.type().layout() == (layout_)) + ASSERT_TRUE(tensor.options().layout() == (layout_)) TEST(TensorOptionsTest, DefaultsToTheRightValues) { TensorOptions options; diff --git a/test/cpp/api/tensor_options_cuda.cpp b/test/cpp/api/tensor_options_cuda.cpp index f896b73158d..15d5ebd98e4 100644 --- a/test/cpp/api/tensor_options_cuda.cpp +++ b/test/cpp/api/tensor_options_cuda.cpp @@ -29,7 +29,7 @@ at::Device CUDADevice(DeviceIndex index) { ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \ ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \ ASSERT_EQ(tensor.scalar_type(), (type_)); \ - ASSERT_TRUE(tensor.type().layout() == (layout_)) + ASSERT_TRUE(tensor.options().layout() == (layout_)) TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) { auto options = CUDA(kFloat).options(); diff --git a/test/cpp/jit/test_argument_spec.cpp b/test/cpp/jit/test_argument_spec.cpp index 0dfe92cc48b..0baac09b02f 100644 --- a/test/cpp/jit/test_argument_spec.cpp +++ b/test/cpp/jit/test_argument_spec.cpp @@ -6,7 +6,7 @@ namespace torch { namespace jit { int device(const autograd::Variable& v) { - return v.type().is_cuda() ? v.get_device() : -1; + return v.device().is_cuda() ? v.get_device() : -1; } bool isEqual(at::IntArrayRef lhs, at::IntArrayRef rhs) { @@ -29,18 +29,18 @@ bool isEqual(const ArgumentInfo& ti, const autograd::Variable& v) { ti.type() == v.scalar_type() && ti.dim() == v.dim(); } -autograd::Variable var(at::DeprecatedTypeProperties& t, at::IntArrayRef sizes, bool requires_grad) { - return autograd::make_variable(at::rand(sizes, t.options()), requires_grad); +autograd::Variable var(at::TensorOptions t, at::IntArrayRef sizes, bool requires_grad) { + return autograd::make_variable(at::rand(sizes, t), requires_grad); } autograd::Variable undef() { return autograd::Variable(); } void testCompleteArgumentSpec() { - auto& CF = at::CPU(at::kFloat); - auto& CD = at::CPU(at::kDouble); - auto& GF = at::CUDA(at::kFloat); - auto& GD = at::CUDA(at::kDouble); + auto const CF = at::CPU(at::kFloat); + auto const CD = at::CPU(at::kDouble); + auto const GF = at::CUDA(at::kFloat); + auto const GD = at::CUDA(at::kDouble); auto list = createStack({var(CF, {1}, true), var(CD, {1, 2}, false), diff --git a/test/cpp_extensions/cuda_extension.cpp b/test/cpp_extensions/cuda_extension.cpp index d4b9c6cd9af..5fe9c06c257 100644 --- a/test/cpp_extensions/cuda_extension.cpp +++ b/test/cpp_extensions/cuda_extension.cpp @@ -6,8 +6,8 @@ void sigmoid_add_cuda(const float* x, const float* y, float* output, int size); torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) { - TORCH_CHECK(x.type().is_cuda(), "x must be a CUDA tensor"); - TORCH_CHECK(y.type().is_cuda(), "y must be a CUDA tensor"); + TORCH_CHECK(x.device().is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(y.device().is_cuda(), "y must be a CUDA tensor"); auto output = torch::zeros_like(x); sigmoid_add_cuda( x.data_ptr(), y.data_ptr(), output.data_ptr(), output.numel()); diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 105e23c00a2..73694650a4d 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -656,7 +656,7 @@ def create_python_bindings(python_functions, has_self, is_module=False): } python_binding_arguments.append(dtype_arg) if is_factory_function or is_like_or_new_function_with_options: - py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_or_new_function_with_options else None + py_default_layout = '*torch::getLayout(self.options().backend())' if is_like_or_new_function_with_options else None layout_arg = { 'default': 'torch.strided', 'dynamic_type': 'Layout', diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h index abf40e1ea9d..95b4a1c8f85 100644 --- a/tools/autograd/templates/Functions.h +++ b/tools/autograd/templates/Functions.h @@ -32,17 +32,17 @@ inline std::vector unpack_list(at::ArrayRef xs) { } struct TypeAndSize { - TypeAndSize() : type(nullptr) {} + TypeAndSize() : options(at::TensorOptions()) {} /* implicit */ TypeAndSize(const Tensor & t) : sizes(t.sizes().vec()) - , type(&t.type()) {} + , options(t.options()) {} - Tensor zeros() { return at::zeros(sizes, *type); } + Tensor zeros() { return at::zeros(sizes, options); } private: std::vector sizes; - at::DeprecatedTypeProperties* type; + at::TensorOptions options; }; ${autograd_function_declarations} diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 47fa7b672ba..1da51b94b93 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -783,7 +783,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.isNone(0)) { - return THPUtils_packString(torch::utils::type_to_string(self_.type())); + return THPUtils_packString(torch::utils::options_to_string(self_.options())); } auto obj = r.pyobject(0); auto opt_memory_format = r.memoryformatOptional(2); @@ -807,9 +807,9 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa if (is_dtype) { scalar_type = r.scalartype(0); } else { - at::DeprecatedTypeProperties* type = torch::utils::type_from_string(type_name); - scalar_type = type->scalarType(); - auto device_type = backendToDeviceType(type->backend()); + at::TensorOptions options = torch::utils::options_from_string(type_name); + scalar_type = at::typeMetaToScalarType(options.dtype()); + auto device_type = options.device().type(); if (device_type != device.type()) { device = at::Device(device_type); } diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 7ffe1b231b2..2abb40dbb49 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -102,7 +102,7 @@ static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state } auto& tensor = ((THPVariable*)_new_state)->cdata; if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) { - auto type_name = torch::utils::type_to_string(tensor.type()); + auto type_name = torch::utils::options_to_string(tensor.options()); throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str()); } if (self->cdata->device().type() == at::kCPU) { diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index e29ff09b7e9..e3ac3e8b80f 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -101,7 +101,7 @@ void set_data(const Tensor & self, const Tensor & new_data) { const auto prior_device = prior_accumulator->input_metadata(0).device(); const auto new_device = new_data.device(); - if (new_data.type() != self.type() || prior_device != new_device) { + if (!new_data.options().type_equal(self.options()) || prior_device != new_device) { autograd_meta->grad_accumulator_.reset(); } } diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index c4de008d55e..d3207f857bd 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -99,7 +99,7 @@ variable_list _wrap_outputs(const variable_list &input_vars, } void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) { - if (original.type() != result.type()) { + if (!original.options().type_equal(result.options())) { std::stringstream ss; ss << "hook '" << hook_name << "' has changed the type of value ("; ss << "was " << original.toString() << " got "; diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index eb7bfae35b7..711797f3b96 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -411,11 +411,10 @@ static variable_list call_post_hooks(Node& fn, variable_list outputs, const vari return outputs; } -static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) { +static bool is_compatible_type(const at::TensorOptions& expected, const at::TensorOptions& actual) { // Types are compatible if they exactly match or if the gradient is a sparse // version of the expected type. - return expected == actual || (actual.is_sparse() && - expected == actual.toBackend(toDense(actual.backend()))); + return expected.type_equal(actual) || (actual.is_sparse() && expected.device().type() == actual.device().type()); } void validate_outputs( @@ -451,14 +450,14 @@ void validate_outputs( } grads[i] = at::sum_to(std::move(grads[i]), metadata.shape()); } - TORCH_CHECK(isFloatingType(grads[i].type().scalarType())); - if (metadata.type().scalarType() != grads[i].type().scalarType()) { - grads[i] = grads[i].to(metadata.type().scalarType()); + TORCH_CHECK(isFloatingType(grads[i].scalar_type())); + if (c10::typeMetaToScalarType(metadata.options().dtype()) != grads[i].scalar_type()) { + grads[i] = grads[i].to(c10::typeMetaToScalarType(metadata.options().dtype())); } - if (!is_compatible_type(metadata.type(), grads[i].type())) { + if (!is_compatible_type(metadata.options(), grads[i].options())) { std::stringstream ss; ss << "invalid gradient at index " << i << " - expected type "; - ss << metadata.type() << " but got " << grads[i].type(); + ss << metadata.options() << " but got " << grads[i].options(); AT_ERROR(format_error(ss.str())); } auto output_device = output.device(); diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index dc42a6da95f..c98368c04ce 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -137,11 +137,11 @@ struct TORCH_API Node : std::enable_shared_from_this { /// Adds the type and shape metadata for a new input. Returns the index of /// of the new input. uint32_t add_input_metadata( - const at::DeprecatedTypeProperties& type + const at::TensorOptions& options , at::IntArrayRef shape , at::Device device) noexcept { uint32_t input_nr = input_metadata_.size(); - input_metadata_.emplace_back(type, shape, device); + input_metadata_.emplace_back(options, shape, device); return input_nr; } diff --git a/torch/csrc/autograd/functions/comm.cpp b/torch/csrc/autograd/functions/comm.cpp index 1c48b0bbe4d..82f11471305 100644 --- a/torch/csrc/autograd/functions/comm.cpp +++ b/torch/csrc/autograd/functions/comm.cpp @@ -77,7 +77,7 @@ variable_list Gather::apply(variable_list&& inputs) { TORCH_CHECK( input.is_cuda(), "All inputs to Gather must be CUDA tensors, got ", - input.type()); + input.toString()); if (input.dim() > 0) { all_are_zero_dim = false; } diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 4c1074db2bd..0765f1b2321 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -20,21 +20,16 @@ namespace torch { namespace autograd { struct InputMetadata { InputMetadata() = default; - InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device) - : type_{&type}, shape_{shape}, device_{device} { + InputMetadata(const at::TensorOptions options, at::IntArrayRef shape, at::Device device) + : options_{options}, shape_{shape}, device_{device} { stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_); } InputMetadata(const at::Tensor& t) - : InputMetadata(t.type(), t.sizes(), t.device()) { } + : InputMetadata(t.options(), t.sizes(), t.device()) { } - bool is_valid() const { - return type_ != nullptr; - } - - const at::DeprecatedTypeProperties& type() const { - AT_ASSERT(type_); - return *type_; + const at::TensorOptions options() const { + return options_; } at::IntArrayRef shape() const { @@ -50,11 +45,11 @@ struct InputMetadata { } at::Tensor zeros_like() const { - return at::zeros(shape_, type_->options(device_)); + return at::zeros(shape_, options_); } private: - const at::DeprecatedTypeProperties* type_ = nullptr; + const at::TensorOptions options_; at::DimVector shape_; at::Device device_ = at::kCPU; c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device_); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 308620700a8..16834abfb70 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -275,7 +275,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) bool gradIsSparse = (var.dtype() == grad.dtype() && var.device().type() == grad.device().type() && grad.layout() == kSparse); - THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse, + THPUtils_assertRet(-1, grad.options().type_equal(var.options()) || gradIsSparse, "assigned grad has data of a different type"); if (var.is_cuda()) { THPUtils_assertRet(-1, grad.get_device() == var.get_device(), @@ -487,7 +487,7 @@ static PyObject *THPVariable_dtype(THPVariable *self, void *unused) static PyObject * THPVariable_layout(THPVariable* self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; - return torch::autograd::utils::wrap(torch::getLayout(self_.type().backend())); + return torch::autograd::utils::wrap(torch::getLayout(self_.options().backend())); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index cfd5b4aab1d..7611680803d 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -149,7 +149,7 @@ static Variable valueToTensor(c10::TensorOptions options, PyObject* value) { throw TypeError( "can't assign a %s to a %s", Py_TYPE(value)->tp_name, - torch::utils::type_to_string(getDeprecatedTypeProperties(options.backend(), typeMetaToScalarType(options.dtype()))).c_str()); + torch::utils::options_to_string(options).c_str()); } static Variable boolToIndexingTensor(const Variable& self, bool value) { diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index 9d6bb29f569..756657cb4f5 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -57,7 +57,7 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { if (saved_version_ != version_counter_.current_version()) { std::stringstream message; message << "one of the variables needed for gradient computation has been " - "modified by an inplace operation: [" << data_.type().toString() << " " + "modified by an inplace operation: [" << data_.toString() << " " << data_.sizes() << "]"; if (grad_fn) { message << ", which is output " << output_nr_ diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index d771da70a54..9f478292ba2 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -331,7 +331,7 @@ const std::shared_ptr& VariableHooks::grad_fn(const Tenso fn->storage_offset = self.storage_offset(); fn->set_next_edges(torch::autograd::collect_next_edges(diff_view_meta->base_)); fn->add_input_metadata( - diff_view_meta->base_.type() + diff_view_meta->base_.options() , self.sizes() // Note: sizes(), not base_.sizes(), is intentional , diff_view_meta->base_.device()); diff_view_meta->grad_fn_ = std::move(fn); diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 018fd0c501e..f517d564ff2 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -193,11 +193,11 @@ void Reducer::mark_variable_ready_dense(VariableIndex index) { if (grad.defined()) { // Ensure that the gradient type matches the bucket type. AT_ASSERTM( - grad.type() == bucket_view.type(), + grad.options().type_equal(bucket_view.options()), "Expected ", - bucket_view.type(), + bucket_view.toString(), ", got ", - grad.type()); + grad.toString()); // Assert that the grad tensor and the bucket don't share storage. // If they did, we could avoid the copy altogether. // The reason for not doing this is that existing code calls diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index fd730dd4c9f..346bdedddcf 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -277,7 +277,7 @@ struct DifferentiableGraphBackward : public autograd::Node { // NB: since our requires_grad setting is only a heuristic we might end // up wanting to differentiate through integral tensors, which is // generally a hard error in autograd. - if (at::isFloatingType(output.type().scalarType())) { + if (at::isFloatingType(output.scalar_type())) { autograd::create_gradient_edge(output, shared_from_this()); output.set_requires_grad(true); } else { diff --git a/torch/csrc/jit/node_hashing.cpp b/torch/csrc/jit/node_hashing.cpp index 6c372efc733..18330d1be3d 100644 --- a/torch/csrc/jit/node_hashing.cpp +++ b/torch/csrc/jit/node_hashing.cpp @@ -16,7 +16,7 @@ namespace jit { namespace { bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { - return lhs.type() == rhs.type() && lhs.equal(rhs); + return lhs.options().type_equal(rhs.options()) && lhs.equal(rhs); } bool tensorListEqual( diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 45555604421..1c476e88083 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -296,7 +296,7 @@ struct PythonPrintImpl { // because it doesn't hash any information about the tensors. // We will probably need to optimize this at some point using hashing. for (size_t i = 0; i < tensor_table_.size(); ++i) { - if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) { + if (t.options().type_equal(tensor_table_[i].options()) && t.equal(tensor_table_[i])) { return i; } } diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index 8b6debff6a4..27c6f99fe43 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -54,7 +54,7 @@ static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t di } Tensor & apply_(Tensor & self, PyObject* fn) { - if (self.type().backend() != Backend::CPU) { + if (self.options().backend() != Backend::CPU) { throw TypeError("apply_ is only implemented on CPU tensors"); } auto scalarType = self.scalar_type(); @@ -63,12 +63,12 @@ Tensor & apply_(Tensor & self, PyObject* fn) { } Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) { - if (self.type().backend() != Backend::CPU) { + if (self.options().backend() != Backend::CPU) { throw TypeError("map_ is only implemented on CPU tensors"); } - if (other_.type() != self.type()) { + if (!other_.options().type_equal(self.options())) { throw TypeError("map_: expected %s for 'other' (got %s)", - self.type().toString().c_str(), other_.type().toString().c_str()); + self.toString().c_str(), other_.toString().c_str()); } Tensor other; std::tie(other) = expand_inplace(self, other_, "map_"); @@ -78,16 +78,16 @@ Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) { } Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) { - if (self.type().backend() != Backend::CPU || x_.type().backend() != Backend::CPU || y_.type().backend() != Backend::CPU) { + if (self.options().backend() != Backend::CPU || x_.options().backend() != Backend::CPU || y_.options().backend() != Backend::CPU) { throw TypeError("map2_ is only implemented on CPU tensors"); } - if (x_.type() != self.type()) { + if (!x_.options().type_equal(self.options())) { throw TypeError("map2_: expected %s for argument 'x' (got %s)", - self.type().toString().c_str(), x_.type().toString().c_str()); + self.toString().c_str(), x_.toString().c_str()); } - if (y_.type() != self.type()) { + if (!y_.options().type_equal(self.options())) { throw TypeError("map2_: expected %s for argument 'y' (got %s)", - self.type().toString().c_str(), y_.type().toString().c_str()); + self.toString().c_str(), y_.toString().c_str()); } Tensor other1, other2; std::tie(other1, other2) = expand_inplace(self, x_, y_, "map2_"); diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index fbcb6a67b8e..6ea89fb050a 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -30,7 +30,7 @@ static PyObject* recursive_to_list( PyObject* tensor_to_list(const Tensor& tensor) { Tensor data = tensor; - if (data.type().backend() != Backend::CPU) { + if (data.options().backend() != Backend::CPU) { pybind11::gil_scoped_release no_gil; data = data.toBackend(Backend::CPU); } diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 521ed0e09f8..b6282aed7db 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -81,8 +81,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) { "can't convert sparse tensor to numpy. Use Tensor.to_dense() to " "convert to a dense tensor first."); } - if (tensor.type().backend() != Backend::CPU) { - throw TypeError("NumPy conversion for %s is not supported", tensor.type().toString().c_str()); + if (tensor.options().backend() != Backend::CPU) { + throw TypeError("NumPy conversion for %s is not supported", tensor.toString().c_str()); } if (tensor.requires_grad()) { throw std::runtime_error( diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 5e3cd8d13a4..d9d082622b9 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -28,13 +28,19 @@ static const char* backend_to_string(const at::Backend& backend) { } } +std::string options_to_string(const at::TensorOptions options) { + std::ostringstream ss; + ss << backend_to_string(options.backend()) << "." << toString(at::typeMetaToScalarType(options.dtype())) << "Tensor"; + return ss.str(); +} + std::string type_to_string(const at::DeprecatedTypeProperties& type) { std::ostringstream ss; ss << backend_to_string(type.backend()) << "." << toString(type.scalarType()) << "Tensor"; return ss.str(); } -at::DeprecatedTypeProperties* type_from_string(const std::string& str) { +at::TensorOptions options_from_string(const std::string& str) { static std::string cuda_prefix("torch.cuda."); static std::once_flag cpu_once; static std::once_flag cuda_once; @@ -46,7 +52,7 @@ at::DeprecatedTypeProperties* type_from_string(const std::string& str) { if (str == "torch.Tensor") { auto backend = tensorTypeIdToBackend(torch::tensors::get_default_tensor_type_id()); auto scalar_type = torch::tensors::get_default_scalar_type(); - return &getDeprecatedTypeProperties(backend, scalar_type); + return getDeprecatedTypeProperties(backend, scalar_type).options(); } if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()).first == cuda_prefix.end()) { @@ -70,7 +76,7 @@ at::DeprecatedTypeProperties* type_from_string(const std::string& str) { if (it == map->end()) { throw ValueError("invalid type: '%s'", str.c_str()); } - return it->second; + return it->second->options(); } std::vector> all_declared_types() { diff --git a/torch/csrc/utils/tensor_types.h b/torch/csrc/utils/tensor_types.h index 39bd46a0ecb..86258320b7e 100644 --- a/torch/csrc/utils/tensor_types.h +++ b/torch/csrc/utils/tensor_types.h @@ -6,8 +6,9 @@ namespace torch { namespace utils { +std::string options_to_string(const at::TensorOptions options); std::string type_to_string(const at::DeprecatedTypeProperties& type); -at::DeprecatedTypeProperties* type_from_string(const std::string& str); +at::TensorOptions options_from_string(const std::string& str); // return a vector of all "declared" types, even those that weren't compiled std::vector> all_declared_types(); diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index dc4e3693d97..60c9a0c1f84 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -1297,7 +1297,7 @@ std::shared_ptr ProcessGroupGloo::allreduce_coalesced( // tensors must have the same device, layout and type. assertLayoutMatch(invalidArgument, tensors); if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { - return t.type() == tensors[0].type(); + return t.options().type_equal(tensors[0].options()); })) { invalidArgument("tensors must all have the same type"); } @@ -1670,11 +1670,11 @@ std::shared_ptr ProcessGroupGloo::allgather( assertDense(invalidArgument, inputs); // Expect all input/output tensors to have the same type and sizes - const auto& type = inputs[0].type(); + const auto& options = inputs[0].options(); const auto& sizes = inputs[0].sizes(); - assertTypeAndSizesMatch(invalidArgument, inputs, type, sizes); + assertTypeAndSizesMatch(invalidArgument, inputs, options, sizes); for (size_t i = 0; i < outputs.size(); i++) { - assertTypeAndSizesMatch(invalidArgument, outputs[i], type, sizes); + assertTypeAndSizesMatch(invalidArgument, outputs[i], options, sizes); } const auto& device = inputs[0].device(); @@ -1807,11 +1807,11 @@ std::shared_ptr ProcessGroupGloo::allgather_coalesced( " (expected length " + toString(expected) + ", got " + toString(actual) + ")"); } - if (input_list[i].type() != output_list[i].type()) { + if (!input_list[i].options().type_equal(output_list[i].options())) { invalidArgument( "invalid tensor type at index " + std::to_string(i) + - " (expected " + input_list[i].type().toString() + ", got " + - output_list[i].type().toString() + ")"); + " (expected " + input_list[i].toString() + ", got " + + output_list[i].toString() + ")"); } } } @@ -1992,9 +1992,9 @@ std::shared_ptr ProcessGroupGloo::gather( invalidArgument(ss.str()); } - const auto& type = inputs[0].type(); + const auto& options = inputs[0].options(); const auto& sizes = inputs[0].sizes(); - assertTypeAndSizesMatch(invalidArgument, outputs[0], type, sizes); + assertTypeAndSizesMatch(invalidArgument, outputs[0], options, sizes); } else { if (outputs.size() != 0) { invalidArgument("requires empty output on non-root"); @@ -2178,9 +2178,9 @@ std::shared_ptr ProcessGroupGloo::scatter( << ", same as size of the process group."; invalidArgument(ss.str()); } - const auto& type = outputs[0].type(); + const auto& options = outputs[0].options(); const auto& sizes = outputs[0].sizes(); - assertTypeAndSizesMatch(invalidArgument, inputs[0], type, sizes); + assertTypeAndSizesMatch(invalidArgument, inputs[0], options, sizes); } else { if (inputs.size() != 0) { invalidArgument("requires empty input on non-root"); diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp index 4a81c009912..fc796da799d 100644 --- a/torch/lib/c10d/Utils.hpp +++ b/torch/lib/c10d/Utils.hpp @@ -43,9 +43,9 @@ inline void assertSameType( const at::DeprecatedTypeProperties& type, const std::vector& tensors) { for (size_t i = 0; i < tensors.size(); i++) { - if (tensors[i].type() != type) { + if (!tensors[i].options().type_equal(type.options())) { const std::string expected = type.toString(); - const std::string actual = tensors[i].type().toString(); + const std::string actual = tensors[i].toString(); throw std::invalid_argument( "mixed types (" + expected + " and " + actual + ")"); } @@ -72,12 +72,12 @@ inline void assertSameSizeAndType(const std::vector& tensors) { } // Ensure all tensors have identical type and shape - auto type = tensors[0].type(); + auto options = tensors[0].options(); auto sizes = tensors[0].sizes(); for (size_t i = 1; i < tensors.size(); i++) { - if (tensors[i].type() != type) { - const std::string expected = type.toString(); - const std::string actual = tensors[i].type().toString(); + if (!tensors[i].options().type_equal(options)) { + const auto expected = toString(options); + const auto actual = toString(tensors[i].options()); throw std::invalid_argument( "argument contains mixed types (" + expected + " and " + actual + ")"); @@ -97,12 +97,24 @@ inline void assertTypeMatch( const at::DeprecatedTypeProperties& type, const at::ArrayRef& tensors, size_t index) { - if (tensors[index].type() != type) { + if (!tensors[index].options().type_equal(type.options())) { fn("invalid tensor type at index " + std::to_string(index) + " (expected " + - type.toString() + ", got " + tensors[index].type().toString() + ")"); + type.toString() + ", got " + tensors[index].toString() + ")"); } } +inline void assertTypeMatch( + std::function fn, + const at::TensorOptions& options, + const at::ArrayRef& tensors, + size_t index) { + if (!tensors[index].options().type_equal(options)) { + fn("invalid tensor type at index " + std::to_string(index) + " (expected " + + toString(options) + ", got " + toString(tensors[index].options()) + ")"); + } +} + + inline void assertSizesMatch( std::function fn, const at::IntArrayRef& sizes, @@ -228,12 +240,23 @@ inline void assertTypeAndSizesMatch( } } +inline void assertTypeAndSizesMatch( + std::function fn, + const at::ArrayRef& tensors, + const at::TensorOptions& options, + const at::IntArrayRef& sizes) { + for (size_t i = 0; i < tensors.size(); i++) { + assertTypeMatch(fn, options, tensors, i); + assertSizesMatch(fn, sizes, tensors, i); + } +} + inline void assertTypeAndSizesMatch( std::function fn, const at::ArrayRef& tensors) { - const auto& type = tensors[0].type(); + const auto& options = tensors[0].options(); const auto sizes = tensors[0].sizes(); - assertTypeAndSizesMatch(fn, tensors.slice(1), type, sizes); + assertTypeAndSizesMatch(fn, tensors.slice(1), options, sizes); } // Copied from ATen/core/functional.h. @@ -303,7 +326,7 @@ inline std::vector> getSizes( inline std::vector getDevices(const std::vector& tensors) { std::vector devices(tensors.size(), -1); - if (tensors[0].type().is_cuda()) { + if (tensors[0].device().is_cuda()) { for (size_t i = 0; i < tensors.size(); i++) { devices[i] = tensors[i].storage().device().index(); }