[ONNX] Add BFloat16 type support when export to ONNX (#66788)

Summary:
- PyTorch and ONNX has supported BFloat16, add this to unblock some mixed-precision training model.
- Support PyTorch TNLG model to use BFloat16 tensors for the inputs/outputs of the layers that run on the NPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66788

Reviewed By: jansel

Differential Revision: D32283510

Pulled By: malfet

fbshipit-source-id: 150d69b1465b2b917dd6554505eca58042c1262a
This commit is contained in:
hwangdeyu 2021-12-14 12:13:31 -08:00 committed by Facebook GitHub Bot
parent 800a457b6f
commit c76c6e9bd3
9 changed files with 70 additions and 25 deletions

View File

@ -39,6 +39,9 @@ skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(),
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"),
"Skip In Travis")
skipIfNoBFloat16Cuda = _skipper(lambda: not torch.cuda.is_bf16_supported(),
"BFloat16 CUDA is not available")
# skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions

View File

@ -6450,6 +6450,19 @@ class TestONNXRuntime(unittest.TestCase):
model = MyModule()
self.run_test(model, (x, y))
# ONNX supports bfloat16 for opsets >= 13
@skipIfUnsupportedMinOpsetVersion(13)
def test_cast_type_as_with_bfloat16(self):
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.ones((3, 4), dtype=torch.bfloat16)
x = x.type_as(y)
return x.to(dtype=torch.float16)
x = torch.ones(3, 4, dtype=torch.float16)
model = MyModule()
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_type_as(self):
class MyModule(torch.nn.Module):

View File

@ -7,7 +7,7 @@ import torch
from torch.cuda.amp import autocast
from test_pytorch_common import disableScriptTest, skipIfUnsupportedMinOpsetVersion
from test_pytorch_common import skipIfNoCuda
from test_pytorch_common import skipIfNoCuda, skipIfNoBFloat16Cuda
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
@ -85,6 +85,20 @@ class TestONNXRuntime_cuda(unittest.TestCase):
model = amp.initialize(LinearModel(), opt_level="O2")
self.run_test(model, input)
# ONNX supports bfloat16 for opsets >= 13
# Add, Sub and Mul ops don't support bfloat16 cpu in onnxruntime.
@skipIfUnsupportedMinOpsetVersion(13)
@skipIfNoBFloat16Cuda
def test_arithmetic_bfp16(self):
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.ones(3, 4, dtype=torch.bfloat16, device=torch.device("cuda"))
x = x.type_as(y)
return torch.mul(torch.add(x, y), torch.sub(x, y)).to(dtype=torch.float16)
x = torch.ones(3, 4, requires_grad=True, dtype=torch.float16, device=torch.device("cuda"))
self.run_test(MyModule(), x, rtol=1e-3, atol=1e-5)
TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test

View File

@ -22,6 +22,7 @@ class TensorProtoDataType(Enum):
UINT64 = ...
COMPLEX64 = ...
COMPLEX128 = ...
BFLOAT16 = ...
class OperatorExportTypes(Enum):
ONNX = ...

View File

@ -31,6 +31,10 @@ static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
{c10::kBool, 9},
{c10::kHalf, 10},
{c10::kDouble, 11},
{c10::kQInt8, 12},
{c10::kQUInt8, 13},
{c10::kQInt32, 14},
{c10::kBFloat16, 15},
};
static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) {

View File

@ -392,6 +392,8 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
return onnx::TensorProto_DataType_UINT8;
case at::kQInt32:
return onnx::TensorProto_DataType_INT32;
case at::kBFloat16:
return onnx::TensorProto_DataType_BFLOAT16;
default:
AT_ERROR("unexpected tensor scalar type");
}

View File

@ -24,7 +24,8 @@ void initONNXBindings(PyObject* module) {
.value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
.value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
.value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128);
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
.value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16);
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
.value("ONNX", OperatorExportTypes::ONNX)

View File

