Fix exception causes all over the codebase (#90271)

This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
This commit is contained in:
Ram Rachum 2022-12-07 04:28:56 +00:00 committed by PyTorch MergeBot
parent 8f079b895b
commit 351d73b97f
63 changed files with 150 additions and 147 deletions

View File

@ -1001,8 +1001,8 @@ class BenchmarkRunner:
try:
self.model_iter_fn(model, example_inputs)
except Exception:
raise NotImplementedError("Eager model failed to run")
except Exception as e:
raise NotImplementedError("Eager model failed to run") from e
def maybe_cast(self, model, example_inputs):
model = copy.deepcopy(model)

View File

@ -330,8 +330,9 @@ class TransformerModel(nn.Module):
super(TransformerModel, self).__init__()
try:
from torch.nn import TransformerEncoder, TransformerEncoderLayer
except Exception:
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.')
except Exception as e:
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or '
'lower.') from e
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)

View File

@ -210,9 +210,9 @@ class TranslatorRegistry(object):
try:
caffe_ops, params = cls.registry_[layer.type](
layer, pretrained_blobs, is_test, **kwargs)
except KeyError:
except KeyError as e:
raise KeyError('No translator registered for layer: %s yet.' %
str(layer))
str(layer)) from e
if caffe_ops is None:
caffe_ops = []
if type(caffe_ops) is not list:

View File

@ -970,7 +970,7 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
input_name,
err
)
)
) from err
# Finally, let's create the sum operator.
sum_ops, g = self._MakeSumOps(input_name, input_version)
@ -1175,7 +1175,7 @@ class GradientRegistry(object):
raise Exception(
"Exception when creating gradient for [{}]:{}.\nOp: \n{}".
format(op.type, e, str(op))
)
) from e
if gradient_ops is None:
return [], g_input

View File

@ -540,8 +540,8 @@ def ExtractPredictorNet(
'StopGradient'
]
)
except ValueError:
raise Exception("No ops with input={}".format(input_blobs))
except ValueError as e:
raise Exception("No ops with input={}".format(input_blobs)) from e
try:
last_op_with_output = max(
[
@ -549,8 +549,8 @@ def ExtractPredictorNet(
if output_blobs.intersection(ops[j].output)
]
)
except ValueError:
raise Exception("No ops with output={}".format(output_blobs))
except ValueError as e:
raise Exception("No ops with output={}".format(output_blobs)) from e
def validate_op(op):
# Check that the op does not have is_test = 0 set. This is a common

View File

@ -69,10 +69,10 @@ def downloadFromURLToFile(url, filename, show_progress=True):
print("") # New line to fix for progress bar
except HTTPError as e:
raise Exception("Could not download model. [HTTP Error] {code}: {reason}."
.format(code=e.code, reason=e.reason))
.format(code=e.code, reason=e.reason)) from e
except URLError as e:
raise Exception("Could not download model. [URL Error] {reason}."
.format(reason=e.reason))
.format(reason=e.reason)) from e
def getURLFromName(name, filename):

View File

@ -150,9 +150,9 @@ class RoIAlignRotatedOp(hu.HypothesisTestCase):
indexer = [slice(None)] * m.ndim
try:
indexer[axis] = slice(None, None, -1)
except IndexError:
except IndexError as e:
raise ValueError("axis=%i is invalid for the %i-dimensional input array"
% (axis, m.ndim))
% (axis, m.ndim)) from e
return m[tuple(indexer)]
def roialign_ref(X, R):

View File

@ -13,8 +13,8 @@ from caffe2.python import model_helper, workspace
try:
import lmdb
except ImportError:
raise unittest.SkipTest("python-lmdb is not installed")
except ImportError as e:
raise unittest.SkipTest("python-lmdb is not installed") from e
class VideoInputOpTest(unittest.TestCase):

View File

@ -546,8 +546,8 @@ class Struct(Field):
raise AttributeError(item)
try:
return super(Struct, self).__getattribute__("fields")[item]
except KeyError:
raise AttributeError(item)
except KeyError as e:
raise AttributeError(item) from e
def __setattr__(self, key, value):
# Disable setting attributes after initialization to prevent false

