Fix exception causes all over the codebase (#90271)

This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
This commit is contained in:
Ram Rachum 2022-12-07 04:28:56 +00:00 committed by PyTorch MergeBot
parent 8f079b895b
commit 351d73b97f
63 changed files with 150 additions and 147 deletions

View File

@ -1001,8 +1001,8 @@ class BenchmarkRunner:
try: try:
self.model_iter_fn(model, example_inputs) self.model_iter_fn(model, example_inputs)
except Exception: except Exception as e:
raise NotImplementedError("Eager model failed to run") raise NotImplementedError("Eager model failed to run") from e
def maybe_cast(self, model, example_inputs): def maybe_cast(self, model, example_inputs):
model = copy.deepcopy(model) model = copy.deepcopy(model)

View File

@ -330,8 +330,9 @@ class TransformerModel(nn.Module):
super(TransformerModel, self).__init__() super(TransformerModel, self).__init__()
try: try:
from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.nn import TransformerEncoder, TransformerEncoderLayer
except Exception: except Exception as e:
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.') raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or '
'lower.') from e
self.model_type = 'Transformer' self.model_type = 'Transformer'
self.src_mask = None self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout) self.pos_encoder = PositionalEncoding(ninp, dropout)

View File

@ -210,9 +210,9 @@ class TranslatorRegistry(object):
try: try:
caffe_ops, params = cls.registry_[layer.type]( caffe_ops, params = cls.registry_[layer.type](
layer, pretrained_blobs, is_test, **kwargs) layer, pretrained_blobs, is_test, **kwargs)
except KeyError: except KeyError as e:
raise KeyError('No translator registered for layer: %s yet.' % raise KeyError('No translator registered for layer: %s yet.' %
str(layer)) str(layer)) from e
if caffe_ops is None: if caffe_ops is None:
caffe_ops = [] caffe_ops = []
if type(caffe_ops) is not list: if type(caffe_ops) is not list:

View File

@ -970,7 +970,7 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
input_name, input_name,
err err
) )
) ) from err
# Finally, let's create the sum operator. # Finally, let's create the sum operator.
sum_ops, g = self._MakeSumOps(input_name, input_version) sum_ops, g = self._MakeSumOps(input_name, input_version)
@ -1175,7 +1175,7 @@ class GradientRegistry(object):
raise Exception( raise Exception(
"Exception when creating gradient for [{}]:{}.\nOp: \n{}". "Exception when creating gradient for [{}]:{}.\nOp: \n{}".
format(op.type, e, str(op)) format(op.type, e, str(op))
) ) from e
if gradient_ops is None: if gradient_ops is None:
return [], g_input return [], g_input

View File

