pytorch/torch/testing/_internal/opinfo/utils.py
Edward Z. Yang 9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00

274 lines
8.5 KiB
Python

# mypy: ignore-errors
import collections
import warnings
from functools import partial, wraps
from typing import Sequence
import numpy as np
import torch
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_dtype import (
_dispatch_dtypes,
all_types,
all_types_and,
all_types_and_complex,
all_types_and_complex_and,
all_types_and_half,
complex_types,
floating_and_complex_types,
floating_and_complex_types_and,
floating_types,
floating_types_and,
floating_types_and_half,
integral_types,
integral_types_and,
)
from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
COMPLETE_DTYPES_DISPATCH = (
all_types,
all_types_and_complex,
all_types_and_half,
floating_types,
floating_and_complex_types,
floating_types_and_half,
integral_types,
complex_types,
)
EXTENSIBLE_DTYPE_DISPATCH = (
all_types_and_complex_and,
floating_types_and,
floating_and_complex_types_and,
integral_types_and,
all_types_and,
)
# Better way to acquire devices?
DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else [])
class _dynamic_dispatch_dtypes(_dispatch_dtypes):
# Class to tag the dynamically generated types.
pass
def get_supported_dtypes(op, sample_inputs_fn, device_type):
# Returns the supported dtypes for the given operator and device_type pair.
assert device_type in ["cpu", "cuda"]
if not TEST_CUDA and device_type == "cuda":
warnings.warn(
"WARNING: CUDA is not available, empty_dtypes dispatch will be returned!"
)
return _dynamic_dispatch_dtypes(())
supported_dtypes = set()
for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
try:
samples = sample_inputs_fn(op, device_type, dtype, False)
except RuntimeError:
# If `sample_inputs_fn` doesn't support sampling for a given
# `dtype`, we assume that the `dtype` is not supported.
# We raise a warning, so that user knows that this was the case
# and can investigate if there was an issue with the `sample_inputs_fn`.
warnings.warn(
f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}"
)
continue
# We assume the dtype is supported
# only if all samples pass for the given dtype.
supported = True
for sample in samples:
try:
op(sample.input, *sample.args, **sample.kwargs)
except RuntimeError as re:
# dtype is not supported
supported = False
break
if supported:
supported_dtypes.add(dtype)
return _dynamic_dispatch_dtypes(supported_dtypes)
def dtypes_dispatch_hint(dtypes):
# Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH)
# and its string representation for the passed `dtypes`.
return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str")
# CUDA is not available, dtypes will be empty.
if len(dtypes) == 0:
return return_type((), str(tuple()))
set_dtypes = set(dtypes)
for dispatch in COMPLETE_DTYPES_DISPATCH:
# Short circuit if we get an exact match.
if set(dispatch()) == set_dtypes:
return return_type(dispatch, dispatch.__name__ + "()")
chosen_dispatch = None
chosen_dispatch_score = 0.0
for dispatch in EXTENSIBLE_DTYPE_DISPATCH:
dispatch_dtypes = set(dispatch())
if not dispatch_dtypes.issubset(set_dtypes):
continue
score = len(dispatch_dtypes)
if score > chosen_dispatch_score:
chosen_dispatch_score = score
chosen_dispatch = dispatch
# If user passed dtypes which are lower than the lowest
# dispatch type available (not likely but possible in code path).
if chosen_dispatch is None:
return return_type((), str(dtypes))
return return_type(
partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))),
)
def is_dynamic_dtype_set(op):
# Detect if the OpInfo entry acquired dtypes dynamically
# using `get_supported_dtypes`.
return op.dynamic_dtypes
def str_format_dynamic_dtype(op):
fmt_str = f"""
OpInfo({op.name},
dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str},
dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str},
)
"""
return fmt_str
def np_unary_ufunc_integer_promotion_wrapper(fn):
# Wrapper that passes PyTorch's default scalar
# type as an argument to the wrapped NumPy
# unary ufunc when given an integer input.
# This mimicks PyTorch's integer->floating point
# type promotion.
#
# This is necessary when NumPy promotes
# integer types to double, since PyTorch promotes
# integer types to the default scalar type.
# Helper to determine if promotion is needed
def is_integral(dtype):
return dtype in [
np.bool_,
bool,
np.uint8,
np.int8,
np.int16,
np.int32,
np.int64,
]
@wraps(fn)
def wrapped_fn(x):
# As the default dtype can change, acquire it when function is called.
# NOTE: Promotion in PyTorch is from integer types to the default dtype
np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
if is_integral(x.dtype):
return fn(x.astype(np_dtype))
return fn(x)
return wrapped_fn
def reference_reduction_numpy(f, supports_keepdims=True):
"""Wraps a NumPy reduction operator.
The wrapper function will forward dim, keepdim, mask, and identity
kwargs to the wrapped function as the NumPy equivalent axis,
keepdims, where, and initiak kwargs, respectively.
Args:
f: NumPy reduction operator to wrap
supports_keepdims (bool, optional): Whether the NumPy operator accepts
keepdims parameter. If it does not, the wrapper will manually unsqueeze
the reduced dimensions if it was called with keepdim=True. Defaults to True.
Returns:
Wrapped function
"""
@wraps(f)
def wrapper(x: np.ndarray, *args, **kwargs):
# Copy keys into a set
keys = set(kwargs.keys())
dim = kwargs.pop("dim", None)
keepdim = kwargs.pop("keepdim", False)
if "dim" in keys:
dim = tuple(dim) if isinstance(dim, Sequence) else dim
# NumPy reductions don't accept dim=0 for scalar inputs
# so we convert it to None if and only if dim is equivalent
if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}:
kwargs["axis"] = None
else:
kwargs["axis"] = dim
if "keepdim" in keys and supports_keepdims:
kwargs["keepdims"] = keepdim
if "mask" in keys:
mask = kwargs.pop("mask")
if mask is not None:
assert mask.layout == torch.strided
kwargs["where"] = mask.cpu().numpy()
if "identity" in keys:
identity = kwargs.pop("identity")
if identity is not None:
if identity.dtype is torch.bfloat16:
identity = identity.cpu().to(torch.float32)
else:
identity = identity.cpu()
kwargs["initial"] = identity.numpy()
result = f(x, *args, **kwargs)
# Unsqueeze reduced dimensions if NumPy does not support keepdims
if keepdim and not supports_keepdims and x.ndim > 0:
dim = list(range(x.ndim)) if dim is None else dim
result = np.expand_dims(result, dim)
return result
return wrapper
def prod_numpy(a, *args, **kwargs):
"""
The function will call np.prod with type as np.int64 if the input type
is int or uint64 if is uint. This is necessary because windows np.prod uses by default
int32 while on linux it uses int64.
This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320
Returns:
np.prod of input
"""
if "dtype" not in kwargs:
if np.issubdtype(a.dtype, np.signedinteger):
a = a.astype(np.int64)
elif np.issubdtype(a.dtype, np.unsignedinteger):
a = a.astype(np.uint64)
fn = reference_reduction_numpy(np.prod)
return fn(a, *args, **kwargs)