Add various uninterpreted bit tensor data types (try 2) (#95860)

Summary:

This is a retry of https://github.com/pytorch/pytorch/pull/94992 which was reverted due to CI issues.

This PR adds a set of unintrepreted data types on PyTorch which can be used to implement experimental functionality out of core (think fp8, int4, int16 quant, etc).

@bypass-github-export-checks

Test Plan:

```
python test/test_quantization.py -k TestBits
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95860
Approved by: https://github.com/atalman
This commit is contained in:
vasiliy 2023-03-04 03:35:59 +00:00 committed by PyTorch MergeBot
parent 5e1067bcc2
commit dc70e8175f
6 changed files with 174 additions and 3 deletions

View File

@ -60,6 +60,13 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bit types are not supported by dlpack");
break;
case ScalarType::Undefined:
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
case ScalarType::NumOptions:

View File

@ -3,6 +3,7 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
@ -43,7 +44,12 @@ namespace c10 {
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
@ -270,6 +276,12 @@ static inline bool isQIntType(ScalarType t) {
t == ScalarType::QUInt2x4;
}
static inline bool isBitsType(ScalarType t) {
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
t == ScalarType::Bits16;
}
static inline ScalarType toQIntType(ScalarType t) {
switch (t) {
case ScalarType::Byte:
@ -307,6 +319,12 @@ static inline bool isSignedType(ScalarType t) {
return std::numeric_limits<ctype>::is_signed;
switch (t) {
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bits types are undefined");
case ScalarType::ComplexHalf:
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
@ -421,11 +439,24 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
toString(b));
}
if (isBitsType(a) && a == b) {
return a;
} else if (isBitsType(a) || isBitsType(b)) {
return ScalarType::Undefined;
}
// Ignore the 5 bits types, since they are handled by the if statement
// above and do not participate in type promotion. The `5` value has to
// be consistent with the number of the unique `c10::bits*` types that
// exist.
const int NUM_PROMOTE_TYPES = static_cast<int>(ScalarType::NumOptions) - 5;
// this matrix has to be consistent with
// AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
// are not sure about the correct value for type promotion.
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
// clang-format off
static constexpr ScalarType _promoteTypesLookup[
NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
@ -444,6 +475,7 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
};
// clang-format on
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}

61
c10/util/bits.h Normal file
View File

@ -0,0 +1,61 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits1x8 {
using underlying = uint8_t;
uint8_t val_;
bits1x8() = default;
C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {}
};
/**
* bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits2x4 {
using underlying = uint8_t;
uint8_t val_;
bits2x4() = default;
C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {}
};
/**
* bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits4x2 {
using underlying = uint8_t;
uint8_t val_;
bits4x2() = default;
C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {}
};
/**
* bits8 is an uninterpreted dtype of a tensor with 8 bits, without any
* semantics defined.
*/
struct alignas(1) bits8 {
uint8_t val_;
bits8() = default;
C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {}
};
/**
* bits16 is an uninterpreted dtype of a tensor with 16 bits, without any
* semantics defined.
*/
struct alignas(2) bits16 {
uint16_t val_;
bits16() = default;
C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,58 @@
# Owner(s): ["oncall: quantization"]
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map
class Int16Tensor(torch.Tensor):
def __new__(cls, elem):
assert elem.dtype == torch.bits16
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
def __init__(self, elem):
super().__init__()
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, torch.Tensor):
with no_dispatch():
return t.view(torch.int16)
return t
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
with no_dispatch():
out = func(*args, **kwargs)
def wrap(t):
if isinstance(t, torch.Tensor):
with no_dispatch():
return t.view(torch.bits16)
return t
out = tree_map(wrap, out)
return out
def __repr__(self) -> str:
with no_dispatch():
t16 = self.view(torch.int16)
return f"TensorSubclassDemo{self.view(torch.int16)}"
class TestBits(TestCase):
def test_types(self):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
_ = torch.zeros(20, dtype=torch.int32).view(bits_type)
_ = torch.empty(20, dtype=bits_type)
def test_subclass(self):
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
s = Int16Tensor(t)
s = s + 1 - 1
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16)))
if __name__ == '__main__':
run_tests()

View File

@ -134,5 +134,8 @@ try:
except ImportError:
pass
# Experimental functionality
from quantization.core.experimental.test_bits import TestBits # noqa: F401
if __name__ == '__main__':
run_tests()

View File

@ -52,6 +52,16 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
return std::make_pair("quint4x2", "");
case at::ScalarType::QUInt2x4:
return std::make_pair("quint2x4", "");
case at::ScalarType::Bits1x8:
return std::make_pair("bits1x8", "");
case at::ScalarType::Bits2x4:
return std::make_pair("bits2x4", "");
case at::ScalarType::Bits4x2:
return std::make_pair("bits4x2", "");
case at::ScalarType::Bits8:
return std::make_pair("bits8", "");
case at::ScalarType::Bits16:
return std::make_pair("bits16", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}