mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is a lot of files changed! Don't panic! Here's how it works: * Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file. * When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded. * The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors. * Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list. * Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves. * torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state. * There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many. In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file. The codemod was done with this script authored by GPT-4: ``` import glob exclude_patterns = [ ... ] for pattern in exclude_patterns: for filepath in glob.glob(pattern, recursive=True): if filepath.endswith('.py'): with open(filepath, 'r+') as f: content = f.read() f.seek(0, 0) f.write('# mypy: ignore-errors\n\n' + content) ``` Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414 Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
207 lines
7.8 KiB
Python
207 lines
7.8 KiB
Python
# mypy: ignore-errors
|
|
|
|
from torch.testing._internal.opinfo.core import (
|
|
BinaryUfuncInfo,
|
|
OpInfo,
|
|
ReductionOpInfo,
|
|
UnaryUfuncInfo,
|
|
)
|
|
|
|
# NOTE [Python References]
|
|
# Python References emulate existing PyTorch operations, but can ultimately
|
|
# be expressed in terms of "primitive" operations from torch._prims.
|
|
#
|
|
# These references are experimental.
|
|
# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577
|
|
# for additional context.
|
|
#
|
|
# Python Reference OpInfos should be added to the python_ref_db list below.
|
|
# Tests can opt-into running on these references by including
|
|
# that list in the Sequence they pass to the @ops decorator.
|
|
#
|
|
# When a Python Reference OpInfo is constructed a pointer to an
|
|
# existing OpInfo must be provided using the torch_opinfo_name kwarg.
|
|
# The existing OpInfo with that name and no variant will be found
|
|
# to inherit from.
|
|
#
|
|
# Instead of just inheriting the existing OpInfo's metadata, the
|
|
# Python Reference OpInfos inherit the existing OpInfo's
|
|
# construction arguments. These arguments can be overridden
|
|
# by adding kwargs to the constructor.
|
|
|
|
|
|
def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None):
|
|
"""
|
|
Finds the OpInfo with the given name that has no variant name.
|
|
"""
|
|
# NOTE: searching the global op_db doesn't work when OpInfos are split into
|
|
# different modules, as otherwise the op_db will not be fully constructed
|
|
# yet. So, instead the local op_db must be passed in explicitly.
|
|
if op_db is None:
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
|
|
for opinfo in op_db:
|
|
if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name:
|
|
return opinfo
|
|
|
|
|
|
def _inherit_constructor_args(name, op, inherited, overrides):
|
|
# inherits metadata
|
|
common_kwargs = {
|
|
"name": name,
|
|
"op": op,
|
|
"aliases": None, # TODO add a check for alias coverage
|
|
"method_variant": None,
|
|
"inplace_variant": None, # TODO: add a check for inplace coverage
|
|
"supports_scripting": False,
|
|
}
|
|
|
|
# Acquires inherited kwargs
|
|
kwargs = inherited.copy()
|
|
|
|
# Fixes metadata
|
|
if "kwargs" in kwargs:
|
|
kwargs.update(kwargs["kwargs"])
|
|
del kwargs["kwargs"]
|
|
if "self" in kwargs:
|
|
del kwargs["self"]
|
|
if "__class__" in kwargs:
|
|
del kwargs["__class__"]
|
|
if "skips" in kwargs:
|
|
del kwargs["skips"]
|
|
if "decorators" in kwargs:
|
|
del kwargs["decorators"]
|
|
|
|
# Overrides metadata
|
|
kwargs.update(common_kwargs)
|
|
kwargs.update(overrides)
|
|
|
|
# At the moment no prims support autograd, so we must not run autograd
|
|
# tests e.g. when testing dtype support. Once we start writing autograd
|
|
# formulas for prims this can be removed.
|
|
kwargs["supports_autograd"] = False
|
|
kwargs["supports_gradgrad"] = False
|
|
kwargs["supports_fwgrad_bwgrad"] = False
|
|
kwargs["supports_inplace_autograd"] = False
|
|
kwargs["supports_forward_ad"] = False
|
|
|
|
return kwargs
|
|
|
|
|
|
class PythonRefInfo(OpInfo):
|
|
"""
|
|
An OpInfo for a Python reference of an OpInfo base class operation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name, # the stringname of the callable Python reference
|
|
*,
|
|
op=None, # the function variant of the operation, populated as torch.<name> if None
|
|
op_db=None, # The database of opinfos to search for the parent opinfo
|
|
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
|
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
|
validate_view_consistency=True,
|
|
**kwargs,
|
|
): # additional kwargs override kwargs inherited from the torch opinfo
|
|
self.torch_opinfo_name = torch_opinfo_name
|
|
self.torch_opinfo_variant_name = torch_opinfo_variant_name
|
|
self.torch_opinfo = _find_referenced_opinfo(
|
|
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
|
)
|
|
self.validate_view_consistency = validate_view_consistency
|
|
assert isinstance(self.torch_opinfo, OpInfo)
|
|
|
|
inherited = self.torch_opinfo._original_opinfo_args
|
|
ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
|
|
super().__init__(**ukwargs)
|
|
|
|
|
|
class ReductionPythonRefInfo(ReductionOpInfo):
|
|
"""
|
|
An OpInfo for a Python reference of an elementwise unary operation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name, # the stringname of the callable Python reference
|
|
*,
|
|
op=None, # the function variant of the operation, populated as torch.<name> if None
|
|
op_db=None, # The database of opinfos to search for the parent opinfo
|
|
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
|
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
|
**kwargs,
|
|
): # additional kwargs override kwargs inherited from the torch opinfo
|
|
self.torch_opinfo_name = torch_opinfo_name
|
|
self.torch_opinfo_variant_name = torch_opinfo_variant_name
|
|
self.torch_opinfo = _find_referenced_opinfo(
|
|
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
|
)
|
|
assert isinstance(self.torch_opinfo, ReductionOpInfo)
|
|
|
|
inherited = self.torch_opinfo._original_reduction_args
|
|
ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
|
|
|
|
# See https://github.com/pytorch/pytorch/issues/77216
|
|
self.validate_view_consistency = False
|
|
|
|
super().__init__(**ukwargs)
|
|
|
|
|
|
class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
|
|
"""
|
|
An OpInfo for a Python reference of an elementwise unary operation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name, # the stringname of the callable Python reference
|
|
*,
|
|
op=None, # the function variant of the operation, populated as torch.<name> if None
|
|
op_db=None, # The database of opinfos to search for the parent opinfo
|
|
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
|
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
|
validate_view_consistency=True,
|
|
**kwargs,
|
|
): # additional kwargs override kwargs inherited from the torch opinfo
|
|
self.torch_opinfo_name = torch_opinfo_name
|
|
self.torch_opinfo_variant_name = torch_opinfo_variant_name
|
|
self.torch_opinfo = _find_referenced_opinfo(
|
|
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
|
)
|
|
self.validate_view_consistency = validate_view_consistency
|
|
assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
|
|
|
|
inherited = self.torch_opinfo._original_unary_ufunc_args
|
|
ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
|
|
|
|
super().__init__(**ukwargs)
|
|
|
|
|
|
class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
|
|
"""
|
|
An OpInfo for a Python reference of an elementwise binary operation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name, # the stringname of the callable Python reference
|
|
*,
|
|
op=None, # the function variant of the operation, populated as torch.<name> if None
|
|
op_db=None, # The database of opinfos to search for the parent opinfo
|
|
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
|
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
|
**kwargs,
|
|
): # additional kwargs override kwargs inherited from the torch opinfo
|
|
self.torch_opinfo_name = torch_opinfo_name
|
|
self.torch_opinfo_variant_name = torch_opinfo_variant_name
|
|
self.torch_opinfo = _find_referenced_opinfo(
|
|
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
|
)
|
|
assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
|
|
|
|
inherited = self.torch_opinfo._original_binary_ufunc_args
|
|
ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
|
|
|
|
super().__init__(**ukwargs)
|