@ -317,18 +317,12 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False)
def _is_fp(value):
if value:
if isinstance(value, torch.Tensor):
type = value.dtype
return (type == "torch.float32") or (type == "torch.float64") or (type == "torch.float16")
return value.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16)
else:
type = value.type().scalarType()
if type is None:
warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")
return (type == "Float") or (type == "Double") or (type == "Half")
return False
def _dtype_is_fp(type_value):
if type_value:
return (type_value == torch.float16) or (type_value == torch.float32) or (type_value == torch.float64)
return type in ("Float", "Double", "Half", "BFloat16")
return False
def _generate_wrapped_number(g, scalar):
@ -893,6 +887,7 @@ cast_pytorch_to_onnx = {
"Bool": torch.onnx.TensorProtoDataType.BOOL,
"ComplexFloat": torch.onnx.TensorProtoDataType.COMPLEX64,
"ComplexDouble": torch.onnx.TensorProtoDataType.COMPLEX128,
"BFloat16": torch.onnx.TensorProtoDataType.BFLOAT16,
"Undefined": torch.onnx.TensorProtoDataType.UNDEFINED,
}
@ -907,7 +902,11 @@ scalar_name_to_pytorch = {
"int16_t": "Short",
"bool": "Bool",
"complex64": "ComplexFloat",
"complex128": "ComplexDouble"
"complex128": "ComplexDouble",
"qint8": "QInt8",
"quint8": "QUInt8",
"qint32": "QInt32",
"bfloat16": "BFloat16",
}
@ -934,7 +933,7 @@ class ScalarType(enum.IntEnum):
# This indicates each scalar type's corresponding
# torch type. Related source:
# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
scalar_type_to_pytorch_type = [
torch.uint8, # 0
torch.int8, # 1
@ -948,6 +947,10 @@ scalar_type_to_pytorch_type = [
torch.complex64, # 9
torch.complex128, # 10
torch.bool, # 11
torch.qint8, # 12
torch.quint8, # 13
torch.qint32, # 14
torch.bfloat16, # 15
]
def _cast_func_template(to_i, g, input, non_blocking):
@ -955,18 +958,22 @@ def _cast_func_template(to_i, g, input, non_blocking):
scalar_type_to_onnx = [
cast_pytorch_to_onnx["Byte"],
cast_pytorch_to_onnx["Char"],
cast_pytorch_to_onnx["Short"],
cast_pytorch_to_onnx["Int"],
cast_pytorch_to_onnx["Long"],
cast_pytorch_to_onnx["Half"],
cast_pytorch_to_onnx["Float"],
cast_pytorch_to_onnx["Double"],
cast_pytorch_to_onnx["Undefined"],
cast_pytorch_to_onnx["ComplexFloat"],
cast_pytorch_to_onnx["ComplexDouble"],
cast_pytorch_to_onnx["Bool"],
cast_pytorch_to_onnx["Byte"], # 0
cast_pytorch_to_onnx["Char"], # 1
cast_pytorch_to_onnx["Short"], # 2
cast_pytorch_to_onnx["Int"], # 3
cast_pytorch_to_onnx["Long"], # 4
cast_pytorch_to_onnx["Half"], # 5
cast_pytorch_to_onnx["Float"], # 6
cast_pytorch_to_onnx["Double"], # 7
cast_pytorch_to_onnx["Undefined"], # 8
cast_pytorch_to_onnx["ComplexFloat"], # 9
cast_pytorch_to_onnx["ComplexDouble"], # 10
cast_pytorch_to_onnx["Bool"], # 11
cast_pytorch_to_onnx["Char"], # 12
cast_pytorch_to_onnx["Byte"], # 13
cast_pytorch_to_onnx["Int"], # 14
cast_pytorch_to_onnx["BFloat16"], # 15
]
# Global set to store the list of quantized operators in the network.

View File

@ -3196,7 +3196,7 @@ def linear(g, input, weight, bias):
def hann_window(g, window_length, periodic=True, dtype=None, layout=None, device=None, pin_memory=None, requires_grad=False):
if dtype is None:
dtype = torch.get_default_dtype()
if sym_help._dtype_is_fp(dtype) is False:
if not dtype or not dtype.is_floating_point:
dtype = torch.float
dtype = sym_help.scalar_type_to_pytorch_type.index(dtype)