mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63844 Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D31141433 Pulled By: mruberry fbshipit-source-id: a29331278ab99a19e225e2cb357458e3db4f9732
148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
"""This module exist to be able to deprecate functions publicly without doing so internally. The deprecated
|
|
public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal
|
|
versions should be imported from torch.testing._internal
|
|
"""
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
__all_dtype_getters__ = [
|
|
"_validate_dtypes",
|
|
"_dispatch_dtypes",
|
|
"all_types",
|
|
"all_types_and",
|
|
"all_types_and_complex",
|
|
"all_types_and_complex_and",
|
|
"all_types_and_half",
|
|
"complex_types",
|
|
"empty_types",
|
|
"floating_and_complex_types",
|
|
"floating_and_complex_types_and",
|
|
"floating_types",
|
|
"floating_types_and",
|
|
"double_types",
|
|
"floating_types_and_half",
|
|
"get_all_complex_dtypes",
|
|
"get_all_dtypes",
|
|
"get_all_fp_dtypes",
|
|
"get_all_int_dtypes",
|
|
"get_all_math_dtypes",
|
|
"integral_types",
|
|
"integral_types_and",
|
|
]
|
|
|
|
__all__ = [
|
|
*__all_dtype_getters__,
|
|
"get_all_device_types",
|
|
]
|
|
|
|
# Functions and classes for describing the dtypes a function supports
|
|
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
|
|
|
|
# Verifies each given dtype is a torch.dtype
|
|
def _validate_dtypes(*dtypes):
|
|
for dtype in dtypes:
|
|
assert isinstance(dtype, torch.dtype)
|
|
return dtypes
|
|
|
|
# class for tuples corresponding to a PyTorch dispatch macro
|
|
class _dispatch_dtypes(tuple):
|
|
def __add__(self, other):
|
|
assert isinstance(other, tuple)
|
|
return _dispatch_dtypes(tuple.__add__(self, other))
|
|
|
|
_empty_types = _dispatch_dtypes(())
|
|
def empty_types():
|
|
return _empty_types
|
|
|
|
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
|
|
def floating_types():
|
|
return _floating_types
|
|
|
|
_floating_types_and_half = _floating_types + (torch.half,)
|
|
def floating_types_and_half():
|
|
return _floating_types_and_half
|
|
|
|
def floating_types_and(*dtypes):
|
|
return _floating_types + _validate_dtypes(*dtypes)
|
|
|
|
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
|
|
def floating_and_complex_types():
|
|
return _floating_and_complex_types
|
|
|
|
def floating_and_complex_types_and(*dtypes):
|
|
return _floating_and_complex_types + _validate_dtypes(*dtypes)
|
|
|
|
_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
|
|
def double_types():
|
|
return _double_types
|
|
|
|
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
|
|
def integral_types():
|
|
return _integral_types
|
|
|
|
def integral_types_and(*dtypes):
|
|
return _integral_types + _validate_dtypes(*dtypes)
|
|
|
|
_all_types = _floating_types + _integral_types
|
|
def all_types():
|
|
return _all_types
|
|
|
|
def all_types_and(*dtypes):
|
|
return _all_types + _validate_dtypes(*dtypes)
|
|
|
|
_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
|
|
def complex_types():
|
|
return _complex_types
|
|
|
|
_all_types_and_complex = _all_types + _complex_types
|
|
def all_types_and_complex():
|
|
return _all_types_and_complex
|
|
|
|
def all_types_and_complex_and(*dtypes):
|
|
return _all_types_and_complex + _validate_dtypes(*dtypes)
|
|
|
|
_all_types_and_half = _all_types + (torch.half,)
|
|
def all_types_and_half():
|
|
return _all_types_and_half
|
|
|
|
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
|
|
|
|
def get_all_dtypes(include_half=True,
|
|
include_bfloat16=True,
|
|
include_bool=True,
|
|
include_complex=True,
|
|
include_complex32=False
|
|
) -> List[torch.dtype]:
|
|
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
|
if include_bool:
|
|
dtypes.append(torch.bool)
|
|
if include_complex:
|
|
dtypes += get_all_complex_dtypes(include_complex32)
|
|
return dtypes
|
|
|
|
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
|
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
|
include_bfloat16=False) + get_all_complex_dtypes()
|
|
|
|
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
|
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
|
|
|
|
|
|
def get_all_int_dtypes() -> List[torch.dtype]:
|
|
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
|
|
|
|
|
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
|
dtypes = [torch.float32, torch.float64]
|
|
if include_half:
|
|
dtypes.append(torch.float16)
|
|
if include_bfloat16:
|
|
dtypes.append(torch.bfloat16)
|
|
return dtypes
|
|
|
|
|
|
def get_all_device_types() -> List[str]:
|
|
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|