View File

@ -29,8 +29,8 @@ def _get_output_shapes(output_value_infos):
def check_gpu_():
try:
C.get_cuda_version()
except Exception as _:
raise Exception("TensorRT related functions require CUDA support")
except Exception as e:
raise Exception("TensorRT related functions require CUDA support") from e
def convert_onnx_model_to_trt_op(onnx_model,
max_batch_size=64,

View File

@ -446,8 +446,8 @@ Please install it via `conda install {module}` or `pip install {module}`
def check_pydep(importname, module):
try:
importlib.import_module(importname)
except ImportError:
raise RuntimeError(missing_pydep.format(importname=importname, module=module))
except ImportError as e:
raise RuntimeError(missing_pydep.format(importname=importname, module=module)) from e
class build_ext(setuptools.command.build_ext.build_ext):

View File

@ -611,8 +611,8 @@ class TestFSDPStateDict(FSDPTest):
def _state_dict(model: Module, state_dict_type: str):
try:
enum_val = STATE_DICT_MAPPING[state_dict_type]
except KeyError:
raise ValueError(f"No state_dict type for {state_dict_type}")
except KeyError as e:
raise ValueError(f"No state_dict type for {state_dict_type}") from e
with FSDP.state_dict_type(model, enum_val):
return model.state_dict()
@ -623,8 +623,8 @@ class TestFSDPStateDict(FSDPTest):
):
try:
enum_val = STATE_DICT_MAPPING[state_dict_type]
except KeyError:
raise ValueError(f"No state_dict for {state_dict_type}")
except KeyError as e:
raise ValueError(f"No state_dict for {state_dict_type}") from e
with FSDP.state_dict_type(model, enum_val):
return model.load_state_dict(state_dict, strict=True)

View File

@ -2598,7 +2598,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
try:
pg_gloo.barrier().wait()
except Exception as e:
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}")
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}") from e
# Now verify communicators on this rank have
# been aborted by watchdog.
self._wait_for_comm_abort(process_group, failed_collective_timeout)

View File

@ -101,8 +101,8 @@ def get_hf_bert(rank):
# in a multiprocessing test
try:
from transformers import BertConfig, AutoModelForMaskedLM
except ImportError:
raise unittest.SkipTest("Unable to import transformers")
except ImportError as e:
raise unittest.SkipTest("Unable to import transformers") from e
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
model = AutoModelForMaskedLM.from_config(config).to(device)

View File

@ -1856,8 +1856,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
def __getattr__(self, item: str):
try:
return self.data[item]
except KeyError:
raise AttributeError
except KeyError as e:
raise AttributeError from e
def tokenization(x):
encoding = BatchEncoding({"key": x})

View File

@ -65,7 +65,7 @@ except (ImportError, AssertionError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
raise unittest.SkipTest("requires sympy/functorch/filelock") from e
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA

View File

@ -128,9 +128,9 @@ class TestDtypeBase(JitTestCase):
inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)]
try:
self.assert_dtype_equal_custom_args(fn, inputs)
except Exception:
except Exception as e:
fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
raise AssertionError(fail_text)
raise AssertionError(fail_text) from e
def assert_dtype_equal_custom_args(self, fn, args):
try:

View File

@ -141,7 +141,7 @@ def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10):
raise e # reraise the exception
exception_message = str(e)
if not re.search(exception_msg_pattern, exception_message):
raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}")
raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}") from e
else:
# We are done for the test case that expects an exception
return

View File

@ -110,8 +110,8 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
try:
from apex import amp
except Exception:
raise unittest.SkipTest("Apex is not available")
except Exception as e:
raise unittest.SkipTest("Apex is not available") from e
input = torch.randn(3, 3, device=torch.device("cuda"))
model = amp.initialize(LinearModel(), opt_level="O2")
self.run_test(model, input)

View File

@ -1703,7 +1703,7 @@ class TestAutograd(TestCase):
self.assertTrue(torch.is_grad_enabled())
yield (-i if has_raised else i)
except UnrecoverableException:
except UnrecoverableException :
self.assertTrue(torch.is_grad_enabled())
raise SecondaryException

