mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[CUDA]: Add frexp CUDA bfloat16 support (#133313)
Fixes #133263 Add CUDA bfloat16 support to cuda_frexp Pull Request resolved: https://github.com/pytorch/pytorch/pull/133313 Approved by: https://github.com/ezyang, https://github.com/eqy
This commit is contained in:
parent
59e33cd1f4
commit
ec49ce5f8e
|
|
@ -262,7 +262,7 @@ void nan_to_num_kernel_cuda(
|
|||
}
|
||||
|
||||
void frexp_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half,
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
// The iter.dtype() here is the dtype of mantissa output.
|
||||
// It's a floating point type and must be the same as the input's dtype.
|
||||
iter.dtype(),
|
||||
|
|
|
|||
|
|
@ -13661,8 +13661,6 @@ op_db: List[OpInfo] = [
|
|||
op=torch.frexp,
|
||||
ref=np.frexp,
|
||||
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
# skip testing torch.frexp as it is not supported by ROCm platform yet
|
||||
decorators=[],
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user