mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
This PR enables bool tensor creation and some basic operations for the CPU backend. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:
1. Storage Implementation [Done]
2. Tensor Creation.
a) CPU (this PR)
b) CUDA
3. Tensor Conversions.
4. Tensor Indexing.
5. Tensor Operations.
6. Back compatibility related changes.
**Change**:
Enable CPU tensors and these operations:
- torch.zeros
- torch.tensor
- torch.ones
- torch.randint
- torch.full
- torch.full_like
- torch.empty
- torch.empty_like
**Tested via**:
1) unit tests
2)
torch.zeros(2,2, dtype=torch.bool)
torch.tensor([True, False], dtype=torch.bool)
torch.tensor([-1, -1.1, 0, 1, 1.1, 2], dtype=torch.bool)
torch.ones([1,2], dtype=torch.bool)
torch.randint(10, (2, 2), dtype=torch.bool)
torch.full((2, 3), True, dtype=torch.bool)
torch.empty(4, dtype=torch.bool)
a = torch.tensor([0,0,1])
b = torch.full_like(a, True)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17376
Reviewed By: ezyang
Differential Revision: D14375995
Pulled By: izdeby
fbshipit-source-id: a65490b5360ee0e6e3accc54ce7e32e49ad2d2a8
48 lines
2.3 KiB
C++
48 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
|
|
namespace torch { namespace utils {
|
|
|
|
inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
|
switch (scalarType) {
|
|
case at::kByte: *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj); break;
|
|
case at::kChar: *(char*)data = (char)THPUtils_unpackLong(obj); break;
|
|
case at::kShort: *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj); break;
|
|
case at::kInt: *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj); break;
|
|
case at::kLong: *(int64_t*)data = THPUtils_unpackLong(obj); break;
|
|
case at::kHalf:
|
|
*(at::Half*)data = at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
|
|
break;
|
|
case at::kFloat: *(float*)data = (float)THPUtils_unpackDouble(obj); break;
|
|
case at::kDouble: *(double*)data = THPUtils_unpackDouble(obj); break;
|
|
case at::kComplexFloat: *(std::complex<float>*)data = (std::complex<float>)THPUtils_unpackComplexDouble(obj); break;
|
|
case at::kComplexDouble: *(std::complex<double>*)data = THPUtils_unpackComplexDouble(obj); break;
|
|
case at::kBool: *(bool*)data = (bool)THPUtils_unpackLong(obj); break;
|
|
default: throw std::runtime_error("invalid type");
|
|
}
|
|
}
|
|
|
|
inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
|
|
switch (scalarType) {
|
|
case at::kByte: return THPUtils_packInt64(*(uint8_t*)data);
|
|
case at::kChar: return THPUtils_packInt64(*(char*)data);
|
|
case at::kShort: return THPUtils_packInt64(*(int16_t*)data);
|
|
case at::kInt: return THPUtils_packInt64(*(int32_t*)data);
|
|
case at::kLong: return THPUtils_packInt64(*(int64_t*)data);
|
|
case at::kHalf: return PyFloat_FromDouble(at::convert<double, at::Half>(*(at::Half*)data));
|
|
case at::kFloat: return PyFloat_FromDouble(*(float*)data);
|
|
case at::kDouble: return PyFloat_FromDouble(*(double*)data);
|
|
case at::kComplexFloat: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<float>*)data));
|
|
case at::kComplexDouble: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<double>*)data));
|
|
case at::kBool: return PyBool_FromLong(*(uint8_t*)data);
|
|
default: throw std::runtime_error("invalid type");
|
|
}
|
|
}
|
|
|
|
}} // namespace torch::utils
|