View File

@ -3805,7 +3805,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
f"unintended, please revert it. If it was intended, check with the FX " \
f"team to ensure that the proper deprecation protocols have been followed " \
f"and subsequently --accept the change."
raise AssertionError(msg)
raise AssertionError(msg) from e
def test_public_api_surface(self):
non_back_compat_objects = {}

View File

@ -597,7 +597,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_minmax_int_ops(self):
def apply(fn):
@ -627,7 +627,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_comparison_eq_ne(self):
for device in self.devices:
@ -1288,7 +1288,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)])
)
) from e
def test_isnan(self):
x = torch.rand([4])
@ -1321,7 +1321,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), 'isnan', device])
)
) from e
def test_gelu(self):
def apply(fn):
@ -1352,7 +1352,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
)
) from e
def test_unary_ops(self):
with torch._jit_internal._disable_emit_hooks():
@ -1435,7 +1435,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
)
) from e
def test_binary_ops(self):
def apply(fn):
@ -1488,7 +1488,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_binary_scalar_ops(self):
def apply(fn):
@ -1534,7 +1534,7 @@ class TestTEFuser(JitTestCase):
try:
k = torch._C._te.TensorExprKernel(graph)
except Exception as e:
raise RuntimeError(" ".join(["Compilation failed:", device, str(code)]))
raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) from e
# Run the graph
for x, y in product(values[dtype_x], values[dtype_y]):
@ -1543,7 +1543,7 @@ class TestTEFuser(JitTestCase):
res = k.run((x, y))
self.assertEqual(ref, res)
except Exception as e:
raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)]))
raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) from e
def test_matmul(self):
if self.dynamic_shapes:
@ -1599,7 +1599,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), device])
)
) from e
def test_binary_tensor_scalar_ops(self):
with torch._jit_internal._disable_emit_hooks():
@ -1643,7 +1643,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_binary_div_ops(self):
def apply_with_scalar(fn, scalar):
@ -1676,7 +1676,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
"Failed: {} {} {} {}".format(dtype, op.__name__, device, scalar)
)
) from e
def test_binary_pow(self):
def apply_with_scalar(fn, scalar):
@ -1714,7 +1714,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_ternary_ops(self):
def apply(fn):
@ -1746,7 +1746,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_ternary_norm_ops(self):
def apply(fn):
@ -1777,7 +1777,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
@unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure")
@ -1810,7 +1810,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_where_ops(self):
def apply(fn):
@ -1843,7 +1843,7 @@ class TestTEFuser(JitTestCase):
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device])
)
) from e
def test_unsupported_dtypes(self):
for device in self.devices:

View File

@ -467,7 +467,7 @@ def run_meta_crossref(
# they're not tested outside of gradcheck which only checks
# torch.float64 and torch.complex128 (which this second one
# often skipped as well).
raise unittest.SkipTest("Original OpInfo is broken")
raise unittest.SkipTest("Original OpInfo is broken") from e
# TODO: also handle cases where func raise an exception

View File

@ -28,13 +28,13 @@ class UniqueKeyLoader(Loader):
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
try:
hash(key)
except TypeError:
except TypeError as e:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key ",
key_node.start_mark,
)
) from e
# check for duplicate keys
if key in mapping:
raise ConstructorError(

View File

@ -12,10 +12,10 @@ try:
TestCase,
TestSuite,
)
except ImportError:
except ImportError as e:
raise ImportError(
"junitparser not found, please install with 'pip install junitparser'"
)
) from e
try:
import rich

View File

@ -364,7 +364,7 @@ def helper_for_dump_minify(contents):
fd.write(contents)
except OSError as e:
log.exception(e)
raise NotImplementedError("Could not write to {minified_repro_path}")
raise NotImplementedError("Could not write to {minified_repro_path}") from e
def dump_to_minify(gm, args, compiler_name: str):

View File