@ -540,8 +540,8 @@ def ExtractPredictorNet(
'StopGradient' 'StopGradient'
] ]
) )
except ValueError: except ValueError as e:
raise Exception("No ops with input={}".format(input_blobs)) raise Exception("No ops with input={}".format(input_blobs)) from e
try: try:
last_op_with_output = max( last_op_with_output = max(
[ [
@ -549,8 +549,8 @@ def ExtractPredictorNet(
if output_blobs.intersection(ops[j].output) if output_blobs.intersection(ops[j].output)
] ]
) )
except ValueError: except ValueError as e:
raise Exception("No ops with output={}".format(output_blobs)) raise Exception("No ops with output={}".format(output_blobs)) from e
def validate_op(op): def validate_op(op):
# Check that the op does not have is_test = 0 set. This is a common # Check that the op does not have is_test = 0 set. This is a common

View File

@ -69,10 +69,10 @@ def downloadFromURLToFile(url, filename, show_progress=True):
print("") # New line to fix for progress bar print("") # New line to fix for progress bar
except HTTPError as e: except HTTPError as e:
raise Exception("Could not download model. [HTTP Error] {code}: {reason}." raise Exception("Could not download model. [HTTP Error] {code}: {reason}."
.format(code=e.code, reason=e.reason)) .format(code=e.code, reason=e.reason)) from e
except URLError as e: except URLError as e:
raise Exception("Could not download model. [URL Error] {reason}." raise Exception("Could not download model. [URL Error] {reason}."
.format(reason=e.reason)) .format(reason=e.reason)) from e
def getURLFromName(name, filename): def getURLFromName(name, filename):

View File

@ -150,9 +150,9 @@ class RoIAlignRotatedOp(hu.HypothesisTestCase):
indexer = [slice(None)] * m.ndim indexer = [slice(None)] * m.ndim
try: try:
indexer[axis] = slice(None, None, -1) indexer[axis] = slice(None, None, -1)
except IndexError: except IndexError as e:
raise ValueError("axis=%i is invalid for the %i-dimensional input array" raise ValueError("axis=%i is invalid for the %i-dimensional input array"
% (axis, m.ndim)) % (axis, m.ndim)) from e
return m[tuple(indexer)] return m[tuple(indexer)]
def roialign_ref(X, R): def roialign_ref(X, R):

View File

@ -13,8 +13,8 @@ from caffe2.python import model_helper, workspace
try: try:
import lmdb import lmdb
except ImportError: except ImportError as e:
raise unittest.SkipTest("python-lmdb is not installed") raise unittest.SkipTest("python-lmdb is not installed") from e
class VideoInputOpTest(unittest.TestCase): class VideoInputOpTest(unittest.TestCase):

View File

@ -546,8 +546,8 @@ class Struct(Field):
raise AttributeError(item) raise AttributeError(item)
try: try:
return super(Struct, self).__getattribute__("fields")[item] return super(Struct, self).__getattribute__("fields")[item]
except KeyError: except KeyError as e:
raise AttributeError(item) raise AttributeError(item) from e
def __setattr__(self, key, value): def __setattr__(self, key, value):
# Disable setting attributes after initialization to prevent false # Disable setting attributes after initialization to prevent false

View File

@ -29,8 +29,8 @@ def _get_output_shapes(output_value_infos):
def check_gpu_(): def check_gpu_():
try: try:
C.get_cuda_version() C.get_cuda_version()
except Exception as _: except Exception as e:
raise Exception("TensorRT related functions require CUDA support") raise Exception("TensorRT related functions require CUDA support") from e
def convert_onnx_model_to_trt_op(onnx_model, def convert_onnx_model_to_trt_op(onnx_model,
max_batch_size=64, max_batch_size=64,

View File

@ -446,8 +446,8 @@ Please install it via `conda install {module}` or `pip install {module}`
def check_pydep(importname, module): def check_pydep(importname, module):
try: try:
importlib.import_module(importname) importlib.import_module(importname)
except ImportError: except ImportError as e:
raise RuntimeError(missing_pydep.format(importname=importname, module=module)) raise RuntimeError(missing_pydep.format(importname=importname, module=module)) from e
class build_ext(setuptools.command.build_ext.build_ext): class build_ext(setuptools.command.build_ext.build_ext):

View File

@ -611,8 +611,8 @@ class TestFSDPStateDict(FSDPTest):
def _state_dict(model: Module, state_dict_type: str): def _state_dict(model: Module, state_dict_type: str):
try: try:
enum_val = STATE_DICT_MAPPING[state_dict_type] enum_val = STATE_DICT_MAPPING[state_dict_type]
except KeyError: except KeyError as e:
raise ValueError(f"No state_dict type for {state_dict_type}") raise ValueError(f"No state_dict type for {state_dict_type}") from e
with FSDP.state_dict_type(model, enum_val): with FSDP.state_dict_type(model, enum_val):
return model.state_dict() return model.state_dict()
@ -623,8 +623,8 @@ class TestFSDPStateDict(FSDPTest):
): ):
try: try:
enum_val = STATE_DICT_MAPPING[state_dict_type] enum_val = STATE_DICT_MAPPING[state_dict_type]
except KeyError: except KeyError as e:
raise ValueError(f"No state_dict for {state_dict_type}") raise ValueError(f"No state_dict for {state_dict_type}") from e
with FSDP.state_dict_type(model, enum_val): with FSDP.state_dict_type(model, enum_val):
return model.load_state_dict(state_dict, strict=True) return model.load_state_dict(state_dict, strict=True)

View File

@ -2598,7 +2598,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
try: try:
pg_gloo.barrier().wait() pg_gloo.barrier().wait()
except Exception as e: except Exception as e:
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}") raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}") from e
# Now verify communicators on this rank have # Now verify communicators on this rank have
# been aborted by watchdog. # been aborted by watchdog.
self._wait_for_comm_abort(process_group, failed_collective_timeout) self._wait_for_comm_abort(process_group, failed_collective_timeout)

View File

@ -101,8 +101,8 @@ def get_hf_bert(rank):
# in a multiprocessing test # in a multiprocessing test
try: try:
from transformers import BertConfig, AutoModelForMaskedLM from transformers import BertConfig, AutoModelForMaskedLM
except ImportError: except ImportError as e:
raise unittest.SkipTest("Unable to import transformers") raise unittest.SkipTest("Unable to import transformers") from e
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}" batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
model = AutoModelForMaskedLM.from_config(config).to(device) model = AutoModelForMaskedLM.from_config(config).to(device)

View File

@ -1856,8 +1856,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
def __getattr__(self, item: str): def __getattr__(self, item: str):
try: try:
return self.data[item] return self.data[item]
except KeyError: except KeyError as e:
raise AttributeError raise AttributeError from e
def tokenization(x): def tokenization(x):
encoding = BatchEncoding({"key": x}) encoding = BatchEncoding({"key": x})

View File

@ -65,7 +65,7 @@ except (ImportError, AssertionError) as e:
sys.stderr.write(f"{type(e)}: {e}\n") sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(0) sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock") raise unittest.SkipTest("requires sympy/functorch/filelock") from e
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA

View File

@ -128,9 +128,9 @@ class TestDtypeBase(JitTestCase):
inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)] inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)]
try: try:
self.assert_dtype_equal_custom_args(fn, inputs) self.assert_dtype_equal_custom_args(fn, inputs)
except Exception: except Exception as e:
fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}" fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
raise AssertionError(fail_text) raise AssertionError(fail_text) from e
def assert_dtype_equal_custom_args(self, fn, args): def assert_dtype_equal_custom_args(self, fn, args):
try: try:

View File

@ -141,7 +141,7 @@ def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10):
raise e # reraise the exception raise e # reraise the exception
exception_message = str(e) exception_message = str(e)
if not re.search(exception_msg_pattern, exception_message): if not re.search(exception_msg_pattern, exception_message):
raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}") raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}") from e
else: else:
# We are done for the test case that expects an exception # We are done for the test case that expects an exception
return return

