mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72405 Fixes #71674. This shouldn't segfault now: ``` import torch d = torch.complex64 torch.set_default_dtype(d) ``` Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D34423660 Pulled By: anjali411 fbshipit-source-id: cac92a6f56846f2c0727a120b5f568aa75baa21e (cherry picked from commit eaab813a0fddced24303b3bd50e4fcdba1516e46)
156 lines
4.9 KiB
Python
156 lines
4.9 KiB
Python
"""This module exist to be able to deprecate functions publicly without doing so internally. The deprecated
|
|
public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal
|
|
versions should be imported from torch.testing._internal
|
|
"""
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
__all_dtype_getters__ = [
|
|
"_validate_dtypes",
|
|
"_dispatch_dtypes",
|
|
"all_types",
|
|
"all_types_and",
|
|
"all_types_and_complex",
|
|
"all_types_and_complex_and",
|
|
"all_types_and_half",
|
|
"complex_types",
|
|
"empty_types",
|
|
"floating_and_complex_types",
|
|
"floating_and_complex_types_and",
|
|
"floating_types",
|
|
"floating_types_and",
|
|
"double_types",
|
|
"floating_types_and_half",
|
|
"get_all_complex_dtypes",
|
|
"get_all_dtypes",
|
|
"get_all_fp_dtypes",
|
|
"get_all_int_dtypes",
|
|
"get_all_math_dtypes",
|
|
"integral_types",
|
|
"integral_types_and",
|
|
]
|
|
|
|
__all__ = [
|
|
*__all_dtype_getters__,
|
|
"get_all_device_types",
|
|
]
|
|
|
|
# Functions and classes for describing the dtypes a function supports
|
|
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
|
|
|
|
# Verifies each given dtype is a torch.dtype
|
|
def _validate_dtypes(*dtypes):
|
|
for dtype in dtypes:
|
|
assert isinstance(dtype, torch.dtype)
|
|
return dtypes
|
|
|
|
# class for tuples corresponding to a PyTorch dispatch macro
|
|
class _dispatch_dtypes(tuple):
|
|
def __add__(self, other):
|
|
assert isinstance(other, tuple)
|
|
return _dispatch_dtypes(tuple.__add__(self, other))
|
|
|
|
_empty_types = _dispatch_dtypes(())
|
|
def empty_types():
|
|
return _empty_types
|
|
|
|
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
|
|
def floating_types():
|
|
return _floating_types
|
|
|
|
_floating_types_and_half = _floating_types + (torch.half,)
|
|
def floating_types_and_half():
|
|
return _floating_types_and_half
|
|
|
|
def floating_types_and(*dtypes):
|
|
return _floating_types + _validate_dtypes(*dtypes)
|
|
|
|
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
|
|
def floating_and_complex_types():
|
|
return _floating_and_complex_types
|
|
|
|
def floating_and_complex_types_and(*dtypes):
|
|
return _floating_and_complex_types + _validate_dtypes(*dtypes)
|
|
|
|
_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
|
|
def double_types():
|
|
return _double_types
|
|
|
|
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
|
|
def integral_types():
|
|
return _integral_types
|
|
|
|
def integral_types_and(*dtypes):
|
|
return _integral_types + _validate_dtypes(*dtypes)
|
|
|
|
_all_types = _floating_types + _integral_types
|
|
def all_types():
|
|
return _all_types
|
|
|
|
def all_types_and(*dtypes):
|
|
return _all_types + _validate_dtypes(*dtypes)
|
|
|
|
_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
|
|
def complex_types():
|
|
return _complex_types
|
|
|
|
_all_types_and_complex = _all_types + _complex_types
|
|
def all_types_and_complex():
|
|
return _all_types_and_complex
|
|
|
|
def all_types_and_complex_and(*dtypes):
|
|
return _all_types_and_complex + _validate_dtypes(*dtypes)
|
|
|
|
_all_types_and_half = _all_types + (torch.half,)
|
|
def all_types_and_half():
|
|
return _all_types_and_half
|
|
|
|
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
|
|
|
|
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
|
|
def get_all_dtypes(include_half=True,
|
|
include_bfloat16=True,
|
|
include_bool=True,
|
|
include_complex=True,
|
|
include_complex32=False,
|
|
include_qint=False,
|
|
) -> List[torch.dtype]:
|
|
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
|
if include_bool:
|
|
dtypes.append(torch.bool)
|
|
if include_complex:
|
|
dtypes += get_all_complex_dtypes(include_complex32)
|
|
if include_qint:
|
|
dtypes += get_all_qint_dtypes()
|
|
return dtypes
|
|
|
|
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
|
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
|
include_bfloat16=False) + get_all_complex_dtypes()
|
|
|
|
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
|
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
|
|
|
|
|
|
def get_all_int_dtypes() -> List[torch.dtype]:
|
|
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
|
|
|
|
|
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
|
dtypes = [torch.float32, torch.float64]
|
|
if include_half:
|
|
dtypes.append(torch.float16)
|
|
if include_bfloat16:
|
|
dtypes.append(torch.bfloat16)
|
|
return dtypes
|
|
|
|
|
|
def get_all_qint_dtypes() -> List[torch.dtype]:
|
|
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
|
|
|
|
|
|
def get_all_device_types() -> List[str]:
|
|
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|