@ -283,7 +283,7 @@ def break_graph_if_unsupported(*, push):
if self.has_backedge():
msg = "Skipping frame because there is a graph break in a for/while loop"
log.debug(msg)
raise exc.SkipFrame(msg)
raise exc.SkipFrame(msg) from excp
if not self.should_compile_partial_graph():
raise

View File

@ -348,13 +348,13 @@ def proxy_args_kwargs(args, kwargs):
proxy_args = tuple(arg.as_proxy() for arg in args)
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
return proxy_args, proxy_kwargs
except NotImplementedError:
except NotImplementedError as e:
from .exc import unimplemented
from .variables.base import typestr
raise unimplemented(
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
)
) from e
@dataclasses.dataclass
@ -745,7 +745,7 @@ def wrap_fake_exception(fn):
msg = f"Unsupported: {e.reason} with fake tensor propagation."
log.warning(msg)
raise unimplemented(msg)
raise unimplemented(msg) from e
def wrap_to_fake_tensor(e, fake_mode):

View File

@ -56,8 +56,8 @@ class ConstantVariable(VariableTracker):
try:
options = VariableTracker.propagate([self])
return [ConstantVariable(x, **options) for x in self.as_python_constant()]
except TypeError:
raise NotImplementedError()
except TypeError as e:
raise NotImplementedError from e
def const_getattr(self, tx, name):
member = getattr(self.value, name)

View File

@ -504,8 +504,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
try:
fn = inspect.getattr_static(self.value_type, "__iter__")
except AttributeError:
raise NotImplementedError()
except AttributeError as e:
raise NotImplementedError from e
if fn in (
torch.nn.ModuleList.__iter__,

View File

@ -277,8 +277,9 @@ def min_cut_rematerialization_partition(
"""
try:
import networkx as nx
except ImportError:
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
except ImportError as e:
raise RuntimeError("Need networkx installed to perform smart recomputation "
"heuristics") from e
joint_module.graph.eliminate_dead_code()
joint_module.recompile()

View File

@ -435,7 +435,7 @@ class CppCodeCache:
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output)
raise exc.CppCompileError(cmd, e.output) from e
cls.cache[key] = cls._load_library(output_path)
cls.cache[key].key = key

View File

@ -184,7 +184,7 @@ class CachingAutotuner(KernelInterface):
raise RuntimeError(
"""Consider updating Triton with
`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
)
) from e
else:
raise e

View File

@ -136,5 +136,5 @@ class CrossRefFakeMode(TorchDispatchMode):
r_out, fake_out, check_strides=self.check_strides
)
except Exception as e:
raise RuntimeError(f"Mismatch on {func}: {e}")
raise RuntimeError(f"Mismatch on {func}: {e}") from e
return r

View File

@ -819,7 +819,7 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool:
except RuntimeError as ex:
# Rethrow to provide a better error message
raise GradcheckError(
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}')
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') from ex
for input_idx, (res, exp) in enumerate(zip(result, expected)):
if torch.allclose(res, exp):
@ -861,7 +861,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
# autograd.grad instead of the C++ traceback of what line in the
# backward formula
raise GradcheckError(
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}')
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') from ex
for input_idx, (res, exp) in enumerate(zip(result, expected)):
if torch.allclose(res, exp):
@ -977,12 +977,12 @@ def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
try:
grads_input = torch.autograd.grad(output_to_check, diff_input_list,
grads_output, allow_unused=True)
except RuntimeError:
except RuntimeError as e:
warn_bc_breaking()
raise GradcheckError((
'Expected backward function to handle undefined output grads. '
'Please look at "Notes about undefined output gradients" in '
'"tools/autograd/derivatives.yaml"'))
'"tools/autograd/derivatives.yaml"')) from e
for gi, i in zip(grads_input, diff_input_list):
if (gi is not None) and (not gi.eq(0).all()):

View File

@ -669,13 +669,13 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
"""
try:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
raise ModuleNotFoundError("pynvml module not found, please install pynvml")
except ModuleNotFoundError as e:
raise ModuleNotFoundError("pynvml module not found, please install pynvml") from e
from pynvml import NVMLError_DriverNotLoaded
try:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?")
except NVMLError_DriverNotLoaded as e:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
device = _get_device_index(device, optional=True)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
@ -695,13 +695,13 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
"""
try:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
raise ModuleNotFoundError("pynvml module not found, please install pynvml")
except ModuleNotFoundError as e:
raise ModuleNotFoundError("pynvml module not found, please install pynvml") from e
from pynvml import NVMLError_DriverNotLoaded
try:
pynvml.nvmlInit()
except NVMLError_DriverNotLoaded:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?")
except NVMLError_DriverNotLoaded as e:
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
device = _get_device_index(device, optional=True)
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu

View File

@ -1640,10 +1640,10 @@ class DistributedDataParallel(Module, Joinable):
)
try:
overlapped_optim.register_ddp(self)
except NotImplementedError:
except NotImplementedError as e:
raise RuntimeError(
f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
)
) from e
def _distributed_broadcast_coalesced(
self, tensors, buffer_size, authoritative_rank=0

View File

@ -195,11 +195,11 @@ class RendezvousParameters:
return value
try:
return int(value)
except ValueError:
except ValueError as e:
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid integer "
"value."
)
) from e
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
@ -244,11 +244,11 @@ class RendezvousHandlerRegistry:
"""Creates a new :py:class:`RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError:
except KeyError as e:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call `{self.register.__name__}`?"
)
) from e
handler = creator(params)