View File

@ -110,8 +110,8 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
try: try:
from apex import amp from apex import amp
except Exception: except Exception as e:
raise unittest.SkipTest("Apex is not available") raise unittest.SkipTest("Apex is not available") from e
input = torch.randn(3, 3, device=torch.device("cuda")) input = torch.randn(3, 3, device=torch.device("cuda"))
model = amp.initialize(LinearModel(), opt_level="O2") model = amp.initialize(LinearModel(), opt_level="O2")
self.run_test(model, input) self.run_test(model, input)

View File

@ -1703,7 +1703,7 @@ class TestAutograd(TestCase):
self.assertTrue(torch.is_grad_enabled()) self.assertTrue(torch.is_grad_enabled())
yield (-i if has_raised else i) yield (-i if has_raised else i)
except UnrecoverableException: except UnrecoverableException :
self.assertTrue(torch.is_grad_enabled()) self.assertTrue(torch.is_grad_enabled())
raise SecondaryException raise SecondaryException

View File

@ -3805,7 +3805,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
f"unintended, please revert it. If it was intended, check with the FX " \ f"unintended, please revert it. If it was intended, check with the FX " \
f"team to ensure that the proper deprecation protocols have been followed " \ f"team to ensure that the proper deprecation protocols have been followed " \
f"and subsequently --accept the change." f"and subsequently --accept the change."
raise AssertionError(msg) raise AssertionError(msg) from e
def test_public_api_surface(self): def test_public_api_surface(self):
non_back_compat_objects = {} non_back_compat_objects = {}

View File

@ -597,7 +597,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_minmax_int_ops(self): def test_minmax_int_ops(self):
def apply(fn): def apply(fn):
@ -627,7 +627,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_comparison_eq_ne(self): def test_comparison_eq_ne(self):
for device in self.devices: for device in self.devices:
@ -1288,7 +1288,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) " ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)])
) ) from e
def test_isnan(self): def test_isnan(self):
x = torch.rand([4]) x = torch.rand([4])
@ -1321,7 +1321,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), 'isnan', device]) " ".join(["Failed:", str(dtype), 'isnan', device])
) ) from e
def test_gelu(self): def test_gelu(self):
def apply(fn): def apply(fn):
@ -1352,7 +1352,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) " ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
) ) from e
def test_unary_ops(self): def test_unary_ops(self):
with torch._jit_internal._disable_emit_hooks(): with torch._jit_internal._disable_emit_hooks():
@ -1435,7 +1435,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) " ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
) ) from e
def test_binary_ops(self): def test_binary_ops(self):
def apply(fn): def apply(fn):
@ -1488,7 +1488,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_binary_scalar_ops(self): def test_binary_scalar_ops(self):
def apply(fn): def apply(fn):
@ -1534,7 +1534,7 @@ class TestTEFuser(JitTestCase):
try: try:
k = torch._C._te.TensorExprKernel(graph) k = torch._C._te.TensorExprKernel(graph)
except Exception as e: except Exception as e:
raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) from e
# Run the graph # Run the graph
for x, y in product(values[dtype_x], values[dtype_y]): for x, y in product(values[dtype_x], values[dtype_y]):
@ -1543,7 +1543,7 @@ class TestTEFuser(JitTestCase):
res = k.run((x, y)) res = k.run((x, y))
self.assertEqual(ref, res) self.assertEqual(ref, res)
except Exception as e: except Exception as e:
raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) from e
def test_matmul(self): def test_matmul(self):
if self.dynamic_shapes: if self.dynamic_shapes:
@ -1599,7 +1599,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), device]) " ".join(["Failed:", str(dtype), device])
) ) from e
def test_binary_tensor_scalar_ops(self): def test_binary_tensor_scalar_ops(self):
with torch._jit_internal._disable_emit_hooks(): with torch._jit_internal._disable_emit_hooks():
@ -1643,7 +1643,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_binary_div_ops(self): def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar): def apply_with_scalar(fn, scalar):
@ -1676,7 +1676,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
"Failed: {} {} {} {}".format(dtype, op.__name__, device, scalar) "Failed: {} {} {} {}".format(dtype, op.__name__, device, scalar)
) ) from e
def test_binary_pow(self): def test_binary_pow(self):
def apply_with_scalar(fn, scalar): def apply_with_scalar(fn, scalar):
@ -1714,7 +1714,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_ternary_ops(self): def test_ternary_ops(self):
def apply(fn): def apply(fn):
@ -1746,7 +1746,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_ternary_norm_ops(self): def test_ternary_norm_ops(self):
def apply(fn): def apply(fn):
@ -1777,7 +1777,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
@unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure") @unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure")
@ -1810,7 +1810,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_where_ops(self): def test_where_ops(self):
def apply(fn): def apply(fn):
@ -1843,7 +1843,7 @@ class TestTEFuser(JitTestCase):
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device]) " ".join(["Failed:", str(dtype), op.__name__, device])
) ) from e
def test_unsupported_dtypes(self): def test_unsupported_dtypes(self):
for device in self.devices: for device in self.devices:

View File

