mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add various uninterpreted bit tensor data types (try 2) (#95860)
Summary: This is a retry of https://github.com/pytorch/pytorch/pull/94992 which was reverted due to CI issues. This PR adds a set of unintrepreted data types on PyTorch which can be used to implement experimental functionality out of core (think fp8, int4, int16 quant, etc). @bypass-github-export-checks Test Plan: ``` python test/test_quantization.py -k TestBits ``` Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/95860 Approved by: https://github.com/atalman
This commit is contained in:
parent
5e1067bcc2
commit
dc70e8175f
|
|
@ -60,6 +60,13 @@ DLDataType getDLDataType(const Tensor& t) {
|
|||
case ScalarType::QUInt2x4:
|
||||
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Bits1x8:
|
||||
case ScalarType::Bits2x4:
|
||||
case ScalarType::Bits4x2:
|
||||
case ScalarType::Bits8:
|
||||
case ScalarType::Bits16:
|
||||
TORCH_CHECK(false, "Bit types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Undefined:
|
||||
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
|
||||
case ScalarType::NumOptions:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/bits.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/qint32.h>
|
||||
#include <c10/util/qint8.h>
|
||||
|
|
@ -43,7 +44,12 @@ namespace c10 {
|
|||
_(c10::qint32, QInt32) /* 14 */ \
|
||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||
_(c10::bits8, Bits8) /* 21 */ \
|
||||
_(c10::bits16, Bits16) /* 22 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
|
|
@ -270,6 +276,12 @@ static inline bool isQIntType(ScalarType t) {
|
|||
t == ScalarType::QUInt2x4;
|
||||
}
|
||||
|
||||
static inline bool isBitsType(ScalarType t) {
|
||||
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
|
||||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
|
||||
t == ScalarType::Bits16;
|
||||
}
|
||||
|
||||
static inline ScalarType toQIntType(ScalarType t) {
|
||||
switch (t) {
|
||||
case ScalarType::Byte:
|
||||
|
|
@ -307,6 +319,12 @@ static inline bool isSignedType(ScalarType t) {
|
|||
return std::numeric_limits<ctype>::is_signed;
|
||||
|
||||
switch (t) {
|
||||
case ScalarType::Bits1x8:
|
||||
case ScalarType::Bits2x4:
|
||||
case ScalarType::Bits4x2:
|
||||
case ScalarType::Bits8:
|
||||
case ScalarType::Bits16:
|
||||
TORCH_CHECK(false, "Bits types are undefined");
|
||||
case ScalarType::ComplexHalf:
|
||||
case ScalarType::ComplexFloat:
|
||||
case ScalarType::ComplexDouble:
|
||||
|
|
@ -421,11 +439,24 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|||
toString(b));
|
||||
}
|
||||
|
||||
if (isBitsType(a) && a == b) {
|
||||
return a;
|
||||
} else if (isBitsType(a) || isBitsType(b)) {
|
||||
return ScalarType::Undefined;
|
||||
}
|
||||
|
||||
// Ignore the 5 bits types, since they are handled by the if statement
|
||||
// above and do not participate in type promotion. The `5` value has to
|
||||
// be consistent with the number of the unique `c10::bits*` types that
|
||||
// exist.
|
||||
const int NUM_PROMOTE_TYPES = static_cast<int>(ScalarType::NumOptions) - 5;
|
||||
|
||||
// this matrix has to be consistent with
|
||||
// AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
|
||||
// are not sure about the correct value for type promotion.
|
||||
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
|
||||
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
|
||||
// clang-format off
|
||||
static constexpr ScalarType _promoteTypesLookup[
|
||||
NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
|
||||
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
|
||||
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
|
||||
|
|
@ -444,6 +475,7 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|||
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
|
||||
};
|
||||
// clang-format on
|
||||
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
|
||||
}
|
||||
|
||||
|
|
|
|||
61
c10/util/bits.h
Normal file
61
c10/util/bits.h
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/**
|
||||
* bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte
|
||||
* boundary), without any semantics defined.
|
||||
*/
|
||||
struct alignas(1) bits1x8 {
|
||||
using underlying = uint8_t;
|
||||
uint8_t val_;
|
||||
bits1x8() = default;
|
||||
C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte
|
||||
* boundary), without any semantics defined.
|
||||
*/
|
||||
struct alignas(1) bits2x4 {
|
||||
using underlying = uint8_t;
|
||||
uint8_t val_;
|
||||
bits2x4() = default;
|
||||
C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte
|
||||
* boundary), without any semantics defined.
|
||||
*/
|
||||
struct alignas(1) bits4x2 {
|
||||
using underlying = uint8_t;
|
||||
uint8_t val_;
|
||||
bits4x2() = default;
|
||||
C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* bits8 is an uninterpreted dtype of a tensor with 8 bits, without any
|
||||
* semantics defined.
|
||||
*/
|
||||
struct alignas(1) bits8 {
|
||||
uint8_t val_;
|
||||
bits8() = default;
|
||||
C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* bits16 is an uninterpreted dtype of a tensor with 16 bits, without any
|
||||
* semantics defined.
|
||||
*/
|
||||
struct alignas(2) bits16 {
|
||||
uint16_t val_;
|
||||
bits16() = default;
|
||||
C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
58
test/quantization/core/experimental/test_bits.py
Normal file
58
test/quantization/core/experimental/test_bits.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
class Int16Tensor(torch.Tensor):
|
||||
def __new__(cls, elem):
|
||||
assert elem.dtype == torch.bits16
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
def __init__(self, elem):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
with no_dispatch():
|
||||
return t.view(torch.int16)
|
||||
return t
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
with no_dispatch():
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
def wrap(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
with no_dispatch():
|
||||
return t.view(torch.bits16)
|
||||
return t
|
||||
out = tree_map(wrap, out)
|
||||
return out
|
||||
|
||||
def __repr__(self) -> str:
|
||||
with no_dispatch():
|
||||
t16 = self.view(torch.int16)
|
||||
return f"TensorSubclassDemo{self.view(torch.int16)}"
|
||||
|
||||
|
||||
class TestBits(TestCase):
|
||||
def test_types(self):
|
||||
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
|
||||
for bits_type in bits_types:
|
||||
_ = torch.zeros(20, dtype=torch.int32).view(bits_type)
|
||||
_ = torch.empty(20, dtype=bits_type)
|
||||
|
||||
def test_subclass(self):
|
||||
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
|
||||
s = Int16Tensor(t)
|
||||
s = s + 1 - 1
|
||||
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -134,5 +134,8 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
# Experimental functionality
|
||||
from quantization.core.experimental.test_bits import TestBits # noqa: F401
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -52,6 +52,16 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
|
|||
return std::make_pair("quint4x2", "");
|
||||
case at::ScalarType::QUInt2x4:
|
||||
return std::make_pair("quint2x4", "");
|
||||
case at::ScalarType::Bits1x8:
|
||||
return std::make_pair("bits1x8", "");
|
||||
case at::ScalarType::Bits2x4:
|
||||
return std::make_pair("bits2x4", "");
|
||||
case at::ScalarType::Bits4x2:
|
||||
return std::make_pair("bits4x2", "");
|
||||
case at::ScalarType::Bits8:
|
||||
return std::make_pair("bits8", "");
|
||||
case at::ScalarType::Bits16:
|
||||
return std::make_pair("bits16", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user