View File

@ -464,10 +464,10 @@ class EtcdRendezvous(object):
version_counter = self.client.get(self.get_path("/rdzv/version_counter"))
version_counter.value = str(int(version_counter.value) + 1)
self.client.update(version_counter)
except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed):
except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e:
raise RendezvousError(
"Unexpected state of EtcdRendezvousHandler, worker needs to die."
)
) from e
# Any failure below results in declaring a retryable rendezvous failure.
# The ephemeral /rdzv/active_version will expire and someone can then

View File

@ -381,8 +381,8 @@ def _get_ignored_modules(
msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
try:
ignored_root_modules = set(_ignored_modules)
except TypeError:
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}")
except TypeError as e:
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
for module in ignored_root_modules:
if not isinstance(module, torch.nn.Module):
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")

View File

@ -1015,11 +1015,11 @@ def _get_param_id_to_param_from_optim_input(
return list(model.parameters())
try:
params = list(optim_input)
except TypeError:
except TypeError as e:
raise TypeError(
"Optimizer input should be an iterable of Tensors or dicts, "
f"but got {optim_input}"
)
) from e
if len(params) == 0:
raise ValueError("Optimizer input should not be empty")

View File

@ -44,8 +44,9 @@ def register_functional_optim(key, optim):
def as_functional_optim(optim_cls: Type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError:
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
except KeyError as e:
raise ValueError(f"Optimizer {optim_cls} does not have a functional "
f"counterpart!") from e
return _create_functional_optim(functional_cls, *args, **kwargs)

View File

@ -1323,9 +1323,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
f"Tensors, but got {torch.typename(params)}")
try:
all_params = list(params)
except TypeError:
except TypeError as e:
raise TypeError("`params` argument should be an iterable of Tensors"
f" or dicts, but got {torch.typename(params)}")
f" or dicts, but got {torch.typename(params)}") from e
if len(all_params) == 0:
raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
"list")

View File

@ -198,8 +198,8 @@ class Skippable(nn.Module):
for ns, name in self.poppable():
try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
except KeyError:
raise RuntimeError(f"'{name}' has not been stashed")
except KeyError as e:
raise RuntimeError(f"'{name}' has not been stashed") from e
input = batch.values
# Handle skip commands.

View File

@ -1,9 +1,9 @@
try:
from urllib.parse import urlparse, urlunparse
except ImportError:
except ImportError as e:
raise ImportError(
"urllib cannot be found, urlparse from python2 is no longer supported."
)
) from e
import numbers
import os

View File

@ -228,7 +228,7 @@ def _handle_exception(result):
raise RuntimeError( # noqa: B904
f"Failed to create original exception type. Error msg was {str(e)}"
f" Original exception on remote side was {exception_msg}"
)
) from e
if exc is not None:
raise exc