@ -467,7 +467,7 @@ def run_meta_crossref(
# they're not tested outside of gradcheck which only checks # they're not tested outside of gradcheck which only checks
# torch.float64 and torch.complex128 (which this second one # torch.float64 and torch.complex128 (which this second one
# often skipped as well). # often skipped as well).
raise unittest.SkipTest("Original OpInfo is broken") raise unittest.SkipTest("Original OpInfo is broken") from e
# TODO: also handle cases where func raise an exception # TODO: also handle cases where func raise an exception

View File

@ -28,13 +28,13 @@ class UniqueKeyLoader(Loader):
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
try: try:
hash(key) hash(key)
except TypeError: except TypeError as e:
raise ConstructorError( raise ConstructorError(
"while constructing a mapping", "while constructing a mapping",
node.start_mark, node.start_mark,
"found unacceptable key ", "found unacceptable key ",
key_node.start_mark, key_node.start_mark,
) ) from e
# check for duplicate keys # check for duplicate keys
if key in mapping: if key in mapping:
raise ConstructorError( raise ConstructorError(

View File

@ -12,10 +12,10 @@ try:
TestCase, TestCase,
TestSuite, TestSuite,
) )
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"junitparser not found, please install with 'pip install junitparser'" "junitparser not found, please install with 'pip install junitparser'"
) ) from e
try: try:
import rich import rich

View File

@ -364,7 +364,7 @@ def helper_for_dump_minify(contents):
fd.write(contents) fd.write(contents)
except OSError as e: except OSError as e:
log.exception(e) log.exception(e)
raise NotImplementedError("Could not write to {minified_repro_path}") raise NotImplementedError("Could not write to {minified_repro_path}") from e
def dump_to_minify(gm, args, compiler_name: str): def dump_to_minify(gm, args, compiler_name: str):

View File

@ -283,7 +283,7 @@ def break_graph_if_unsupported(*, push):
if self.has_backedge(): if self.has_backedge():
msg = "Skipping frame because there is a graph break in a for/while loop" msg = "Skipping frame because there is a graph break in a for/while loop"
log.debug(msg) log.debug(msg)
raise exc.SkipFrame(msg) raise exc.SkipFrame(msg) from excp
if not self.should_compile_partial_graph(): if not self.should_compile_partial_graph():
raise raise

View File

@ -348,13 +348,13 @@ def proxy_args_kwargs(args, kwargs):
proxy_args = tuple(arg.as_proxy() for arg in args) proxy_args = tuple(arg.as_proxy() for arg in args)
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
return proxy_args, proxy_kwargs return proxy_args, proxy_kwargs
except NotImplementedError: except NotImplementedError as e:
from .exc import unimplemented from .exc import unimplemented
from .variables.base import typestr from .variables.base import typestr
raise unimplemented( raise unimplemented(
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}" f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
) ) from e
@dataclasses.dataclass @dataclasses.dataclass
@ -745,7 +745,7 @@ def wrap_fake_exception(fn):
msg = f"Unsupported: {e.reason} with fake tensor propagation." msg = f"Unsupported: {e.reason} with fake tensor propagation."
log.warning(msg) log.warning(msg)
raise unimplemented(msg) raise unimplemented(msg) from e
def wrap_to_fake_tensor(e, fake_mode): def wrap_to_fake_tensor(e, fake_mode):

View File

@ -56,8 +56,8 @@ class ConstantVariable(VariableTracker):
try: try:
options = VariableTracker.propagate([self]) options = VariableTracker.propagate([self])
return [ConstantVariable(x, **options) for x in self.as_python_constant()] return [ConstantVariable(x, **options) for x in self.as_python_constant()]
except TypeError: except TypeError as e:
raise NotImplementedError() raise NotImplementedError from e
def const_getattr(self, tx, name): def const_getattr(self, tx, name):
member = getattr(self.value, name) member = getattr(self.value, name)

View File

@ -504,8 +504,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
try: try:
fn = inspect.getattr_static(self.value_type, "__iter__") fn = inspect.getattr_static(self.value_type, "__iter__")
except AttributeError: except AttributeError as e:
raise NotImplementedError() raise NotImplementedError from e
if fn in ( if fn in (
torch.nn.ModuleList.__iter__, torch.nn.ModuleList.__iter__,

View File

@ -277,8 +277,9 @@ def min_cut_rematerialization_partition(
""" """
try: try:
import networkx as nx import networkx as nx
except ImportError: except ImportError as e:
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics") raise RuntimeError("Need networkx installed to perform smart recomputation "
"heuristics") from e
joint_module.graph.eliminate_dead_code() joint_module.graph.eliminate_dead_code()
joint_module.recompile() joint_module.recompile()

View File

@ -435,7 +435,7 @@ class CppCodeCache:
try: try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT) subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) raise exc.CppCompileError(cmd, e.output) from e
cls.cache[key] = cls._load_library(output_path) cls.cache[key] = cls._load_library(output_path)
cls.cache[key].key = key cls.cache[key].key = key

View File

