mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: ## Motivation Fixes https://github.com/pytorch/pytorch/issues/43770. ## Description of the change This PR fixes exception chaining only in files under `torch/` where appropriate. To fix exception chaining, I used either: 1. `raise new_exception from old_exception` where `new_exception` itself seems not descriptive enough to debug or `old_exception` delivers valuable information. 2. `raise new_exception from None` where raising both of `new_exception` and `old_exception` seems a bit noisy and redundant. I subjectively chose which one to use from the above options. ## List of lines containing raise in except clause: I wrote [this simple script](https://gist.github.com/akihironitta/4223c1b32404b36c1b349d70c4c93b4d) using [ast](https://docs.python.org/3.8/library/ast.html#module-ast) to list lines where `raise`ing in `except` clause. - [x]000739c31a/torch/jit/annotations.py (L35)- [x]000739c31a/torch/jit/annotations.py (L150)- [x]000739c31a/torch/jit/annotations.py (L158)- [x]000739c31a/torch/jit/annotations.py (L231)- [x]000739c31a/torch/jit/_trace.py (L432)- [x]000739c31a/torch/nn/utils/prune.py (L192)- [x]000739c31a/torch/cuda/nvtx.py (L7)- [x]000739c31a/torch/utils/cpp_extension.py (L1537)- [x]000739c31a/torch/utils/tensorboard/_pytorch_graph.py (L292)- [x]000739c31a/torch/utils/data/dataloader.py (L835)- [x]000739c31a/torch/utils/data/dataloader.py (L849)- [x]000739c31a/torch/utils/data/dataloader.py (L856)- [x]000739c31a/torch/testing/_internal/common_utils.py (L186)- [x]000739c31a/torch/testing/_internal/common_utils.py (L189)- [x]000739c31a/torch/testing/_internal/common_utils.py (L424)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1279)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1283)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1356)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1388)- [x]000739c31a/torch/testing/_internal/common_utils.py (L1391)- [ ]000739c31a/torch/testing/_internal/common_utils.py (L1412)- [x]000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L310)- [x]000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L329)- [x]000739c31a/torch/testing/_internal/codegen/random_topo_test.py (L332)- [x]000739c31a/torch/testing/_internal/jit_utils.py (L183)- [x]000739c31a/torch/testing/_internal/common_nn.py (L4789)- [x]000739c31a/torch/onnx/utils.py (L367)- [x]000739c31a/torch/onnx/utils.py (L659)- [x]000739c31a/torch/onnx/utils.py (L892)- [x]000739c31a/torch/onnx/utils.py (L897)- [x]000739c31a/torch/serialization.py (L108)- [x]000739c31a/torch/serialization.py (L754)- [x]000739c31a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py (L76)- [x]000739c31a/torch/distributed/rpc/backend_registry.py (L260)- [x]000739c31a/torch/distributed/distributed_c10d.py (L184)- [x]000739c31a/torch/_utils_internal.py (L57)- [x]000739c31a/torch/hub.py (L494)- [x]000739c31a/torch/contrib/_tensorboard_vis.py (L16)- [x]000739c31a/torch/distributions/lowrank_multivariate_normal.py (L100)- [x]000739c31a/torch/distributions/constraint_registry.py (L142)Pull Request resolved: https://github.com/pytorch/pytorch/pull/43836 Reviewed By: ailzhang Differential Revision: D23431212 Pulled By: malfet fbshipit-source-id: 5f7f41b391164a5ad0efc06e55cd58c23408a921
384 lines
14 KiB
Python
384 lines
14 KiB
Python
import ast
|
|
import enum
|
|
import inspect
|
|
import warnings
|
|
import os
|
|
import re
|
|
import torch
|
|
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
|
|
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
|
|
is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn
|
|
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
|
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
|
|
DeviceObjType, FutureType, EnumType
|
|
|
|
|
|
from textwrap import dedent
|
|
from torch._six import builtins
|
|
from torch._utils_internal import get_source_lines_and_file
|
|
|
|
|
|
if torch.distributed.rpc.is_available():
|
|
from .._jit_internal import RRef, is_rref
|
|
from torch._C import RRefType
|
|
|
|
|
|
class Module(object):
|
|
def __init__(self, name, members):
|
|
self.name = name
|
|
self.members = members
|
|
|
|
def __getattr__(self, name):
|
|
try:
|
|
return self.members[name]
|
|
except KeyError:
|
|
raise RuntimeError("Module {} has no member called {}".format(self.name, name)) from None
|
|
|
|
|
|
class EvalEnv(object):
|
|
env = {
|
|
'torch': Module('torch', {'Tensor': torch.Tensor}),
|
|
'Tensor': torch.Tensor,
|
|
'typing': Module('typing', {'Tuple': Tuple}),
|
|
'Tuple': Tuple,
|
|
'List': List,
|
|
'Dict': Dict,
|
|
'Optional': Optional,
|
|
'Future': Future,
|
|
}
|
|
|
|
def __init__(self, rcb):
|
|
self.rcb = rcb
|
|
if torch.distributed.rpc.is_available():
|
|
self.env['RRef'] = RRef
|
|
|
|
def __getitem__(self, name):
|
|
if name in self.env:
|
|
return self.env[name]
|
|
if self.rcb is not None:
|
|
return self.rcb(name)
|
|
return getattr(builtins, name, None)
|
|
|
|
def get_signature(fn, rcb, loc, is_method):
|
|
signature = try_real_annotations(fn, loc)
|
|
if signature is not None and is_method:
|
|
# If this is a method, then the signature will include a type for
|
|
# `self`, but type comments do not contain a `self`. So strip it
|
|
# away here so everything is consistent (`inspect.ismethod` does
|
|
# not work here since `fn` is unbound at this point)
|
|
param_types, return_type = signature
|
|
param_types = param_types[1:]
|
|
signature = (param_types, return_type)
|
|
|
|
if signature is None:
|
|
type_line, source = None, None
|
|
try:
|
|
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
|
|
type_line = get_type_line(source)
|
|
except TypeError:
|
|
pass
|
|
# This might happen both because we failed to get the source of fn, or
|
|
# because it didn't have any annotations.
|
|
if type_line is not None:
|
|
signature = parse_type_line(type_line, rcb, loc)
|
|
|
|
return signature
|
|
|
|
|
|
def is_function_or_method(the_callable):
|
|
# A stricter version of `inspect.isroutine` that does not pass for built-in
|
|
# functions
|
|
return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
|
|
|
|
|
|
def is_vararg(the_callable):
|
|
if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'): # noqa: B004
|
|
# If `the_callable` is a class, de-sugar the call so we can still get
|
|
# the signature
|
|
the_callable = the_callable.__call__
|
|
|
|
if is_function_or_method(the_callable):
|
|
return inspect.getfullargspec(the_callable).varargs is not None
|
|
else:
|
|
return False
|
|
|
|
|
|
def get_param_names(fn, n_args):
|
|
if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
|
|
# De-sugar calls to classes
|
|
fn = fn.__call__
|
|
|
|
if is_function_or_method(fn):
|
|
if is_ignored_fn(fn):
|
|
fn = inspect.unwrap(fn)
|
|
return inspect.getfullargspec(fn).args
|
|
else:
|
|
# The `fn` was not a method or function (maybe a class with a __call__
|
|
# method, so use a default param name list)
|
|
return [str(i) for i in range(n_args)]
|
|
|
|
|
|
def check_fn(fn, loc):
|
|
# Make sure the function definition is not a class instantiation
|
|
try:
|
|
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
|
|
except (TypeError, IOError):
|
|
return
|
|
if source is None:
|
|
return
|
|
|
|
py_ast = ast.parse(source)
|
|
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
|
|
raise torch.jit.frontend.FrontendError(
|
|
loc, "Cannot instantiate class '{}' in a script function".format(py_ast.body[0].name))
|
|
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
|
raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
|
|
|
|
|
|
def parse_type_line(type_line, rcb, loc):
|
|
"""Parses a type annotation specified as a comment.
|
|
|
|
Example inputs:
|
|
# type: (Tensor, torch.Tensor) -> Tuple[Tensor]
|
|
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
|
|
"""
|
|
arg_ann_str, ret_ann_str = split_type_line(type_line)
|
|
|
|
try:
|
|
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # noqa: P204
|
|
except (NameError, SyntaxError) as e:
|
|
raise RuntimeError("Failed to parse the argument list of a type annotation") from e
|
|
|
|
if not isinstance(arg_ann, tuple):
|
|
arg_ann = (arg_ann,)
|
|
|
|
try:
|
|
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # noqa: P204
|
|
except (NameError, SyntaxError) as e:
|
|
raise RuntimeError("Failed to parse the return type of a type annotation") from e
|
|
|
|
arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
|
|
return arg_types, ann_to_type(ret_ann, loc)
|
|
|
|
|
|
def get_type_line(source):
|
|
"""Tries to find the line containing a comment with the type annotation."""
|
|
type_comment = '# type:'
|
|
|
|
lines = source.split('\n')
|
|
lines = [(line_num, line) for line_num, line in enumerate(lines)]
|
|
type_lines = list(filter(lambda line: type_comment in line[1], lines))
|
|
# `type: ignore` comments may be needed in JIT'ed functions for mypy, due
|
|
# to the hack in torch/_VF.py.
|
|
type_lines = list(filter(lambda line: not line[1].endswith("# type: ignore"),
|
|
type_lines))
|
|
lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
|
|
|
|
if len(type_lines) == 0:
|
|
type_pattern = re.compile('#[\t ]*type[\t ]*(?!: ignore$):')
|
|
wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
|
|
if len(wrong_type_lines) > 0:
|
|
raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
|
|
+ " is probably invalid.\nIt must be '# type:'"
|
|
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa
|
|
+ "\nfor examples")
|
|
return None
|
|
elif len(type_lines) == 1:
|
|
# Only 1 type line, quit now
|
|
return type_lines[0][1].strip()
|
|
|
|
# Parse split up argument types according to PEP 484
|
|
# https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
|
|
return_line = None
|
|
parameter_type_lines = []
|
|
for line_num, line in type_lines:
|
|
if '# type: (...) -> ' in line:
|
|
return_line = (line_num, line)
|
|
break
|
|
elif type_comment in line:
|
|
parameter_type_lines.append(line)
|
|
if return_line is None:
|
|
raise RuntimeError(
|
|
"Return type line '# type: (...) -> ...' not found on multiline "
|
|
"type annotation\nfor type lines:\n" +
|
|
'\n'.join([line[1] for line in type_lines]) +
|
|
"\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)") # noqa
|
|
|
|
def get_parameter_type(line):
|
|
item_type = line[line.find(type_comment) + len(type_comment):]
|
|
return item_type.strip()
|
|
|
|
types = map(get_parameter_type, parameter_type_lines)
|
|
parameter_types = ", ".join(types)
|
|
|
|
return return_line[1].replace("...", parameter_types)
|
|
|
|
|
|
def split_type_line(type_line):
|
|
"""Splits the comment with the type annotation into parts for argument and return types.
|
|
|
|
For example, for an input of:
|
|
# type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
|
|
|
|
This function will return:
|
|
("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
|
|
|
|
"""
|
|
start_offset = len('# type:')
|
|
try:
|
|
arrow_pos = type_line.index('->')
|
|
except ValueError:
|
|
raise RuntimeError("Syntax error in type annotation (cound't find `->`)") from None
|
|
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
|
|
|
|
|
|
def try_real_annotations(fn, loc):
|
|
"""Tries to use the Py3.5+ annotation syntax to get the type."""
|
|
try:
|
|
sig = inspect.signature(fn)
|
|
except ValueError:
|
|
return None
|
|
|
|
all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
|
|
if all(ann is sig.empty for ann in all_annots):
|
|
return None
|
|
|
|
def as_ann(ann):
|
|
# sig.empty is really annoying so convert it to None
|
|
return ann if ann is not sig.empty else None
|
|
|
|
arg_types = [ann_to_type(as_ann(p.annotation), loc)
|
|
for p in sig.parameters.values()]
|
|
return_type = ann_to_type(as_ann(sig.return_annotation), loc)
|
|
return arg_types, return_type
|
|
|
|
|
|
# Finds common type for enum values belonging to an Enum class. If not all
|
|
# values have the same type, AnyType is returned.
|
|
def get_enum_value_type(e: enum.Enum, loc):
|
|
enum_values = list(e)
|
|
if not enum_values:
|
|
raise ValueError("No enum values defined for: '{}'".format(e.__class__))
|
|
|
|
types = set([type(v.value) for v in enum_values])
|
|
ir_types = [try_ann_to_type(t, loc) for t in types]
|
|
|
|
# If Enum values are of different types, an exception will be raised here.
|
|
# Even though Python supports this case, we chose to not implement it to
|
|
# avoid overcomplicate logic here for a rare use case. Please report a
|
|
# feature request if you find it necessary.
|
|
return torch._C.unify_type_list(ir_types)
|
|
|
|
|
|
# Guards against using Enum support in JIT before the feature is complete.
|
|
# TODO(gmagogsfm): remove this check once Enum support is complete.
|
|
def is_enum_support_enabled() -> bool:
|
|
return os.environ.get('EXPERIMENTAL_ENUM_SUPPORT', "0") == "1"
|
|
|
|
|
|
def try_ann_to_type(ann, loc):
|
|
if ann is None:
|
|
return TensorType.get()
|
|
if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
|
|
return TensorType.get()
|
|
if is_tuple(ann):
|
|
return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
|
|
if is_list(ann):
|
|
elem_type = try_ann_to_type(ann.__args__[0], loc)
|
|
if elem_type:
|
|
return ListType(elem_type)
|
|
if is_dict(ann):
|
|
key = try_ann_to_type(ann.__args__[0], loc)
|
|
value = try_ann_to_type(ann.__args__[1], loc)
|
|
return DictType(key, value)
|
|
if is_optional(ann):
|
|
if issubclass(ann.__args__[1], type(None)):
|
|
valid_type = try_ann_to_type(ann.__args__[0], loc)
|
|
else:
|
|
valid_type = try_ann_to_type(ann.__args__[1], loc)
|
|
assert valid_type, "Unsupported annotation {} could not be resolved.".format(repr(ann))
|
|
return OptionalType(valid_type)
|
|
if torch.distributed.rpc.is_available() and is_rref(ann):
|
|
return RRefType(try_ann_to_type(ann.__args__[0], loc))
|
|
if is_future(ann):
|
|
return FutureType(try_ann_to_type(ann.__args__[0], loc))
|
|
if ann is float:
|
|
return FloatType.get()
|
|
if ann is int:
|
|
return IntType.get()
|
|
if ann is str:
|
|
return StringType.get()
|
|
if ann is bool:
|
|
return BoolType.get()
|
|
if ann is Any:
|
|
return AnyType.get()
|
|
if ann is type(None):
|
|
return NoneType.get()
|
|
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
|
|
return InterfaceType(_qualified_name(ann))
|
|
if ann is torch.device:
|
|
return DeviceObjType.get()
|
|
if ann is torch.dtype:
|
|
return IntType.get() # dtype not yet bound in as its own type
|
|
if inspect.isclass(ann) and issubclass(ann, enum.Enum):
|
|
if not is_enum_support_enabled():
|
|
warnings.warn("Enum support is work in progress, enum class {}"
|
|
" is not compiled".format(ann))
|
|
return None
|
|
if not hasattr(ann, "__torch_script_class__"):
|
|
torch.jit._script._recursive_compile_class(ann, loc)
|
|
return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann))
|
|
if inspect.isclass(ann):
|
|
if hasattr(ann, "__torch_script_class__"):
|
|
return ClassType(_qualified_name(ann))
|
|
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
|
|
if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
|
|
torch.jit._script._recursive_compile_class(ann, loc)
|
|
return ClassType(_qualified_name(ann))
|
|
|
|
# Maybe resolve a NamedTuple to a Tuple Type
|
|
def fake_rcb(key):
|
|
return None
|
|
return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
|
|
|
|
|
|
def ann_to_type(ann, loc):
|
|
the_type = try_ann_to_type(ann, loc)
|
|
if the_type is not None:
|
|
return the_type
|
|
raise ValueError("Unknown type annotation: '{}'".format(ann))
|
|
|
|
|
|
__all__ = [
|
|
'Any',
|
|
'List',
|
|
'BroadcastingList1',
|
|
'BroadcastingList2',
|
|
'BroadcastingList3',
|
|
'Tuple',
|
|
'is_tuple',
|
|
'is_list',
|
|
'Dict',
|
|
'is_dict',
|
|
'TensorType',
|
|
'TupleType',
|
|
'FloatType',
|
|
'IntType',
|
|
'ListType',
|
|
'StringType',
|
|
'DictType',
|
|
'AnyType',
|
|
'Module',
|
|
# TODO: Consider not exporting these during wildcard import (reserve
|
|
# that for the types; for idiomatic typing code.)
|
|
'get_signature',
|
|
'check_fn',
|
|
'get_param_names',
|
|
'parse_type_line',
|
|
'get_type_line',
|
|
'split_type_line',
|
|
'try_real_annotations',
|
|
'try_ann_to_type',
|
|
'ann_to_type',
|
|
]
|