mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18165 ghimport-source-id: 55cb3fb63a25c2faab1725b4ec14c688bf45bd38 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18166 Bool Tensor for CUDA * **#18165 Resolved comments from Bool Tensor for CPU PR** ------- ------------ This is a follow up PR that resolves some additional feedback on one the of previous Bool Tensor PRs. gchanan, here is a list of almost all the comments from the original PR with respective fixes and replies: **[utils/python_scalars.h]** why is this converting from uint8_t and not bool? (comment?) When i was adding this, i was testing by creating a tensor and then calling its .tolist(). it worked for bool and uint8_t equally good so i left uint8_t as thought it makes more sense as we are calling PyBool_FromLong. �Changing it to bool. **[ATen/Dispatch.h]**better name?. fixed. **[test/test_torch.py]** what about other factories, such as full? (and more). There is a test that goes through the factory methods - test_tensor_factories_empty. i added some bool cases above it and added a comment that once CUDA will be done, i will unite them and it will iterate not just between CUDA and CPU but also all types. ��Adding all bool cases now. Will unite in CUDA PR. **[generic/THTensorMath.h]** any changes in this file actually needed? Bad merge. Fixed. **[TH/THTensor.h]** this generates code for random, clampedRandom, and cappedRandom -- do we have tests for all of these with bool? Added **[c10/core/ScalarType.h]** I'm not very confident about the lack of Bool here -- can you look at the call sites and see what makes sense to do here? Added bool to the macro and created a similar one without for a single case which fails the build with errors: _./torch/csrc/jit/symbolic_variable.h:79:20: error: ambiguous overload for ‘operator*’ (operand types are ‘const torch::jit::SymbolicVariable’ and ‘torch::jit::Value*’) return (*this) * insertConstant(rhs);_ Differential Revision: D14605105 fbshipit-source-id: abf82d50e8f8c50b386545ac068268651b28496d
48 lines
2.3 KiB
C++
48 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
|
|
namespace torch { namespace utils {
|
|
|
|
inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
|
switch (scalarType) {
|
|
case at::kByte: *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj); break;
|
|
case at::kChar: *(char*)data = (char)THPUtils_unpackLong(obj); break;
|
|
case at::kShort: *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj); break;
|
|
case at::kInt: *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj); break;
|
|
case at::kLong: *(int64_t*)data = THPUtils_unpackLong(obj); break;
|
|
case at::kHalf:
|
|
*(at::Half*)data = at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
|
|
break;
|
|
case at::kFloat: *(float*)data = (float)THPUtils_unpackDouble(obj); break;
|
|
case at::kDouble: *(double*)data = THPUtils_unpackDouble(obj); break;
|
|
case at::kComplexFloat: *(std::complex<float>*)data = (std::complex<float>)THPUtils_unpackComplexDouble(obj); break;
|
|
case at::kComplexDouble: *(std::complex<double>*)data = THPUtils_unpackComplexDouble(obj); break;
|
|
case at::kBool: *(bool*)data = (bool)THPUtils_unpackLong(obj); break;
|
|
default: throw std::runtime_error("invalid type");
|
|
}
|
|
}
|
|
|
|
inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
|
|
switch (scalarType) {
|
|
case at::kByte: return THPUtils_packInt64(*(uint8_t*)data);
|
|
case at::kChar: return THPUtils_packInt64(*(char*)data);
|
|
case at::kShort: return THPUtils_packInt64(*(int16_t*)data);
|
|
case at::kInt: return THPUtils_packInt64(*(int32_t*)data);
|
|
case at::kLong: return THPUtils_packInt64(*(int64_t*)data);
|
|
case at::kHalf: return PyFloat_FromDouble(at::convert<double, at::Half>(*(at::Half*)data));
|
|
case at::kFloat: return PyFloat_FromDouble(*(float*)data);
|
|
case at::kDouble: return PyFloat_FromDouble(*(double*)data);
|
|
case at::kComplexFloat: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<float>*)data));
|
|
case at::kComplexDouble: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<double>*)data));
|
|
case at::kBool: return PyBool_FromLong(*(bool*)data);
|
|
default: throw std::runtime_error("invalid type");
|
|
}
|
|
}
|
|
|
|
}} // namespace torch::utils
|