[CUDA]: Add frexp CUDA bfloat16 support (#133313)

Fixes #133263 Add CUDA bfloat16 support to cuda_frexp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133313
Approved by: https://github.com/ezyang, https://github.com/eqy
This commit is contained in:
Aaron Gokaslan 2024-08-15 15:20:00 +00:00 committed by PyTorch MergeBot
parent 59e33cd1f4
commit ec49ce5f8e
2 changed files with 1 additions and 3 deletions

View File

@ -262,7 +262,7 @@ void nan_to_num_kernel_cuda(
}
void frexp_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half,
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
// The iter.dtype() here is the dtype of mantissa output.
// It's a floating point type and must be the same as the input's dtype.
iter.dtype(),

View File

@ -13661,8 +13661,6 @@ op_db: List[OpInfo] = [
op=torch.frexp,
ref=np.frexp,
dtypes=floating_types_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half),
# skip testing torch.frexp as it is not supported by ROCm platform yet
decorators=[],
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,