mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Support float4 (#151069)
- Support exporting float4 models (note: currently we use IR version 10 universally in the exporter, which does not include float 4 support. Eventually when onnx runtime and the ecosystem moves to support the new IR version 11 we should bump our version to 11 in the exporter as well) - The shape of the type is set according to https://github.com/pytorch/pytorch/pull/148791#discussion_r2038704986 (added last dim with size 2) - Use ml_dtypes types when converting to numpy for consistency with ONNX IR Fix https://github.com/pytorch/pytorch/issues/150202 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151069 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
8568dbce1d
commit
0e805aad7f
|
|
@ -35,9 +35,13 @@ class TorchTensorTest(common_utils.TestCase):
|
|||
(torch.uint32, np.uint32),
|
||||
(torch.uint64, np.uint64),
|
||||
(torch.uint8, np.uint8),
|
||||
(torch.float4_e2m1fn_x2, ml_dtypes.float4_e2m1fn),
|
||||
],
|
||||
)
|
||||
def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
|
||||
if dtype == torch.float4_e2m1fn_x2:
|
||||
tensor = _core.TorchTensor(torch.tensor([1], dtype=torch.uint8).view(dtype))
|
||||
else:
|
||||
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
|
||||
self.assertEqual(tensor.numpy().dtype, np_dtype)
|
||||
self.assertEqual(tensor.__array__().dtype, np_dtype)
|
||||
|
|
@ -71,6 +75,12 @@ class TorchTensorTest(common_utils.TestCase):
|
|||
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
|
||||
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())
|
||||
|
||||
def test_tobytes_float4(self):
|
||||
tensor = _core.TorchTensor(
|
||||
torch.tensor([1], dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
)
|
||||
self.assertEqual(tensor.tobytes(), b"\x01")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
|
|||
|
|
@ -216,7 +216,19 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
input = input.to(float8_type)
|
||||
return input
|
||||
|
||||
_ = self.export(Float8Module(), (torch.randn(1, 2),))
|
||||
onnx_program = self.export(Float8Module(), (torch.randn(1, 2),))
|
||||
self.assertEqual(onnx_program.model.graph.outputs[0].dtype, onnx_type)
|
||||
|
||||
def test_float4_support(self):
|
||||
class Float4Module(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.empty([1], dtype=torch.float4_e2m1fn_x2)
|
||||
|
||||
onnx_program = self.export(Float4Module())
|
||||
output = onnx_program.model.graph.outputs[0]
|
||||
self.assertEqual(output.dtype, ir.DataType.FLOAT4E2M1)
|
||||
# The shape is [*shape, 2] because ONNX stores the shape of the unpacked tensor
|
||||
self.assertEqual(output.shape.dims, [1, 2])
|
||||
|
||||
def test_bfloat16_support(self):
|
||||
class BfloatModel(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from torch.onnx._internal.exporter import (
|
|||
_registration,
|
||||
_reporting,
|
||||
_tensors,
|
||||
_type_casting,
|
||||
_verification,
|
||||
)
|
||||
|
||||
|
|
@ -61,6 +62,7 @@ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
|
|||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
||||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
||||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
||||
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
|
||||
torch.int16: ir.DataType.INT16,
|
||||
torch.int32: ir.DataType.INT32,
|
||||
torch.int64: ir.DataType.INT64,
|
||||
|
|
@ -109,8 +111,17 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
|
|||
class TorchTensor(ir.Tensor):
|
||||
def __init__(self, tensor: torch.Tensor, name: str | None = None):
|
||||
# Pass the tensor as the raw data to ir.Tensor's constructor
|
||||
if tensor.dtype == torch.float4_e2m1fn_x2:
|
||||
# Change the shape to the unpacked shape
|
||||
shape = ir.Shape(_type_casting.get_float4_shape(tensor), frozen=True)
|
||||
else:
|
||||
# The base class will set the shape to the tensor's shape
|
||||
shape = None
|
||||
super().__init__(
|
||||
tensor, dtype=torch_dtype_to_onnx_dtype(tensor.dtype), name=name
|
||||
tensor,
|
||||
dtype=torch_dtype_to_onnx_dtype(tensor.dtype),
|
||||
shape=shape,
|
||||
name=name,
|
||||
)
|
||||
|
||||
def numpy(self) -> npt.NDArray:
|
||||
|
|
@ -132,6 +143,10 @@ class TorchTensor(ir.Tensor):
|
|||
ir.DataType.FLOAT8E5M2FNUZ,
|
||||
}:
|
||||
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
||||
if self.dtype == ir.DataType.FLOAT4E2M1:
|
||||
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
|
||||
self.dtype.numpy()
|
||||
)
|
||||
|
||||
return self.raw.numpy(force=True)
|
||||
|
||||
|
|
@ -213,7 +228,13 @@ def _set_shape_type(
|
|||
logger.warning("Setting shape and type of tensors is not supported yet")
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
dims = []
|
||||
for dim in meta_val.shape:
|
||||
shape: tuple[int, ...]
|
||||
if meta_val.dtype == torch.float4_e2m1fn_x2:
|
||||
# Change the shape to the unpacked shape
|
||||
shape = _type_casting.get_float4_shape(meta_val)
|
||||
else:
|
||||
shape = meta_val.shape
|
||||
for dim in shape:
|
||||
if isinstance(dim, int):
|
||||
dims.append(dim)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ _TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = {
|
|||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
||||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
||||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
||||
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
|
||||
torch.int16: ir.DataType.INT16,
|
||||
torch.int32: ir.DataType.INT32,
|
||||
torch.int64: ir.DataType.INT64,
|
||||
|
|
@ -95,6 +96,7 @@ def _param_type_compatible_with_arg(
|
|||
ir.TensorType(ir.DataType.INT32),
|
||||
ir.TensorType(ir.DataType.INT64),
|
||||
# Int inputs can be casted to a float too
|
||||
ir.TensorType(ir.DataType.FLOAT4E2M1),
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
||||
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
||||
|
|
@ -105,6 +107,7 @@ def _param_type_compatible_with_arg(
|
|||
}:
|
||||
return True
|
||||
if isinstance(value, float) and param.type_constraint.allowed_types & {
|
||||
ir.TensorType(ir.DataType.FLOAT4E2M1),
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
||||
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
||||
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
||||
|
|
|
|||
32
torch/onnx/_internal/exporter/_type_casting.py
Normal file
32
torch/onnx/_internal/exporter/_type_casting.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unpack_float4x2_as_uint8(tensor: torch.Tensor) -> np.ndarray:
|
||||
"""Convert a float4x2 tensor to unpacked uint8 np array."""
|
||||
assert tensor.dtype == torch.float4_e2m1fn_x2
|
||||
data = tensor.view(torch.uint8).numpy(force=True).flatten()
|
||||
result_size = tensor.numel() * 2
|
||||
result = np.empty([result_size], dtype=np.uint8)
|
||||
array_low = data & np.uint8(0x0F)
|
||||
array_high = data & np.uint8(0xF0)
|
||||
array_high >>= np.uint8(4)
|
||||
result[0::2] = array_low
|
||||
result[1::2] = array_high
|
||||
result.resize(get_float4_shape(tensor), refcheck=False)
|
||||
return result
|
||||
|
||||
|
||||
def get_float4_shape(tensor: torch.Tensor) -> tuple[int, ...]:
|
||||
"""Get the shape of an unpacked float4 tensor.
|
||||
|
||||
The float4_e2m1fn_x2 type is a shell type described in
|
||||
https://github.com/pytorch/pytorch/issues/146414.
|
||||
|
||||
the shell dtype is takes up 1 byte per element and semantically represents
|
||||
two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape, 2)
|
||||
fp4 elements.
|
||||
"""
|
||||
assert tensor.dtype == torch.float4_e2m1fn_x2
|
||||
return (*tensor.shape, 2)
|
||||
|
|
@ -38,6 +38,9 @@ _TORCH_DTYPE_TO_ONNX_DTYPE = {
|
|||
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
|
||||
torch.float8_e5m2: 19, # FLOAT8E5M2
|
||||
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
|
||||
# 21 = UINT4
|
||||
# 22 = INT4
|
||||
torch.float4_e2m1fn_x2: 23, # FLOAT4E2M1
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ _ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
|
|||
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
||||
21: torch.uint8, # UINT4
|
||||
22: torch.uint8, # INT4
|
||||
23: torch.uint8, # FLOAT4E2M1
|
||||
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
|
||||
}
|
||||
|
||||
_INT_TYPE = "i"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user