mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add BFloat16 type support when export to ONNX (#66788)
Summary: - PyTorch and ONNX has supported BFloat16, add this to unblock some mixed-precision training model. - Support PyTorch TNLG model to use BFloat16 tensors for the inputs/outputs of the layers that run on the NPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/66788 Reviewed By: jansel Differential Revision: D32283510 Pulled By: malfet fbshipit-source-id: 150d69b1465b2b917dd6554505eca58042c1262a
This commit is contained in:
parent
800a457b6f
commit
c76c6e9bd3
|
|
@ -39,6 +39,9 @@ skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(),
|
|||
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"),
|
||||
"Skip In Travis")
|
||||
|
||||
skipIfNoBFloat16Cuda = _skipper(lambda: not torch.cuda.is_bf16_supported(),
|
||||
"BFloat16 CUDA is not available")
|
||||
|
||||
# skips tests for all versions below min_opset_version.
|
||||
# if exporting the op is only supported after a specific version,
|
||||
# add this wrapper to prevent running the test for opset_versions
|
||||
|
|
|
|||
|
|
@ -6450,6 +6450,19 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
model = MyModule()
|
||||
self.run_test(model, (x, y))
|
||||
|
||||
# ONNX supports bfloat16 for opsets >= 13
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_cast_type_as_with_bfloat16(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = torch.ones((3, 4), dtype=torch.bfloat16)
|
||||
x = x.type_as(y)
|
||||
return x.to(dtype=torch.float16)
|
||||
|
||||
x = torch.ones(3, 4, dtype=torch.float16)
|
||||
model = MyModule()
|
||||
self.run_test(model, x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_type_as(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from test_pytorch_common import disableScriptTest, skipIfUnsupportedMinOpsetVersion
|
||||
from test_pytorch_common import skipIfNoCuda
|
||||
from test_pytorch_common import skipIfNoCuda, skipIfNoBFloat16Cuda
|
||||
|
||||
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
|
||||
|
||||
|
|
@ -85,6 +85,20 @@ class TestONNXRuntime_cuda(unittest.TestCase):
|
|||
model = amp.initialize(LinearModel(), opt_level="O2")
|
||||
self.run_test(model, input)
|
||||
|
||||
# ONNX supports bfloat16 for opsets >= 13
|
||||
# Add, Sub and Mul ops don't support bfloat16 cpu in onnxruntime.
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
@skipIfNoBFloat16Cuda
|
||||
def test_arithmetic_bfp16(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = torch.ones(3, 4, dtype=torch.bfloat16, device=torch.device("cuda"))
|
||||
x = x.type_as(y)
|
||||
return torch.mul(torch.add(x, y), torch.sub(x, y)).to(dtype=torch.float16)
|
||||
|
||||
x = torch.ones(3, 4, requires_grad=True, dtype=torch.float16, device=torch.device("cuda"))
|
||||
self.run_test(MyModule(), x, rtol=1e-3, atol=1e-5)
|
||||
|
||||
TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
|
||||
TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class TensorProtoDataType(Enum):
|
|||
UINT64 = ...
|
||||
COMPLEX64 = ...
|
||||
COMPLEX128 = ...
|
||||
BFLOAT16 = ...
|
||||
|
||||
class OperatorExportTypes(Enum):
|
||||
ONNX = ...
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
|
|||
{c10::kBool, 9},
|
||||
{c10::kHalf, 10},
|
||||
{c10::kDouble, 11},
|
||||
{c10::kQInt8, 12},
|
||||
{c10::kQUInt8, 13},
|
||||
{c10::kQInt32, 14},
|
||||
{c10::kBFloat16, 15},
|
||||
};
|
||||
|
||||
static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) {
|
||||
|
|
|
|||
|
|
@ -392,6 +392,8 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
|
|||
return onnx::TensorProto_DataType_UINT8;
|
||||
case at::kQInt32:
|
||||
return onnx::TensorProto_DataType_INT32;
|
||||
case at::kBFloat16:
|
||||
return onnx::TensorProto_DataType_BFLOAT16;
|
||||
default:
|
||||
AT_ERROR("unexpected tensor scalar type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ void initONNXBindings(PyObject* module) {
|
|||
.value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
|
||||
.value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
|
||||
.value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
|
||||
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128);
|
||||
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
|
||||
.value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16);
|
||||
|
||||
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
|
||||
.value("ONNX", OperatorExportTypes::ONNX)
|
||||
|
|
|
|||
|
|
@ -317,18 +317,12 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False)
|
|||
def _is_fp(value):
|
||||
if value:
|
||||
if isinstance(value, torch.Tensor):
|
||||
type = value.dtype
|
||||
return (type == "torch.float32") or (type == "torch.float64") or (type == "torch.float16")
|
||||
return value.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16)
|
||||
else:
|
||||
type = value.type().scalarType()
|
||||
if type is None:
|
||||
warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")
|
||||
return (type == "Float") or (type == "Double") or (type == "Half")
|
||||
return False
|
||||
|
||||
def _dtype_is_fp(type_value):
|
||||
if type_value:
|
||||
return (type_value == torch.float16) or (type_value == torch.float32) or (type_value == torch.float64)
|
||||
return type in ("Float", "Double", "Half", "BFloat16")
|
||||
return False
|
||||
|
||||
def _generate_wrapped_number(g, scalar):
|
||||
|
|
@ -893,6 +887,7 @@ cast_pytorch_to_onnx = {
|
|||
"Bool": torch.onnx.TensorProtoDataType.BOOL,
|
||||
"ComplexFloat": torch.onnx.TensorProtoDataType.COMPLEX64,
|
||||
"ComplexDouble": torch.onnx.TensorProtoDataType.COMPLEX128,
|
||||
"BFloat16": torch.onnx.TensorProtoDataType.BFLOAT16,
|
||||
"Undefined": torch.onnx.TensorProtoDataType.UNDEFINED,
|
||||
}
|
||||
|
||||
|
|
@ -907,7 +902,11 @@ scalar_name_to_pytorch = {
|
|||
"int16_t": "Short",
|
||||
"bool": "Bool",
|
||||
"complex64": "ComplexFloat",
|
||||
"complex128": "ComplexDouble"
|
||||
"complex128": "ComplexDouble",
|
||||
"qint8": "QInt8",
|
||||
"quint8": "QUInt8",
|
||||
"qint32": "QInt32",
|
||||
"bfloat16": "BFloat16",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -934,7 +933,7 @@ class ScalarType(enum.IntEnum):
|
|||
|
||||
# This indicates each scalar type's corresponding
|
||||
# torch type. Related source:
|
||||
# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
|
||||
# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
|
||||
scalar_type_to_pytorch_type = [
|
||||
torch.uint8, # 0
|
||||
torch.int8, # 1
|
||||
|
|
@ -948,6 +947,10 @@ scalar_type_to_pytorch_type = [
|
|||
torch.complex64, # 9
|
||||
torch.complex128, # 10
|
||||
torch.bool, # 11
|
||||
torch.qint8, # 12
|
||||
torch.quint8, # 13
|
||||
torch.qint32, # 14
|
||||
torch.bfloat16, # 15
|
||||
]
|
||||
|
||||
def _cast_func_template(to_i, g, input, non_blocking):
|
||||
|
|
@ -955,18 +958,22 @@ def _cast_func_template(to_i, g, input, non_blocking):
|
|||
|
||||
|
||||
scalar_type_to_onnx = [
|
||||
cast_pytorch_to_onnx["Byte"],
|
||||
cast_pytorch_to_onnx["Char"],
|
||||
cast_pytorch_to_onnx["Short"],
|
||||
cast_pytorch_to_onnx["Int"],
|
||||
cast_pytorch_to_onnx["Long"],
|
||||
cast_pytorch_to_onnx["Half"],
|
||||
cast_pytorch_to_onnx["Float"],
|
||||
cast_pytorch_to_onnx["Double"],
|
||||
cast_pytorch_to_onnx["Undefined"],
|
||||
cast_pytorch_to_onnx["ComplexFloat"],
|
||||
cast_pytorch_to_onnx["ComplexDouble"],
|
||||
cast_pytorch_to_onnx["Bool"],
|
||||
cast_pytorch_to_onnx["Byte"], # 0
|
||||
cast_pytorch_to_onnx["Char"], # 1
|
||||
cast_pytorch_to_onnx["Short"], # 2
|
||||
cast_pytorch_to_onnx["Int"], # 3
|
||||
cast_pytorch_to_onnx["Long"], # 4
|
||||
cast_pytorch_to_onnx["Half"], # 5
|
||||
cast_pytorch_to_onnx["Float"], # 6
|
||||
cast_pytorch_to_onnx["Double"], # 7
|
||||
cast_pytorch_to_onnx["Undefined"], # 8
|
||||
cast_pytorch_to_onnx["ComplexFloat"], # 9
|
||||
cast_pytorch_to_onnx["ComplexDouble"], # 10
|
||||
cast_pytorch_to_onnx["Bool"], # 11
|
||||
cast_pytorch_to_onnx["Char"], # 12
|
||||
cast_pytorch_to_onnx["Byte"], # 13
|
||||
cast_pytorch_to_onnx["Int"], # 14
|
||||
cast_pytorch_to_onnx["BFloat16"], # 15
|
||||
]
|
||||
|
||||
# Global set to store the list of quantized operators in the network.
|
||||
|
|
|
|||
|
|
@ -3196,7 +3196,7 @@ def linear(g, input, weight, bias):
|
|||
def hann_window(g, window_length, periodic=True, dtype=None, layout=None, device=None, pin_memory=None, requires_grad=False):
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if sym_help._dtype_is_fp(dtype) is False:
|
||||
if not dtype or not dtype.is_floating_point:
|
||||
dtype = torch.float
|
||||
dtype = sym_help.scalar_type_to_pytorch_type.index(dtype)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user