@ -184,7 +184,7 @@ class CachingAutotuner(KernelInterface):
raise RuntimeError( raise RuntimeError(
"""Consider updating Triton with """Consider updating Triton with
`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`""" `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
) ) from e
else: else:
raise e raise e

View File

@ -136,5 +136,5 @@ class CrossRefFakeMode(TorchDispatchMode):
r_out, fake_out, check_strides=self.check_strides r_out, fake_out, check_strides=self.check_strides
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Mismatch on {func}: {e}") raise RuntimeError(f"Mismatch on {func}: {e}") from e
return r return r

View File

@ -819,7 +819,7 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool:
except RuntimeError as ex: except RuntimeError as ex:
# Rethrow to provide a better error message # Rethrow to provide a better error message
raise GradcheckError( raise GradcheckError(
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') from ex
for input_idx, (res, exp) in enumerate(zip(result, expected)): for input_idx, (res, exp) in enumerate(zip(result, expected)):
if torch.allclose(res, exp): if torch.allclose(res, exp):
@ -861,7 +861,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
# autograd.grad instead of the C++ traceback of what line in the # autograd.grad instead of the C++ traceback of what line in the
# backward formula # backward formula
raise GradcheckError( raise GradcheckError(
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') from ex
for input_idx, (res, exp) in enumerate(zip(result, expected)): for input_idx, (res, exp) in enumerate(zip(result, expected)):
if torch.allclose(res, exp): if torch.allclose(res, exp):
@ -977,12 +977,12 @@ def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
try: try:
grads_input = torch.autograd.grad(output_to_check, diff_input_list, grads_input = torch.autograd.grad(output_to_check, diff_input_list,
grads_output, allow_unused=True) grads_output, allow_unused=True)
except RuntimeError: except RuntimeError as e:
warn_bc_breaking() warn_bc_breaking()
raise GradcheckError(( raise GradcheckError((
'Expected backward function to handle undefined output grads. ' 'Expected backward function to handle undefined output grads. '
'Please look at "Notes about undefined output gradients" in ' 'Please look at "Notes about undefined output gradients" in '
'"tools/autograd/derivatives.yaml"')) '"tools/autograd/derivatives.yaml"')) from e
for gi, i in zip(grads_input, diff_input_list): for gi, i in zip(grads_input, diff_input_list):
if (gi is not None) and (not gi.eq(0).all()): if (gi is not None) and (not gi.eq(0).all()):

View File

@ -669,13 +669,13 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
""" """
try: try:
import pynvml # type: ignore[import] import pynvml # type: ignore[import]
except ModuleNotFoundError: except ModuleNotFoundError as e:
raise ModuleNotFoundError("pynvml module not found, please install pynvml") raise ModuleNotFoundError("pynvml module not found, please install pynvml") from e
from pynvml import NVMLError_DriverNotLoaded from pynvml import NVMLError_DriverNotLoaded
try: try:
pynvml.nvmlInit() pynvml.nvmlInit()
except NVMLError_DriverNotLoaded: except NVMLError_DriverNotLoaded as e:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
device = _get_device_index(device, optional=True) device = _get_device_index(device, optional=True)
handle = pynvml.nvmlDeviceGetHandleByIndex(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
@ -695,13 +695,13 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
""" """
try: try:
import pynvml # type: ignore[import] import pynvml # type: ignore[import]
except ModuleNotFoundError: except ModuleNotFoundError as e:
raise ModuleNotFoundError("pynvml module not found, please install pynvml") raise ModuleNotFoundError("pynvml module not found, please install pynvml") from e
from pynvml import NVMLError_DriverNotLoaded from pynvml import NVMLError_DriverNotLoaded
try: try:
pynvml.nvmlInit() pynvml.nvmlInit()
except NVMLError_DriverNotLoaded: except NVMLError_DriverNotLoaded as e:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
device = _get_device_index(device, optional=True) device = _get_device_index(device, optional=True)
handle = pynvml.nvmlDeviceGetHandleByIndex(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu

View File

@ -1640,10 +1640,10 @@ class DistributedDataParallel(Module, Joinable):
) )
try: try:
overlapped_optim.register_ddp(self) overlapped_optim.register_ddp(self)
except NotImplementedError: except NotImplementedError as e:
raise RuntimeError( raise RuntimeError(
f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}." f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
) ) from e
def _distributed_broadcast_coalesced( def _distributed_broadcast_coalesced(
self, tensors, buffer_size, authoritative_rank=0 self, tensors, buffer_size, authoritative_rank=0

View File

@ -195,11 +195,11 @@ class RendezvousParameters:
return value return value
try: try:
return int(value) return int(value)
except ValueError: except ValueError as e:
raise ValueError( raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid integer " f"The rendezvous configuration option '{key}' does not represent a valid integer "
"value." "value."
) ) from e
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
@ -244,11 +244,11 @@ class RendezvousHandlerRegistry:
"""Creates a new :py:class:`RendezvousHandler`.""" """Creates a new :py:class:`RendezvousHandler`."""
try: try:
creator = self._registry[params.backend] creator = self._registry[params.backend]
except KeyError: except KeyError as e:
raise ValueError( raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget " f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call `{self.register.__name__}`?" f"to call `{self.register.__name__}`?"
) ) from e
handler = creator(params) handler = creator(params)

View File

@ -464,10 +464,10 @@ class EtcdRendezvous(object):
version_counter = self.client.get(self.get_path("/rdzv/version_counter")) version_counter = self.client.get(self.get_path("/rdzv/version_counter"))
version_counter.value = str(int(version_counter.value) + 1) version_counter.value = str(int(version_counter.value) + 1)
self.client.update(version_counter) self.client.update(version_counter)
except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed): except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e:
raise RendezvousError( raise RendezvousError(
"Unexpected state of EtcdRendezvousHandler, worker needs to die." "Unexpected state of EtcdRendezvousHandler, worker needs to die."
) ) from e
# Any failure below results in declaring a retryable rendezvous failure. # Any failure below results in declaring a retryable rendezvous failure.
# The ephemeral /rdzv/active_version will expire and someone can then # The ephemeral /rdzv/active_version will expire and someone can then

View File

@ -381,8 +381,8 @@ def _get_ignored_modules(
msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s " msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
try: try:
ignored_root_modules = set(_ignored_modules) ignored_root_modules = set(_ignored_modules)
except TypeError: except TypeError as e:
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
for module in ignored_root_modules: for module in ignored_root_modules:
if not isinstance(module, torch.nn.Module): if not isinstance(module, torch.nn.Module):
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")

View File

@ -1015,11 +1015,11 @@ def _get_param_id_to_param_from_optim_input(
return list(model.parameters()) return list(model.parameters())
try: try:
params = list(optim_input) params = list(optim_input)
except TypeError: except TypeError as e:
raise TypeError( raise TypeError(
"Optimizer input should be an iterable of Tensors or dicts, " "Optimizer input should be an iterable of Tensors or dicts, "
f"but got {optim_input}" f"but got {optim_input}"
) ) from e
if len(params) == 0: if len(params) == 0:
raise ValueError("Optimizer input should not be empty") raise ValueError("Optimizer input should not be empty")

View File

@ -44,8 +44,9 @@ def register_functional_optim(key, optim):
def as_functional_optim(optim_cls: Type, *args, **kwargs): def as_functional_optim(optim_cls: Type, *args, **kwargs):
try: try:
functional_cls = functional_optim_map[optim_cls] functional_cls = functional_optim_map[optim_cls]
except KeyError: except KeyError as e:
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!") raise ValueError(f"Optimizer {optim_cls} does not have a functional "
f"counterpart!") from e
return _create_functional_optim(functional_cls, *args, **kwargs) return _create_functional_optim(functional_cls, *args, **kwargs)

View File

@ -1323,9 +1323,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
f"Tensors, but got {torch.typename(params)}") f"Tensors, but got {torch.typename(params)}")
try: try:
all_params = list(params) all_params = list(params)
except TypeError: except TypeError as e:
raise TypeError("`params` argument should be an iterable of Tensors" raise TypeError("`params` argument should be an iterable of Tensors"
f" or dicts, but got {torch.typename(params)}") f" or dicts, but got {torch.typename(params)}") from e
if len(all_params) == 0: if len(all_params) == 0:
raise ValueError("ZeroRedundancyOptimizer got an empty parameter " raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
"list") "list")

View File

@ -198,8 +198,8 @@ class Skippable(nn.Module):
for ns, name in self.poppable(): for ns, name in self.poppable():
try: try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name) poppable_tensors[name] = skip_tracker.load(batch, ns, name)
except KeyError: except KeyError as e:
raise RuntimeError(f"'{name}' has not been stashed") raise RuntimeError(f"'{name}' has not been stashed") from e
input = batch.values input = batch.values
# Handle skip commands. # Handle skip commands.

View File

@ -1,9 +1,9 @@
try: try:
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
except ImportError: except ImportError as e:
raise ImportError( raise ImportError(
"urllib cannot be found, urlparse from python2 is no longer supported." "urllib cannot be found, urlparse from python2 is no longer supported."
) ) from e
import numbers import numbers
import os import os

View File

@ -228,7 +228,7 @@ def _handle_exception(result):
raise RuntimeError( # noqa: B904 raise RuntimeError( # noqa: B904
f"Failed to create original exception type. Error msg was {str(e)}" f"Failed to create original exception type. Error msg was {str(e)}"
f" Original exception on remote side was {exception_msg}" f" Original exception on remote side was {exception_msg}"
) ) from e
if exc is not None: if exc is not None:
raise exc raise exc

View File

@ -605,13 +605,13 @@ def determine_local_world_size(nproc_per_node: str):
try: try:
logging.info(f"Using nproc_per_node={nproc_per_node}.") logging.info(f"Using nproc_per_node={nproc_per_node}.")
return int(nproc_per_node) return int(nproc_per_node)
except ValueError: except ValueError as e:
if nproc_per_node == "cpu": if nproc_per_node == "cpu":
num_proc = os.cpu_count() num_proc = os.cpu_count()
device_type = "cpu" device_type = "cpu"
elif nproc_per_node == "gpu": elif nproc_per_node == "gpu":
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise ValueError("Cuda is not available.") raise ValueError("Cuda is not available.") from e
device_type = "gpu" device_type = "gpu"
num_proc = torch.cuda.device_count() num_proc = torch.cuda.device_count()
elif nproc_per_node == "auto": elif nproc_per_node == "auto":
@ -622,7 +622,7 @@ def determine_local_world_size(nproc_per_node: str):
num_proc = os.cpu_count() num_proc = os.cpu_count()
device_type = "cpu" device_type = "cpu"
else: else:
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
log.info( log.info(
f"Using nproc_per_node={nproc_per_node}," f"Using nproc_per_node={nproc_per_node},"

View File

@ -20,13 +20,13 @@ def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedD
# Parse the function declaration # Parse the function declaration
try: try:
py_ast = ast.parse(decl) py_ast = ast.parse(decl)
except SyntaxError: except SyntaxError as e:
# This should only happen if there's some unforeseeable change # This should only happen if there's some unforeseeable change
# in the dataclasses module that makes our synthesized code fail # in the dataclasses module that makes our synthesized code fail
raise RuntimeError( raise RuntimeError(
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>" "Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
) ) from e
fake_filename = _get_fake_filename(cls, name) fake_filename = _get_fake_filename(cls, name)
# Parse the function # Parse the function
return ParsedDef( return ParsedDef(

View File

@ -185,7 +185,7 @@ def infer_concrete_type_builder(nn_module, share_types=True):
except RuntimeError as re: except RuntimeError as re:
raise RuntimeError( raise RuntimeError(
"Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re) "Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re)
) ) from re
return attr_type, inferred return attr_type, inferred

View File

@ -292,8 +292,8 @@ def tensorboard_trace_handler(dir_name: str, worker_name: Optional[str] = None,
if not os.path.isdir(dir_name): if not os.path.isdir(dir_name):
try: try:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
except Exception: except Exception as e:
raise RuntimeError("Can't create directory: " + dir_name) raise RuntimeError("Can't create directory: " + dir_name) from e
if not worker_name: if not worker_name:
worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid())) worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
file_name = "{}.{}.pt.trace.json".format(worker_name, int(time.time() * 1000)) file_name = "{}.{}.pt.trace.json".format(worker_name, int(time.time() * 1000))

View File

@ -677,8 +677,8 @@ class TensorLikePair(Pair):
try: try:
return torch.as_tensor(tensor_like) return torch.as_tensor(tensor_like)
except Exception: except Exception as e:
raise UnsupportedInputs() raise UnsupportedInputs() from e
def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None: def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
if tensor.layout not in { if tensor.layout not in {

View File

@ -949,7 +949,7 @@ class FSDPTest(MultiProcessTestCase):
**init_kwargs, **init_kwargs,
) )
except Exception as e: except Exception as e:
raise ValueError(f"Initializing {model_class} raised error {str(e)}") raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
if not isinstance(fsdp_model, FSDP): if not isinstance(fsdp_model, FSDP):
# Enforce that we wrap with top-level FSDP since we are comparing # Enforce that we wrap with top-level FSDP since we are comparing
# assuming a data parallel reference and some test models may not # assuming a data parallel reference and some test models may not

View File

@ -3260,7 +3260,7 @@ def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
if any(connect_error in str(error) for connect_error in connect_errors): if any(connect_error in str(error) for connect_error in connect_errors):
tries_remaining -= 1 tries_remaining -= 1
if tries_remaining == 0: if tries_remaining == 0:
raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
time.sleep(random.random()) time.sleep(random.random())
continue continue
raise raise
@ -4001,8 +4001,8 @@ def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
""" """
try: try:
return next(iter(samples)) return next(iter(samples))
except StopIteration: except StopIteration as e:
raise unittest.SkipTest('Skipped! Need at least 1 sample input') raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
# this helper method is to recursively # this helper method is to recursively
# clone the tensor-type input of operators tested by OpInfo # clone the tensor-type input of operators tested by OpInfo

View File

@ -290,8 +290,8 @@ class DTensorConverter(object):
tree_unflatten(new_args, self.flatten_args_spec), tree_unflatten(new_args, self.flatten_args_spec),
tree_unflatten(new_kwargs, self.flatten_kwargs_spec), tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
) )
except StopIteration: except StopIteration as e:
raise StopIteration raise StopIteration from e
def to_dist_tensor( def to_dist_tensor(
self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement] self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]

View File

@ -349,10 +349,10 @@ class _DataPipeSerializationWrapper:
def __len__(self): def __len__(self):
try: try:
return len(self._datapipe) return len(self._datapipe)
except Exception: except Exception as e:
raise TypeError( raise TypeError(
"{} instance doesn't have valid length".format(type(self).__name__) "{} instance doesn't have valid length".format(type(self).__name__)
) ) from e
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):

View File

@ -154,8 +154,8 @@ def _collate_helper(conversion, item):
try: try:
import torcharrow.pytorch as tap # type: ignore[import] import torcharrow.pytorch as tap # type: ignore[import]
collation_fn = tap.rec.Default() collation_fn = tap.rec.Default()
except Exception: except Exception as e:
raise Exception("unable to import default collation function from the TorchArrow") raise Exception("unable to import default collation function from the TorchArrow") from e
tuple_names.append(str(name)) tuple_names.append(str(name))
value = collation_fn(df[name]) value = collation_fn(df[name])

View File

@ -93,8 +93,8 @@ class ZipperMapDataPipe(MapDataPipe[Tuple[T_co, ...]]):
for dp in self.datapipes: for dp in self.datapipes:
try: try:
res.append(dp[index]) res.append(dp[index])
except IndexError: except IndexError as e:
raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") from e
return tuple(res) return tuple(res)
def __len__(self) -> int: def __len__(self) -> int:

View File

@ -53,11 +53,11 @@ class BatcherMapDataPipe(MapDataPipe[DataChunk]):
for i in indices: for i in indices:
batch.append(self.datapipe[i]) batch.append(self.datapipe[i])
return self.wrapper_class(batch) return self.wrapper_class(batch)
except IndexError: except IndexError as e:
if not self.drop_last and len(batch) > 0: if not self.drop_last and len(batch) > 0:
return self.wrapper_class(batch) return self.wrapper_class(batch)
else: else:
raise IndexError(f"Index {index} is out of bound.") raise IndexError(f"Index {index} is out of bound.") from e
def __len__(self) -> int: def __len__(self) -> int:
if self.length is not None: if self.length is not None:

View File

@ -148,13 +148,13 @@ class ImageHandler:
import numpy as np import numpy as np
except ImportError as e: except ImportError as e:
raise ModuleNotFoundError("Package `numpy` is required to be installed for default image decoder." raise ModuleNotFoundError("Package `numpy` is required to be installed for default image decoder."
"Please use `pip install numpy` to install the package") "Please use `pip install numpy` to install the package") from e
try: try:
import PIL.Image import PIL.Image
except ImportError as e: except ImportError as e:
raise ModuleNotFoundError("Package `PIL` is required to be installed for default image decoder." raise ModuleNotFoundError("Package `PIL` is required to be installed for default image decoder."
"Please use `pip install Pillow` to install the package") "Please use `pip install Pillow` to install the package") from e
imagespec = self.imagespec imagespec = self.imagespec
atype, etype, mode = imagespecs[imagespec] atype, etype, mode = imagespecs[imagespec]
@ -200,7 +200,7 @@ def videohandler(extension, data):
except ImportError as e: except ImportError as e:
raise ModuleNotFoundError("Package `torchvision` is required to be installed for default video file loader." raise ModuleNotFoundError("Package `torchvision` is required to be installed for default video file loader."
"Please use `pip install torchvision` or `conda install torchvision -c pytorch`" "Please use `pip install torchvision` or `conda install torchvision -c pytorch`"
"to install the package") "to install the package") from e
with tempfile.TemporaryDirectory() as dirname: with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}") fname = os.path.join(dirname, f"file.{extension}")
@ -221,7 +221,7 @@ def audiohandler(extension, data):
except ImportError as e: except ImportError as e:
raise ModuleNotFoundError("Package `torchaudio` is required to be installed for default audio file loader." raise ModuleNotFoundError("Package `torchaudio` is required to be installed for default audio file loader."
"Please use `pip install torchaudio` or `conda install torchaudio -c pytorch`" "Please use `pip install torchaudio` or `conda install torchaudio -c pytorch`"
"to install the package") "to install the package") from e
with tempfile.TemporaryDirectory() as dirname: with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}") fname = os.path.join(dirname, f"file.{extension}")
@ -240,7 +240,7 @@ class MatHandler:
except ImportError as e: except ImportError as e:
raise ModuleNotFoundError("Package `scipy` is required to be installed for mat file." raise ModuleNotFoundError("Package `scipy` is required to be installed for mat file."
"Please use `pip install scipy` or `conda install scipy`" "Please use `pip install scipy` or `conda install scipy`"
"to install the package") "to install the package") from e
self.sio = sio self.sio = sio
self.loadmat_kwargs = loadmat_kwargs self.loadmat_kwargs = loadmat_kwargs

View File

@ -47,9 +47,9 @@ def _simple_graph_snapshot_restoration(datapipe: IterDataPipe, n_iterations: int
try: try:
next(it) next(it)
remainder -= 1 remainder -= 1
except StopIteration: except StopIteration as e:
raise RuntimeError(f"Fast-forward {datapipe} by {n_iterations} iterations " raise RuntimeError(f"Fast-forward {datapipe} by {n_iterations} iterations "
"exceeds the number of samples available.") "exceeds the number of samples available.") from e
datapipe._fast_forward_iterator = it datapipe._fast_forward_iterator = it
# While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere. # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.

View File

@ -262,10 +262,10 @@ def error_on_missing_kernels(
try: try:
with open(kernel_defn_file_path, "r") as f: with open(kernel_defn_file_path, "r") as f:
backend_defns = f.read() backend_defns = f.read()
except IOError: except IOError as e:
raise AssertionError( raise AssertionError(
f"Unable to read from the specified impl_path file: {kernel_defn_file_path}" f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
) ) from e
if full_codegen is None: if full_codegen is None:
full_codegen = [] full_codegen = []

View File

@ -137,10 +137,10 @@ def validate_shape_inference_header(
with open(shape_inference_hdr, "r") as f: with open(shape_inference_hdr, "r") as f:
shape_infr_decls = f.read() shape_infr_decls = f.read()
shape_infr_decl_lines = set(shape_infr_decls.split("\n")) shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
except IOError: except IOError as e:
raise AssertionError( raise AssertionError(
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}" f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
) ) from e
shape_infr_regex = r"compute_shape_(\w+)" shape_infr_regex = r"compute_shape_(\w+)"
actual_shape_infr_name_counts = Counter( actual_shape_infr_name_counts = Counter(

View File

@ -1719,8 +1719,8 @@ class Type:
return CustomClassType(m.group(1)) return CustomClassType(m.group(1))
try: try:
return BaseType(BaseTy[t]) return BaseType(BaseTy[t])
except KeyError: except KeyError as e:
raise RuntimeError(f"unrecognized type {t}") raise RuntimeError(f"unrecognized type {t}") from e
def __str__(self) -> str: def __str__(self) -> str:
raise NotImplementedError raise NotImplementedError