View File

@ -605,13 +605,13 @@ def determine_local_world_size(nproc_per_node: str):
try:
logging.info(f"Using nproc_per_node={nproc_per_node}.")
return int(nproc_per_node)
except ValueError:
except ValueError as e:
if nproc_per_node == "cpu":
num_proc = os.cpu_count()
device_type = "cpu"
elif nproc_per_node == "gpu":
if not torch.cuda.is_available():
raise ValueError("Cuda is not available.")
raise ValueError("Cuda is not available.") from e
device_type = "gpu"
num_proc = torch.cuda.device_count()
elif nproc_per_node == "auto":
@ -622,7 +622,7 @@ def determine_local_world_size(nproc_per_node: str):
num_proc = os.cpu_count()
device_type = "cpu"
else:
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
log.info(
f"Using nproc_per_node={nproc_per_node},"

View File

@ -20,13 +20,13 @@ def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedD
# Parse the function declaration
try:
py_ast = ast.parse(decl)
except SyntaxError:
except SyntaxError as e:
# This should only happen if there's some unforeseeable change
# in the dataclasses module that makes our synthesized code fail
raise RuntimeError(
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
)
) from e
fake_filename = _get_fake_filename(cls, name)
# Parse the function
return ParsedDef(

View File

@ -185,7 +185,7 @@ def infer_concrete_type_builder(nn_module, share_types=True):
except RuntimeError as re:
raise RuntimeError(
"Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re)
)
) from re
return attr_type, inferred

View File

