mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Support float4 (#151069)
- Support exporting float4 models (note: currently we use IR version 10 universally in the exporter, which does not include float 4 support. Eventually when onnx runtime and the ecosystem moves to support the new IR version 11 we should bump our version to 11 in the exporter as well) - The shape of the type is set according to https://github.com/pytorch/pytorch/pull/148791#discussion_r2038704986 (added last dim with size 2) - Use ml_dtypes types when converting to numpy for consistency with ONNX IR Fix https://github.com/pytorch/pytorch/issues/150202 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151069 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
8568dbce1d
commit
0e805aad7f
|
|
@ -35,9 +35,13 @@ class TorchTensorTest(common_utils.TestCase):
|
||||||
(torch.uint32, np.uint32),
|
(torch.uint32, np.uint32),
|
||||||
(torch.uint64, np.uint64),
|
(torch.uint64, np.uint64),
|
||||||
(torch.uint8, np.uint8),
|
(torch.uint8, np.uint8),
|
||||||
|
(torch.float4_e2m1fn_x2, ml_dtypes.float4_e2m1fn),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
|
def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
|
||||||
|
if dtype == torch.float4_e2m1fn_x2:
|
||||||
|
tensor = _core.TorchTensor(torch.tensor([1], dtype=torch.uint8).view(dtype))
|
||||||
|
else:
|
||||||
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
|
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
|
||||||
self.assertEqual(tensor.numpy().dtype, np_dtype)
|
self.assertEqual(tensor.numpy().dtype, np_dtype)
|
||||||
self.assertEqual(tensor.__array__().dtype, np_dtype)
|
self.assertEqual(tensor.__array__().dtype, np_dtype)
|
||||||
|
|
@ -71,6 +75,12 @@ class TorchTensorTest(common_utils.TestCase):
|
||||||
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
|
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
|
||||||
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())
|
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())
|
||||||
|
|
||||||
|
def test_tobytes_float4(self):
|
||||||
|
tensor = _core.TorchTensor(
|
||||||
|
torch.tensor([1], dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||||
|
)
|
||||||
|
self.assertEqual(tensor.tobytes(), b"\x01")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
common_utils.run_tests()
|
common_utils.run_tests()
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,19 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||||
input = input.to(float8_type)
|
input = input.to(float8_type)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
_ = self.export(Float8Module(), (torch.randn(1, 2),))
|
onnx_program = self.export(Float8Module(), (torch.randn(1, 2),))
|
||||||
|
self.assertEqual(onnx_program.model.graph.outputs[0].dtype, onnx_type)
|
||||||
|
|
||||||
|
def test_float4_support(self):
|
||||||
|
class Float4Module(torch.nn.Module):
|
||||||
|
def forward(self):
|
||||||
|
return torch.empty([1], dtype=torch.float4_e2m1fn_x2)
|
||||||
|
|
||||||
|
onnx_program = self.export(Float4Module())
|
||||||
|
output = onnx_program.model.graph.outputs[0]
|
||||||
|
self.assertEqual(output.dtype, ir.DataType.FLOAT4E2M1)
|
||||||
|
# The shape is [*shape, 2] because ONNX stores the shape of the unpacked tensor
|
||||||
|
self.assertEqual(output.shape.dims, [1, 2])
|
||||||
|
|
||||||
def test_bfloat16_support(self):
|
def test_bfloat16_support(self):
|
||||||
class BfloatModel(torch.nn.Module):
|
class BfloatModel(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ from torch.onnx._internal.exporter import (
|
||||||
_registration,
|
_registration,
|
||||||
_reporting,
|
_reporting,
|
||||||
_tensors,
|
_tensors,
|
||||||
|
_type_casting,
|
||||||
_verification,
|
_verification,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -61,6 +62,7 @@ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
|
||||||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
||||||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
||||||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
||||||
|
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
|
||||||
torch.int16: ir.DataType.INT16,
|
torch.int16: ir.DataType.INT16,
|
||||||
torch.int32: ir.DataType.INT32,
|
torch.int32: ir.DataType.INT32,
|
||||||
torch.int64: ir.DataType.INT64,
|
torch.int64: ir.DataType.INT64,
|
||||||
|
|
@ -109,8 +111,17 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
|
||||||
class TorchTensor(ir.Tensor):
|
class TorchTensor(ir.Tensor):
|
||||||
def __init__(self, tensor: torch.Tensor, name: str | None = None):
|
def __init__(self, tensor: torch.Tensor, name: str | None = None):
|
||||||
# Pass the tensor as the raw data to ir.Tensor's constructor
|
# Pass the tensor as the raw data to ir.Tensor's constructor
|
||||||
|
if tensor.dtype == torch.float4_e2m1fn_x2:
|
||||||
|
# Change the shape to the unpacked shape
|
||||||
|
shape = ir.Shape(_type_casting.get_float4_shape(tensor), frozen=True)
|
||||||
|
else:
|
||||||
|
# The base class will set the shape to the tensor's shape
|
||||||
|
shape = None
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tensor, dtype=torch_dtype_to_onnx_dtype(tensor.dtype), name=name
|
tensor,
|
||||||
|
dtype=torch_dtype_to_onnx_dtype(tensor.dtype),
|
||||||
|
shape=shape,
|
||||||
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def numpy(self) -> npt.NDArray:
|
def numpy(self) -> npt.NDArray:
|
||||||
|
|
@ -132,6 +143,10 @@ class TorchTensor(ir.Tensor):
|
||||||
ir.DataType.FLOAT8E5M2FNUZ,
|
ir.DataType.FLOAT8E5M2FNUZ,
|
||||||
}:
|
}:
|
||||||
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
||||||
|
if self.dtype == ir.DataType.FLOAT4E2M1:
|
||||||
|
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
|
||||||
|
self.dtype.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
return self.raw.numpy(force=True)
|
return self.raw.numpy(force=True)
|
||||||
|
|
||||||
|
|
@ -213,7 +228,13 @@ def _set_shape_type(
|
||||||
logger.warning("Setting shape and type of tensors is not supported yet")
|
logger.warning("Setting shape and type of tensors is not supported yet")
|
||||||
if isinstance(meta_val, torch.Tensor):
|
if isinstance(meta_val, torch.Tensor):
|
||||||
dims = []
|
dims = []
|
||||||
for dim in meta_val.shape:
|
shape: tuple[int, ...]
|
||||||
|
if meta_val.dtype == torch.float4_e2m1fn_x2:
|
||||||
|
# Change the shape to the unpacked shape
|
||||||
|
shape = _type_casting.get_float4_shape(meta_val)
|
||||||
|
else:
|
||||||
|
shape = meta_val.shape
|
||||||
|
for dim in shape:
|
||||||
if isinstance(dim, int):
|
if isinstance(dim, int):
|
||||||
dims.append(dim)
|
dims.append(dim)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ _TORCH_DTYPE_TO_ONNX_COMPATIBLE: dict[torch.dtype, ir.DataType] = {
|
||||||
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
||||||
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
||||||
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
||||||
|
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
|
||||||
torch.int16: ir.DataType.INT16,
|
torch.int16: ir.DataType.INT16,
|
||||||
torch.int32: ir.DataType.INT32,
|
torch.int32: ir.DataType.INT32,
|
||||||
torch.int64: ir.DataType.INT64,
|
torch.int64: ir.DataType.INT64,
|
||||||
|
|
@ -95,6 +96,7 @@ def _param_type_compatible_with_arg(
|
||||||
ir.TensorType(ir.DataType.INT32),
|
ir.TensorType(ir.DataType.INT32),
|
||||||
ir.TensorType(ir.DataType.INT64),
|
ir.TensorType(ir.DataType.INT64),
|
||||||
# Int inputs can be casted to a float too
|
# Int inputs can be casted to a float too
|
||||||
|
ir.TensorType(ir.DataType.FLOAT4E2M1),
|
||||||
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
||||||
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
||||||
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
||||||
|
|
@ -105,6 +107,7 @@ def _param_type_compatible_with_arg(
|
||||||
}:
|
}:
|
||||||
return True
|
return True
|
||||||
if isinstance(value, float) and param.type_constraint.allowed_types & {
|
if isinstance(value, float) and param.type_constraint.allowed_types & {
|
||||||
|
ir.TensorType(ir.DataType.FLOAT4E2M1),
|
||||||
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
|
||||||
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
|
||||||
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
ir.TensorType(ir.DataType.FLOAT8E5M2),
|
||||||
|
|
|
||||||
32
torch/onnx/_internal/exporter/_type_casting.py
Normal file
32
torch/onnx/_internal/exporter/_type_casting.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_float4x2_as_uint8(tensor: torch.Tensor) -> np.ndarray:
|
||||||
|
"""Convert a float4x2 tensor to unpacked uint8 np array."""
|
||||||
|
assert tensor.dtype == torch.float4_e2m1fn_x2
|
||||||
|
data = tensor.view(torch.uint8).numpy(force=True).flatten()
|
||||||
|
result_size = tensor.numel() * 2
|
||||||
|
result = np.empty([result_size], dtype=np.uint8)
|
||||||
|
array_low = data & np.uint8(0x0F)
|
||||||
|
array_high = data & np.uint8(0xF0)
|
||||||
|
array_high >>= np.uint8(4)
|
||||||
|
result[0::2] = array_low
|
||||||
|
result[1::2] = array_high
|
||||||
|
result.resize(get_float4_shape(tensor), refcheck=False)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_float4_shape(tensor: torch.Tensor) -> tuple[int, ...]:
|
||||||
|
"""Get the shape of an unpacked float4 tensor.
|
||||||
|
|
||||||
|
The float4_e2m1fn_x2 type is a shell type described in
|
||||||
|
https://github.com/pytorch/pytorch/issues/146414.
|
||||||
|
|
||||||
|
the shell dtype is takes up 1 byte per element and semantically represents
|
||||||
|
two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape, 2)
|
||||||
|
fp4 elements.
|
||||||
|
"""
|
||||||
|
assert tensor.dtype == torch.float4_e2m1fn_x2
|
||||||
|
return (*tensor.shape, 2)
|
||||||
|
|
@ -38,6 +38,9 @@ _TORCH_DTYPE_TO_ONNX_DTYPE = {
|
||||||
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
|
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
|
||||||
torch.float8_e5m2: 19, # FLOAT8E5M2
|
torch.float8_e5m2: 19, # FLOAT8E5M2
|
||||||
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
|
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
|
||||||
|
# 21 = UINT4
|
||||||
|
# 22 = INT4
|
||||||
|
torch.float4_e2m1fn_x2: 23, # FLOAT4E2M1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ _ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
|
||||||
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
||||||
21: torch.uint8, # UINT4
|
21: torch.uint8, # UINT4
|
||||||
22: torch.uint8, # INT4
|
22: torch.uint8, # INT4
|
||||||
23: torch.uint8, # FLOAT4E2M1
|
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
|
||||||
}
|
}
|
||||||
|
|
||||||
_INT_TYPE = "i"
|
_INT_TYPE = "i"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user