mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Bug fixes: torch::tensor(floating-point values) -> default dtype, and torch::tensor(integer values) ->at::kLong (#32367)
Summary: Some of the `torch::tensor` behavior is updated to better match Python API. Fixes https://github.com/pytorch/pytorch/issues/32234. This PR is BC-breaking in the following way: - `torch::tensor({1.0f, 2.0f})`: float -> default dtype - `torch::tensor(at::ArrayRef<int>({1, 2, 3}))`: int -> at::kLong - `torch::tensor(std::vector<int>({1, 2, 3}))`: int -> at::kLong - `torch::tensor(at::ArrayRef<float>({1.f, 2.f, 3.f}))`: float -> default dtype - `torch::tensor(std::vector<float>({1.f, 2.f, 3.f}))`: float -> default dtype - `torch::tensor(at::ArrayRef<double>({1., 2., 3.}))`: double -> default dtype - `torch::tensor(std::vector<double>({1., 2., 3.}))`: double -> default dtype Pull Request resolved: https://github.com/pytorch/pytorch/pull/32367 Differential Revision: D19498484 Pulled By: yf225 fbshipit-source-id: 19c8dc2a56476266153cff4c404e7f84d309eb12
This commit is contained in:
parent
4cc6e6bbbe
commit
b564eaf7a8
|
|
@ -105,10 +105,10 @@ void assert_eq(T val, T act, T exp) {
|
|||
}
|
||||
|
||||
template<typename Vals, typename Pows>
|
||||
void tensor_pow_scalar(const Vals vals, const Pows pows) {
|
||||
void tensor_pow_scalar(const Vals vals, c10::ScalarType vals_dtype, const Pows pows, c10::ScalarType pows_dtype) {
|
||||
using T = typename Vals::value_type;
|
||||
|
||||
const auto tensor = torch::tensor(vals);
|
||||
const auto tensor = torch::tensor(vals, vals_dtype);
|
||||
|
||||
for (const auto pow : pows) {
|
||||
auto actual_pow = tensor.pow(pow);
|
||||
|
|
@ -144,10 +144,10 @@ void tensor_pow_scalar(const Vals vals, const Pows pows) {
|
|||
}
|
||||
|
||||
template<typename Vals, typename Pows>
|
||||
void scalar_pow_tensor(const Vals vals, const Pows pows) {
|
||||
void scalar_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, const Pows pows, c10::ScalarType pows_dtype) {
|
||||
using T = typename Pows::value_type;
|
||||
|
||||
const auto pow_tensor = torch::tensor(pows);
|
||||
const auto pow_tensor = torch::tensor(pows, pows_dtype);
|
||||
|
||||
for (const auto val : vals) {
|
||||
const auto actual_pow = torch::pow(val, pow_tensor);
|
||||
|
|
@ -175,15 +175,15 @@ void scalar_pow_tensor(const Vals vals, const Pows pows) {
|
|||
}
|
||||
|
||||
template<typename Vals, typename Pows>
|
||||
void tensor_pow_tensor(const Vals vals, Pows pows) {
|
||||
void tensor_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, Pows pows, c10::ScalarType pows_dtype) {
|
||||
using T = typename Vals::value_type;
|
||||
|
||||
typedef std::numeric_limits< double > dbl;
|
||||
std::cout.precision(dbl::max_digits10);
|
||||
|
||||
const auto vals_tensor = torch::tensor(vals);
|
||||
const auto vals_tensor = torch::tensor(vals, vals_dtype);
|
||||
for (size_t shift = 0; shift < pows.size(); shift++) {
|
||||
const auto pows_tensor = torch::tensor(pows);
|
||||
const auto pows_tensor = torch::tensor(pows, pows_dtype);
|
||||
|
||||
const auto actual_pow = vals_tensor.pow(pows_tensor);
|
||||
|
||||
|
|
@ -222,67 +222,67 @@ void tensor_pow_tensor(const Vals vals, Pows pows) {
|
|||
}
|
||||
|
||||
TEST(PowTest, IntTensorPowAllScalars) {
|
||||
tensor_pow_scalar(ints, non_neg_ints);
|
||||
tensor_pow_scalar(ints, non_neg_longs);
|
||||
tensor_pow_scalar(ints, floats);
|
||||
tensor_pow_scalar(ints, doubles);
|
||||
tensor_pow_scalar(ints, c10::kInt, non_neg_ints, c10::kInt);
|
||||
tensor_pow_scalar(ints, c10::kInt, non_neg_longs, c10::kLong);
|
||||
tensor_pow_scalar(ints, c10::kInt, floats, c10::kFloat);
|
||||
tensor_pow_scalar(ints, c10::kInt, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, LongTensorPowAllScalars) {
|
||||
tensor_pow_scalar(longs, non_neg_ints);
|
||||
tensor_pow_scalar(longs, non_neg_longs);
|
||||
tensor_pow_scalar(longs, floats);
|
||||
tensor_pow_scalar(longs, doubles);
|
||||
tensor_pow_scalar(longs, c10::kLong, non_neg_ints, c10::kInt);
|
||||
tensor_pow_scalar(longs, c10::kLong, non_neg_longs, c10::kLong);
|
||||
tensor_pow_scalar(longs, c10::kLong, floats, c10::kFloat);
|
||||
tensor_pow_scalar(longs, c10::kLong, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, FloatTensorPowAllScalars) {
|
||||
tensor_pow_scalar(floats, ints);
|
||||
tensor_pow_scalar(floats, longs);
|
||||
tensor_pow_scalar(floats, floats);
|
||||
tensor_pow_scalar(floats, doubles);
|
||||
tensor_pow_scalar(floats, c10::kFloat, ints, c10::kInt);
|
||||
tensor_pow_scalar(floats, c10::kFloat, longs, c10::kLong);
|
||||
tensor_pow_scalar(floats, c10::kFloat, floats, c10::kFloat);
|
||||
tensor_pow_scalar(floats, c10::kFloat, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, DoubleTensorPowAllScalars) {
|
||||
tensor_pow_scalar(doubles, ints);
|
||||
tensor_pow_scalar(doubles, longs);
|
||||
tensor_pow_scalar(doubles, floats);
|
||||
tensor_pow_scalar(doubles, doubles);
|
||||
tensor_pow_scalar(doubles, c10::kDouble, ints, c10::kInt);
|
||||
tensor_pow_scalar(doubles, c10::kDouble, longs, c10::kLong);
|
||||
tensor_pow_scalar(doubles, c10::kDouble, floats, c10::kFloat);
|
||||
tensor_pow_scalar(doubles, c10::kDouble, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, IntScalarPowAllTensors) {
|
||||
scalar_pow_tensor(ints, ints);
|
||||
scalar_pow_tensor(ints, longs);
|
||||
scalar_pow_tensor(ints, floats);
|
||||
scalar_pow_tensor(ints, doubles);
|
||||
scalar_pow_tensor(ints, c10::kInt, ints, c10::kInt);
|
||||
scalar_pow_tensor(ints, c10::kInt, longs, c10::kLong);
|
||||
scalar_pow_tensor(ints, c10::kInt, floats, c10::kFloat);
|
||||
scalar_pow_tensor(ints, c10::kInt, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, LongScalarPowAllTensors) {
|
||||
scalar_pow_tensor(longs, longs);
|
||||
scalar_pow_tensor(longs, floats);
|
||||
scalar_pow_tensor(longs, doubles);
|
||||
scalar_pow_tensor(longs, c10::kLong, longs, c10::kLong);
|
||||
scalar_pow_tensor(longs, c10::kLong, floats, c10::kFloat);
|
||||
scalar_pow_tensor(longs, c10::kLong, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, FloatScalarPowAllTensors) {
|
||||
scalar_pow_tensor(floats, floats);
|
||||
scalar_pow_tensor(floats, doubles);
|
||||
scalar_pow_tensor(floats, c10::kFloat, floats, c10::kFloat);
|
||||
scalar_pow_tensor(floats, c10::kFloat, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, DoubleScalarPowAllTensors) {
|
||||
scalar_pow_tensor(doubles, doubles);
|
||||
scalar_pow_tensor(doubles, c10::kDouble, doubles, c10::kDouble);
|
||||
}
|
||||
|
||||
TEST(PowTest, IntTensorPowIntTensor) {
|
||||
tensor_pow_tensor(ints, ints);
|
||||
tensor_pow_tensor(ints, c10::kInt, ints, c10::kInt);
|
||||
}
|
||||
|
||||
TEST(PowTest, LongTensorPowLongTensor) {
|
||||
tensor_pow_tensor(longs, longs);
|
||||
tensor_pow_tensor(longs, c10::kLong, longs, c10::kLong);
|
||||
}
|
||||
|
||||
TEST(PowTest, FloatTensorPowFloatTensor) {
|
||||
tensor_pow_tensor(floats, floats);
|
||||
tensor_pow_tensor(floats, c10::kFloat, floats, c10::kFloat);
|
||||
}
|
||||
|
||||
TEST(PowTest, DoubleTensorPowDoubleTensor) {
|
||||
tensor_pow_tensor(doubles, doubles);
|
||||
tensor_pow_tensor(doubles, c10::kDouble, doubles, c10::kDouble);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
|
|||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_EQ(tensor.item<int32_t>(), 123);
|
||||
ASSERT_EQ(tensor.item<int64_t>(), 123);
|
||||
}
|
||||
|
||||
void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
||||
|
|
@ -236,7 +236,7 @@ void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType defau
|
|||
auto tensor = torch::tensor(123.456f);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor, 123.456f));
|
||||
|
||||
tensor = torch::tensor(123.456);
|
||||
|
|
@ -283,7 +283,7 @@ TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
|
|||
tensor = torch::tensor(at::ArrayRef<int>({1, 2, 3}));
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kInt);
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
||||
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
|
|
@ -291,7 +291,7 @@ TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
|
|||
tensor = torch::tensor(std::vector<int>({1, 2, 3}));
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kInt);
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
||||
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
|
|
@ -328,31 +328,31 @@ void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType de
|
|||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5f));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25f));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125f));
|
||||
|
||||
tensor = torch::tensor(at::ArrayRef<float>({1.5, 2.25, 3.125}));
|
||||
tensor = torch::tensor(at::ArrayRef<float>({1.5f, 2.25f, 3.125f}));
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
|
||||
tensor = torch::tensor(std::vector<float>({1.5, 2.25, 3.125}));
|
||||
tensor = torch::tensor(std::vector<float>({1.5f, 2.25f, 3.125f}));
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
|
||||
tensor = torch::tensor(at::ArrayRef<double>({1.5, 2.25, 3.125}));
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
|
|
@ -360,7 +360,7 @@ void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType de
|
|||
tensor = torch::tensor(std::vector<double>({1.5, 2.25, 3.125}));
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
|
|
@ -511,19 +511,19 @@ TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
|
|||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Long, but got element of scalar type: Float");
|
||||
"Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Float");
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true}, {2}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Long");
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true, 2}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Long");
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -624,6 +624,27 @@ TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
|
|||
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
void test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor(at::ArrayRef<int>({1, 2, 3}), torch::TensorOptions()).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor(std::vector<int>({1, 2, 3}), torch::TensorOptions()).dtype(), torch::kLong);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor(at::ArrayRef<double>({1., 2., 3.}), torch::TensorOptions()).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor(std::vector<double>({1., 2., 3.}), torch::TensorOptions()).dtype(), default_dtype);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1.f, 2.f, 3.f}, torch::TensorOptions()).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor(at::ArrayRef<float>({1.f, 2.f, 3.f}), torch::TensorOptions()).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor(std::vector<float>({1.f, 2.f, 3.f}), torch::TensorOptions()).dtype(), default_dtype);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorWithNonDtypeOptions) {
|
||||
test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
void test_Arange_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
|
|
|
|||
|
|
@ -25,15 +25,13 @@ namespace torch {
|
|||
/// the largest data type that can represent all of the elements, or by using
|
||||
/// variadic templates.
|
||||
///
|
||||
/// NOTE: C++ `torch::tensor` by default gives a double tensor, which is
|
||||
/// different from Python `torch.tensor` that gives a float tensor by default.
|
||||
/// We are going to fix this discrepancy by making `torch::tensor` give
|
||||
/// a float tensor by default.
|
||||
/// Tracking issue: https://github.com/pytorch/pytorch/issues/28902
|
||||
/// NOTE: C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` /
|
||||
/// (nested) braced-init-list of floating-point types always produces a tensor of dtype
|
||||
/// `torch::get_default_dtype()`, matching Python `torch.tensor` behavior.
|
||||
///
|
||||
/// NOTE: C++ `torch::tensor` with an integer literal or a braced-init-list of
|
||||
/// integer literals always produces a tensor of dtype `at::kLong` (aka. int64_t),
|
||||
/// matching Python `torch.tensor` behavior.
|
||||
/// NOTE: C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` /
|
||||
/// (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong`
|
||||
/// (aka. int64_t), matching Python `torch.tensor` behavior.
|
||||
///
|
||||
/// NOTE: The following dtypes are not supported by `torch::tensor` currently:
|
||||
/// - `unsigned int`
|
||||
|
|
|
|||
|
|
@ -24,21 +24,15 @@ inline std::ostream& operator<<(std::ostream& stream, c10::BFloat16 value) {
|
|||
}
|
||||
|
||||
inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
|
||||
// NOTE: the dtype computation in this function only takes effect when the user passes
|
||||
// an integer literal / floating-point literal or a braced-init-list to `torch::tensor`
|
||||
// constructor. It doesn't affect `torch::tensor(at::ArrayRef<T>)` and `torch::tensor(std::vector<T>)`
|
||||
// as the specified dtype `T` is always respected.
|
||||
if (scalar_type == at::kInt || scalar_type == at::kLong) {
|
||||
// In C++, an integer literal without suffix (e.g. `1` instead of `1u`) can be one of
|
||||
// `int` / `long int` / `long long int` types. When we find that `scalar_type` is one
|
||||
// of those types, we always use `torch.int64` type, because In Python `torch.tensor(1)`
|
||||
// always gives a tensor of `torch.int64` dtype.
|
||||
// C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` /
|
||||
// (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong`
|
||||
// (aka. int64_t), matching Python `torch.tensor` behavior.
|
||||
return at::kLong;
|
||||
} else if (scalar_type == at::kDouble) {
|
||||
// When `scalar_type == at::kDouble`, we know that the user is passing in
|
||||
// a floating-point literal without specifying its type (e.g. `1.0` instead of `1.0f`).
|
||||
// In Python, the dtype of `torch.tensor(1.0)` depends on the value of
|
||||
// `torch.get_default_dtype()`, and we should do the same for C++ `torch::tensor(1.0)`.
|
||||
} else if (scalar_type == at::kFloat || scalar_type == at::kDouble) {
|
||||
// C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` /
|
||||
// (nested) braced-init-list of floating-point types always produces a tensor of dtype
|
||||
// `torch::get_default_dtype()`, matching Python `torch.tensor` behavior.
|
||||
return at::typeMetaToScalarType(at::get_default_dtype());
|
||||
} else {
|
||||
return scalar_type;
|
||||
|
|
@ -106,7 +100,7 @@ struct TensorDataContainer {
|
|||
#define TENSOR(T, S) \
|
||||
TensorDataContainer(T value) : \
|
||||
sizes_(), \
|
||||
scalar_type_(compute_desired_dtype(at::k##S)), \
|
||||
scalar_type_(at::k##S), \
|
||||
type_(TensorDataContainerType::Scalar), \
|
||||
scalar_(value) {}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
||||
|
|
@ -154,7 +148,7 @@ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
|||
|
||||
// NOTE: We need to handle `std::vector` explicitly instead of relying on an implicit conversion
|
||||
// to `at::ArrayRef`, otherwise the following error can be thrown when calling
|
||||
// `torch::tensor(std::vector<double>({1.1, 2.2}))`:
|
||||
// `torch::tensor(std::vector<int>({1, 2}))`:
|
||||
// ```
|
||||
// error: no matching function for call to 'tensor(const std::vector<int>&)'
|
||||
// no known conversion for argument 1 from 'const std::vector<int>' to
|
||||
|
|
@ -211,7 +205,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR)
|
|||
|
||||
at::Tensor convert_to_tensor(at::TensorOptions options) const {
|
||||
if (!options.has_dtype()) {
|
||||
options = options.dtype(scalar_type_);
|
||||
options = options.dtype(compute_desired_dtype(scalar_type_));
|
||||
}
|
||||
|
||||
if (is_scalar()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user