mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable formatting in all of testing/_internal/opinfo (#83559)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83559 Approved by: https://github.com/albanD
This commit is contained in:
parent
b4bc0d249f
commit
ae68e455be
|
|
@ -719,7 +719,7 @@ include_patterns = [
|
|||
'torch/_refs/**/*.py',
|
||||
'torch/_subclasses/**/*.py',
|
||||
'torch/_*.py',
|
||||
'torch/testing/_internal/opinfo/definitions/*.py',
|
||||
'torch/testing/_internal/opinfo/**/*.py',
|
||||
'torchgen/**/*.py',
|
||||
'functorch/functorch/_src/aot_autograd.py',
|
||||
'functorch/functorch/_src/compilers.py',
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -3,22 +3,22 @@ import warnings
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_cuda import (TEST_CUDA)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types_and_complex_and,
|
||||
all_types_and_complex,
|
||||
all_types_and_half,
|
||||
_dispatch_dtypes,
|
||||
all_types,
|
||||
all_types_and,
|
||||
all_types_and_complex,
|
||||
all_types_and_complex_and,
|
||||
all_types_and_half,
|
||||
complex_types,
|
||||
floating_and_complex_types,
|
||||
floating_types_and_half,
|
||||
floating_types,
|
||||
integral_types,
|
||||
floating_types_and,
|
||||
floating_and_complex_types_and,
|
||||
floating_types,
|
||||
floating_types_and,
|
||||
floating_types_and_half,
|
||||
integral_types,
|
||||
integral_types_and,
|
||||
all_types_and,
|
||||
_dispatch_dtypes,
|
||||
)
|
||||
|
||||
COMPLETE_DTYPES_DISPATCH = (
|
||||
|
|
@ -41,7 +41,8 @@ EXTENSIBLE_DTYPE_DISPATCH = (
|
|||
)
|
||||
|
||||
# Better way to acquire devices?
|
||||
DEVICES = ['cpu'] + (['cuda'] if TEST_CUDA else [])
|
||||
DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else [])
|
||||
|
||||
|
||||
class _dynamic_dispatch_dtypes(_dispatch_dtypes):
|
||||
# Class to tag the dynamically generated types.
|
||||
|
|
@ -50,9 +51,11 @@ class _dynamic_dispatch_dtypes(_dispatch_dtypes):
|
|||
|
||||
def get_supported_dtypes(op, sample_inputs_fn, device_type):
|
||||
# Returns the supported dtypes for the given operator and device_type pair.
|
||||
assert device_type in ['cpu', 'cuda']
|
||||
if not TEST_CUDA and device_type == 'cuda':
|
||||
warnings.warn("WARNING: CUDA is not available, empty_dtypes dispatch will be returned!")
|
||||
assert device_type in ["cpu", "cuda"]
|
||||
if not TEST_CUDA and device_type == "cuda":
|
||||
warnings.warn(
|
||||
"WARNING: CUDA is not available, empty_dtypes dispatch will be returned!"
|
||||
)
|
||||
return _dynamic_dispatch_dtypes(())
|
||||
|
||||
supported_dtypes = set()
|
||||
|
|
@ -64,7 +67,9 @@ def get_supported_dtypes(op, sample_inputs_fn, device_type):
|
|||
# `dtype`, we assume that the `dtype` is not supported.
|
||||
# We raise a warning, so that user knows that this was the case
|
||||
# and can investigate if there was an issue with the `sample_inputs_fn`.
|
||||
warnings.warn(f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}")
|
||||
warnings.warn(
|
||||
f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}"
|
||||
)
|
||||
continue
|
||||
|
||||
# We assume the dtype is supported
|
||||
|
|
@ -87,7 +92,7 @@ def get_supported_dtypes(op, sample_inputs_fn, device_type):
|
|||
def dtypes_dispatch_hint(dtypes):
|
||||
# Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH)
|
||||
# and its string representation for the passed `dtypes`.
|
||||
return_type = collections.namedtuple('return_type', 'dispatch_fn dispatch_fn_str')
|
||||
return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str")
|
||||
|
||||
# CUDA is not available, dtypes will be empty.
|
||||
if len(dtypes) == 0:
|
||||
|
|
@ -100,7 +105,7 @@ def dtypes_dispatch_hint(dtypes):
|
|||
return return_type(dispatch, dispatch.__name__ + "()")
|
||||
|
||||
chosen_dispatch = None
|
||||
chosen_dispatch_score = 0.
|
||||
chosen_dispatch_score = 0.0
|
||||
for dispatch in EXTENSIBLE_DTYPE_DISPATCH:
|
||||
dispatch_dtypes = set(dispatch())
|
||||
if not dispatch_dtypes.issubset(set_dtypes):
|
||||
|
|
@ -116,8 +121,10 @@ def dtypes_dispatch_hint(dtypes):
|
|||
if chosen_dispatch is None:
|
||||
return return_type((), str(dtypes))
|
||||
|
||||
return return_type(partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
|
||||
dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))))
|
||||
return return_type(
|
||||
partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
|
||||
dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))),
|
||||
)
|
||||
|
||||
|
||||
def is_dynamic_dtype_set(op):
|
||||
|
|
@ -132,8 +139,10 @@ def str_format_dynamic_dtype(op):
|
|||
dtypes={dtypes},
|
||||
dtypesIfCUDA={dtypesIfCUDA},
|
||||
)
|
||||
""".format(name=op.name,
|
||||
""".format(
|
||||
name=op.name,
|
||||
dtypes=dtypes_dispatch_hint(op.dtypes).dispatch_fn_str,
|
||||
dtypesIfCUDA=dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str)
|
||||
dtypesIfCUDA=dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str,
|
||||
)
|
||||
|
||||
return fmt_str
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user