@ -292,8 +292,8 @@ def tensorboard_trace_handler(dir_name: str, worker_name: Optional[str] = None,
if not os.path.isdir(dir_name):
try:
os.makedirs(dir_name, exist_ok=True)
except Exception:
raise RuntimeError("Can't create directory: " + dir_name)
except Exception as e:
raise RuntimeError("Can't create directory: " + dir_name) from e
if not worker_name:
worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
file_name = "{}.{}.pt.trace.json".format(worker_name, int(time.time() * 1000))

View File

@ -677,8 +677,8 @@ class TensorLikePair(Pair):
try:
return torch.as_tensor(tensor_like)
except Exception:
raise UnsupportedInputs()
except Exception as e:
raise UnsupportedInputs() from e
def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
if tensor.layout not in {

View File

@ -949,7 +949,7 @@ class FSDPTest(MultiProcessTestCase):
**init_kwargs,
)
except Exception as e:
raise ValueError(f"Initializing {model_class} raised error {str(e)}")
raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
if not isinstance(fsdp_model, FSDP):
# Enforce that we wrap with top-level FSDP since we are comparing
# assuming a data parallel reference and some test models may not

View File

@ -3260,7 +3260,7 @@ def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
if any(connect_error in str(error) for connect_error in connect_errors):
tries_remaining -= 1
if tries_remaining == 0:
raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}")
raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
time.sleep(random.random())
continue
raise
@ -4001,8 +4001,8 @@ def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
"""
try:
return next(iter(samples))
except StopIteration:
raise unittest.SkipTest('Skipped! Need at least 1 sample input')
except StopIteration as e:
raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
# this helper method is to recursively
# clone the tensor-type input of operators tested by OpInfo

View File

@ -290,8 +290,8 @@ class DTensorConverter(object):
tree_unflatten(new_args, self.flatten_args_spec),
tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
)
except StopIteration:
raise StopIteration
except StopIteration as e:
raise StopIteration from e
def to_dist_tensor(
self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]

View File

@ -349,10 +349,10 @@ class _DataPipeSerializationWrapper:
def __len__(self):
try:
return len(self._datapipe)
except Exception:
except Exception as e:
raise TypeError(
"{} instance doesn't have valid length".format(type(self).__name__)
)
) from e
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):

View File

@ -154,8 +154,8 @@ def _collate_helper(conversion, item):
try:
import torcharrow.pytorch as tap # type: ignore[import]
collation_fn = tap.rec.Default()
except Exception:
raise Exception("unable to import default collation function from the TorchArrow")
except Exception as e:
raise Exception("unable to import default collation function from the TorchArrow") from e
tuple_names.append(str(name))
value = collation_fn(df[name])

View File

@ -93,8 +93,8 @@ class ZipperMapDataPipe(MapDataPipe[Tuple[T_co, ...]]):
for dp in self.datapipes:
try:
res.append(dp[index])
except IndexError:
raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.")
except IndexError as e:
raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") from e
return tuple(res)
def __len__(self) -> int:

View File

@ -53,11 +53,11 @@ class BatcherMapDataPipe(MapDataPipe[DataChunk]):
for i in indices:
batch.append(self.datapipe[i])
return self.wrapper_class(batch)
except IndexError:
except IndexError as e:
if not self.drop_last and len(batch) > 0:
return self.wrapper_class(batch)
else:
raise IndexError(f"Index {index} is out of bound.")
raise IndexError(f"Index {index} is out of bound.") from e
def __len__(self) -> int:
if self.length is not None:

View File

@ -148,13 +148,13 @@ class ImageHandler:
import numpy as np
except ImportError as e:
raise ModuleNotFoundError("Package `numpy` is required to be installed for default image decoder."
"Please use `pip install numpy` to install the package")
"Please use `pip install numpy` to install the package") from e
try:
import PIL.Image
except ImportError as e:
raise ModuleNotFoundError("Package `PIL` is required to be installed for default image decoder."
"Please use `pip install Pillow` to install the package")
"Please use `pip install Pillow` to install the package") from e
imagespec = self.imagespec
atype, etype, mode = imagespecs[imagespec]
@ -200,7 +200,7 @@ def videohandler(extension, data):
except ImportError as e:
raise ModuleNotFoundError("Package `torchvision` is required to be installed for default video file loader."
"Please use `pip install torchvision` or `conda install torchvision -c pytorch`"
"to install the package")
"to install the package") from e
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
@ -221,7 +221,7 @@ def audiohandler(extension, data):
except ImportError as e:
raise ModuleNotFoundError("Package `torchaudio` is required to be installed for default audio file loader."
"Please use `pip install torchaudio` or `conda install torchaudio -c pytorch`"
"to install the package")
"to install the package") from e
with tempfile.TemporaryDirectory() as dirname:
fname = os.path.join(dirname, f"file.{extension}")
@ -240,7 +240,7 @@ class MatHandler:
except ImportError as e:
raise ModuleNotFoundError("Package `scipy` is required to be installed for mat file."
"Please use `pip install scipy` or `conda install scipy`"
"to install the package")
"to install the package") from e
self.sio = sio
self.loadmat_kwargs = loadmat_kwargs

View File

@ -47,9 +47,9 @@ def _simple_graph_snapshot_restoration(datapipe: IterDataPipe, n_iterations: int
try:
next(it)
remainder -= 1
except StopIteration:
except StopIteration as e:
raise RuntimeError(f"Fast-forward {datapipe} by {n_iterations} iterations "
"exceeds the number of samples available.")
"exceeds the number of samples available.") from e
datapipe._fast_forward_iterator = it
# While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.

View File

@ -262,10 +262,10 @@ def error_on_missing_kernels(
try:
with open(kernel_defn_file_path, "r") as f:
backend_defns = f.read()
except IOError:
except IOError as e:
raise AssertionError(
f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
)
) from e
if full_codegen is None:
full_codegen = []

View File

@ -137,10 +137,10 @@ def validate_shape_inference_header(
with open(shape_inference_hdr, "r") as f:
shape_infr_decls = f.read()
shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
except IOError:
except IOError as e:
raise AssertionError(
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
)
) from e
shape_infr_regex = r"compute_shape_(\w+)"
actual_shape_infr_name_counts = Counter(

View File

@ -1719,8 +1719,8 @@ class Type:
return CustomClassType(m.group(1))
try:
return BaseType(BaseTy[t])
except KeyError:
raise RuntimeError(f"unrecognized type {t}")
except KeyError as e:
raise RuntimeError(f"unrecognized type {t}") from e
def __str__(self) -> str:
raise NotImplementedError