mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271 Approved by: https://github.com/kit1980
This commit is contained in:
parent
8f079b895b
commit
351d73b97f
|
|
@ -1001,8 +1001,8 @@ class BenchmarkRunner:
|
|||
|
||||
try:
|
||||
self.model_iter_fn(model, example_inputs)
|
||||
except Exception:
|
||||
raise NotImplementedError("Eager model failed to run")
|
||||
except Exception as e:
|
||||
raise NotImplementedError("Eager model failed to run") from e
|
||||
|
||||
def maybe_cast(self, model, example_inputs):
|
||||
model = copy.deepcopy(model)
|
||||
|
|
|
|||
|
|
@ -330,8 +330,9 @@ class TransformerModel(nn.Module):
|
|||
super(TransformerModel, self).__init__()
|
||||
try:
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
except Exception:
|
||||
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.')
|
||||
except Exception as e:
|
||||
raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or '
|
||||
'lower.') from e
|
||||
self.model_type = 'Transformer'
|
||||
self.src_mask = None
|
||||
self.pos_encoder = PositionalEncoding(ninp, dropout)
|
||||
|
|
|
|||
|
|
@ -210,9 +210,9 @@ class TranslatorRegistry(object):
|
|||
try:
|
||||
caffe_ops, params = cls.registry_[layer.type](
|
||||
layer, pretrained_blobs, is_test, **kwargs)
|
||||
except KeyError:
|
||||
except KeyError as e:
|
||||
raise KeyError('No translator registered for layer: %s yet.' %
|
||||
str(layer))
|
||||
str(layer)) from e
|
||||
if caffe_ops is None:
|
||||
caffe_ops = []
|
||||
if type(caffe_ops) is not list:
|
||||
|
|
|
|||
|
|
@ -970,7 +970,7 @@ StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
|
|||
input_name,
|
||||
err
|
||||
)
|
||||
)
|
||||
) from err
|
||||
|
||||
# Finally, let's create the sum operator.
|
||||
sum_ops, g = self._MakeSumOps(input_name, input_version)
|
||||
|
|
@ -1175,7 +1175,7 @@ class GradientRegistry(object):
|
|||
raise Exception(
|
||||
"Exception when creating gradient for [{}]:{}.\nOp: \n{}".
|
||||
format(op.type, e, str(op))
|
||||
)
|
||||
) from e
|
||||
|
||||
if gradient_ops is None:
|
||||
return [], g_input
|
||||
|
|
|
|||
|
|
@ -540,8 +540,8 @@ def ExtractPredictorNet(
|
|||
'StopGradient'
|
||||
]
|
||||
)
|
||||
except ValueError:
|
||||
raise Exception("No ops with input={}".format(input_blobs))
|
||||
except ValueError as e:
|
||||
raise Exception("No ops with input={}".format(input_blobs)) from e
|
||||
try:
|
||||
last_op_with_output = max(
|
||||
[
|
||||
|
|
@ -549,8 +549,8 @@ def ExtractPredictorNet(
|
|||
if output_blobs.intersection(ops[j].output)
|
||||
]
|
||||
)
|
||||
except ValueError:
|
||||
raise Exception("No ops with output={}".format(output_blobs))
|
||||
except ValueError as e:
|
||||
raise Exception("No ops with output={}".format(output_blobs)) from e
|
||||
|
||||
def validate_op(op):
|
||||
# Check that the op does not have is_test = 0 set. This is a common
|
||||
|
|
|
|||
|
|
@ -69,10 +69,10 @@ def downloadFromURLToFile(url, filename, show_progress=True):
|
|||
print("") # New line to fix for progress bar
|
||||
except HTTPError as e:
|
||||
raise Exception("Could not download model. [HTTP Error] {code}: {reason}."
|
||||
.format(code=e.code, reason=e.reason))
|
||||
.format(code=e.code, reason=e.reason)) from e
|
||||
except URLError as e:
|
||||
raise Exception("Could not download model. [URL Error] {reason}."
|
||||
.format(reason=e.reason))
|
||||
.format(reason=e.reason)) from e
|
||||
|
||||
|
||||
def getURLFromName(name, filename):
|
||||
|
|
|
|||
|
|
@ -150,9 +150,9 @@ class RoIAlignRotatedOp(hu.HypothesisTestCase):
|
|||
indexer = [slice(None)] * m.ndim
|
||||
try:
|
||||
indexer[axis] = slice(None, None, -1)
|
||||
except IndexError:
|
||||
except IndexError as e:
|
||||
raise ValueError("axis=%i is invalid for the %i-dimensional input array"
|
||||
% (axis, m.ndim))
|
||||
% (axis, m.ndim)) from e
|
||||
return m[tuple(indexer)]
|
||||
|
||||
def roialign_ref(X, R):
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ from caffe2.python import model_helper, workspace
|
|||
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("python-lmdb is not installed")
|
||||
except ImportError as e:
|
||||
raise unittest.SkipTest("python-lmdb is not installed") from e
|
||||
|
||||
|
||||
class VideoInputOpTest(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -546,8 +546,8 @@ class Struct(Field):
|
|||
raise AttributeError(item)
|
||||
try:
|
||||
return super(Struct, self).__getattribute__("fields")[item]
|
||||
except KeyError:
|
||||
raise AttributeError(item)
|
||||
except KeyError as e:
|
||||
raise AttributeError(item) from e
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
# Disable setting attributes after initialization to prevent false
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ def _get_output_shapes(output_value_infos):
|
|||
def check_gpu_():
|
||||
try:
|
||||
C.get_cuda_version()
|
||||
except Exception as _:
|
||||
raise Exception("TensorRT related functions require CUDA support")
|
||||
except Exception as e:
|
||||
raise Exception("TensorRT related functions require CUDA support") from e
|
||||
|
||||
def convert_onnx_model_to_trt_op(onnx_model,
|
||||
max_batch_size=64,
|
||||
|
|
|
|||
4
setup.py
4
setup.py
|
|
@ -446,8 +446,8 @@ Please install it via `conda install {module}` or `pip install {module}`
|
|||
def check_pydep(importname, module):
|
||||
try:
|
||||
importlib.import_module(importname)
|
||||
except ImportError:
|
||||
raise RuntimeError(missing_pydep.format(importname=importname, module=module))
|
||||
except ImportError as e:
|
||||
raise RuntimeError(missing_pydep.format(importname=importname, module=module)) from e
|
||||
|
||||
|
||||
class build_ext(setuptools.command.build_ext.build_ext):
|
||||
|
|
|
|||
|
|
@ -611,8 +611,8 @@ class TestFSDPStateDict(FSDPTest):
|
|||
def _state_dict(model: Module, state_dict_type: str):
|
||||
try:
|
||||
enum_val = STATE_DICT_MAPPING[state_dict_type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No state_dict type for {state_dict_type}")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"No state_dict type for {state_dict_type}") from e
|
||||
|
||||
with FSDP.state_dict_type(model, enum_val):
|
||||
return model.state_dict()
|
||||
|
|
@ -623,8 +623,8 @@ class TestFSDPStateDict(FSDPTest):
|
|||
):
|
||||
try:
|
||||
enum_val = STATE_DICT_MAPPING[state_dict_type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No state_dict for {state_dict_type}")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"No state_dict for {state_dict_type}") from e
|
||||
|
||||
with FSDP.state_dict_type(model, enum_val):
|
||||
return model.load_state_dict(state_dict, strict=True)
|
||||
|
|
|
|||
|
|
@ -2598,7 +2598,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
try:
|
||||
pg_gloo.barrier().wait()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}")
|
||||
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}") from e
|
||||
# Now verify communicators on this rank have
|
||||
# been aborted by watchdog.
|
||||
self._wait_for_comm_abort(process_group, failed_collective_timeout)
|
||||
|
|
|
|||
|
|
@ -101,8 +101,8 @@ def get_hf_bert(rank):
|
|||
# in a multiprocessing test
|
||||
try:
|
||||
from transformers import BertConfig, AutoModelForMaskedLM
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("Unable to import transformers")
|
||||
except ImportError as e:
|
||||
raise unittest.SkipTest("Unable to import transformers") from e
|
||||
|
||||
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
|
||||
model = AutoModelForMaskedLM.from_config(config).to(device)
|
||||
|
|
|
|||
|
|
@ -1856,8 +1856,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
def __getattr__(self, item: str):
|
||||
try:
|
||||
return self.data[item]
|
||||
except KeyError:
|
||||
raise AttributeError
|
||||
except KeyError as e:
|
||||
raise AttributeError from e
|
||||
|
||||
def tokenization(x):
|
||||
encoding = BatchEncoding({"key": x})
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ except (ImportError, AssertionError) as e:
|
|||
sys.stderr.write(f"{type(e)}: {e}\n")
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock") from e
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
|
|
|
|||
|
|
@ -128,9 +128,9 @@ class TestDtypeBase(JitTestCase):
|
|||
inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)]
|
||||
try:
|
||||
self.assert_dtype_equal_custom_args(fn, inputs)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
|
||||
raise AssertionError(fail_text)
|
||||
raise AssertionError(fail_text) from e
|
||||
|
||||
def assert_dtype_equal_custom_args(self, fn, args):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10):
|
|||
raise e # reraise the exception
|
||||
exception_message = str(e)
|
||||
if not re.search(exception_msg_pattern, exception_message):
|
||||
raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}")
|
||||
raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}") from e
|
||||
else:
|
||||
# We are done for the test case that expects an exception
|
||||
return
|
||||
|
|
|
|||
|
|
@ -110,8 +110,8 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
try:
|
||||
from apex import amp
|
||||
except Exception:
|
||||
raise unittest.SkipTest("Apex is not available")
|
||||
except Exception as e:
|
||||
raise unittest.SkipTest("Apex is not available") from e
|
||||
input = torch.randn(3, 3, device=torch.device("cuda"))
|
||||
model = amp.initialize(LinearModel(), opt_level="O2")
|
||||
self.run_test(model, input)
|
||||
|
|
|
|||
|
|
@ -1703,7 +1703,7 @@ class TestAutograd(TestCase):
|
|||
self.assertTrue(torch.is_grad_enabled())
|
||||
yield (-i if has_raised else i)
|
||||
|
||||
except UnrecoverableException:
|
||||
except UnrecoverableException :
|
||||
self.assertTrue(torch.is_grad_enabled())
|
||||
raise SecondaryException
|
||||
|
||||
|
|
|
|||
|
|
@ -3805,7 +3805,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
|||
f"unintended, please revert it. If it was intended, check with the FX " \
|
||||
f"team to ensure that the proper deprecation protocols have been followed " \
|
||||
f"and subsequently --accept the change."
|
||||
raise AssertionError(msg)
|
||||
raise AssertionError(msg) from e
|
||||
|
||||
def test_public_api_surface(self):
|
||||
non_back_compat_objects = {}
|
||||
|
|
|
|||
|
|
@ -597,7 +597,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_minmax_int_ops(self):
|
||||
def apply(fn):
|
||||
|
|
@ -627,7 +627,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_comparison_eq_ne(self):
|
||||
for device in self.devices:
|
||||
|
|
@ -1288,7 +1288,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_isnan(self):
|
||||
x = torch.rand([4])
|
||||
|
|
@ -1321,7 +1321,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), 'isnan', device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_gelu(self):
|
||||
def apply(fn):
|
||||
|
|
@ -1352,7 +1352,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_unary_ops(self):
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
|
|
@ -1435,7 +1435,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_binary_ops(self):
|
||||
def apply(fn):
|
||||
|
|
@ -1488,7 +1488,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_binary_scalar_ops(self):
|
||||
def apply(fn):
|
||||
|
|
@ -1534,7 +1534,7 @@ class TestTEFuser(JitTestCase):
|
|||
try:
|
||||
k = torch._C._te.TensorExprKernel(graph)
|
||||
except Exception as e:
|
||||
raise RuntimeError(" ".join(["Compilation failed:", device, str(code)]))
|
||||
raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) from e
|
||||
|
||||
# Run the graph
|
||||
for x, y in product(values[dtype_x], values[dtype_y]):
|
||||
|
|
@ -1543,7 +1543,7 @@ class TestTEFuser(JitTestCase):
|
|||
res = k.run((x, y))
|
||||
self.assertEqual(ref, res)
|
||||
except Exception as e:
|
||||
raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)]))
|
||||
raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) from e
|
||||
|
||||
def test_matmul(self):
|
||||
if self.dynamic_shapes:
|
||||
|
|
@ -1599,7 +1599,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_binary_tensor_scalar_ops(self):
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
|
|
@ -1643,7 +1643,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_binary_div_ops(self):
|
||||
def apply_with_scalar(fn, scalar):
|
||||
|
|
@ -1676,7 +1676,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed: {} {} {} {}".format(dtype, op.__name__, device, scalar)
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_binary_pow(self):
|
||||
def apply_with_scalar(fn, scalar):
|
||||
|
|
@ -1714,7 +1714,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_ternary_ops(self):
|
||||
def apply(fn):
|
||||
|
|
@ -1746,7 +1746,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_ternary_norm_ops(self):
|
||||
def apply(fn):
|
||||
|
|
@ -1777,7 +1777,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
@unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure")
|
||||
|
|
@ -1810,7 +1810,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_where_ops(self):
|
||||
def apply(fn):
|
||||
|
|
@ -1843,7 +1843,7 @@ class TestTEFuser(JitTestCase):
|
|||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
" ".join(["Failed:", str(dtype), op.__name__, device])
|
||||
)
|
||||
) from e
|
||||
|
||||
def test_unsupported_dtypes(self):
|
||||
for device in self.devices:
|
||||
|
|
|
|||
|
|
@ -467,7 +467,7 @@ def run_meta_crossref(
|
|||
# they're not tested outside of gradcheck which only checks
|
||||
# torch.float64 and torch.complex128 (which this second one
|
||||
# often skipped as well).
|
||||
raise unittest.SkipTest("Original OpInfo is broken")
|
||||
raise unittest.SkipTest("Original OpInfo is broken") from e
|
||||
|
||||
|
||||
# TODO: also handle cases where func raise an exception
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ class UniqueKeyLoader(Loader):
|
|||
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
||||
try:
|
||||
hash(key)
|
||||
except TypeError:
|
||||
except TypeError as e:
|
||||
raise ConstructorError(
|
||||
"while constructing a mapping",
|
||||
node.start_mark,
|
||||
"found unacceptable key ",
|
||||
key_node.start_mark,
|
||||
)
|
||||
) from e
|
||||
# check for duplicate keys
|
||||
if key in mapping:
|
||||
raise ConstructorError(
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ try:
|
|||
TestCase,
|
||||
TestSuite,
|
||||
)
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"junitparser not found, please install with 'pip install junitparser'"
|
||||
)
|
||||
) from e
|
||||
|
||||
try:
|
||||
import rich
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ def helper_for_dump_minify(contents):
|
|||
fd.write(contents)
|
||||
except OSError as e:
|
||||
log.exception(e)
|
||||
raise NotImplementedError("Could not write to {minified_repro_path}")
|
||||
raise NotImplementedError("Could not write to {minified_repro_path}") from e
|
||||
|
||||
|
||||
def dump_to_minify(gm, args, compiler_name: str):
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ def break_graph_if_unsupported(*, push):
|
|||
if self.has_backedge():
|
||||
msg = "Skipping frame because there is a graph break in a for/while loop"
|
||||
log.debug(msg)
|
||||
raise exc.SkipFrame(msg)
|
||||
raise exc.SkipFrame(msg) from excp
|
||||
|
||||
if not self.should_compile_partial_graph():
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -348,13 +348,13 @@ def proxy_args_kwargs(args, kwargs):
|
|||
proxy_args = tuple(arg.as_proxy() for arg in args)
|
||||
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
|
||||
return proxy_args, proxy_kwargs
|
||||
except NotImplementedError:
|
||||
except NotImplementedError as e:
|
||||
from .exc import unimplemented
|
||||
from .variables.base import typestr
|
||||
|
||||
raise unimplemented(
|
||||
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -745,7 +745,7 @@ def wrap_fake_exception(fn):
|
|||
|
||||
msg = f"Unsupported: {e.reason} with fake tensor propagation."
|
||||
log.warning(msg)
|
||||
raise unimplemented(msg)
|
||||
raise unimplemented(msg) from e
|
||||
|
||||
|
||||
def wrap_to_fake_tensor(e, fake_mode):
|
||||
|
|
|
|||
|
|
@ -56,8 +56,8 @@ class ConstantVariable(VariableTracker):
|
|||
try:
|
||||
options = VariableTracker.propagate([self])
|
||||
return [ConstantVariable(x, **options) for x in self.as_python_constant()]
|
||||
except TypeError:
|
||||
raise NotImplementedError()
|
||||
except TypeError as e:
|
||||
raise NotImplementedError from e
|
||||
|
||||
def const_getattr(self, tx, name):
|
||||
member = getattr(self.value, name)
|
||||
|
|
|
|||
|
|
@ -504,8 +504,8 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
|
||||
try:
|
||||
fn = inspect.getattr_static(self.value_type, "__iter__")
|
||||
except AttributeError:
|
||||
raise NotImplementedError()
|
||||
except AttributeError as e:
|
||||
raise NotImplementedError from e
|
||||
|
||||
if fn in (
|
||||
torch.nn.ModuleList.__iter__,
|
||||
|
|
|
|||
|
|
@ -277,8 +277,9 @@ def min_cut_rematerialization_partition(
|
|||
"""
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Need networkx installed to perform smart recomputation "
|
||||
"heuristics") from e
|
||||
|
||||
joint_module.graph.eliminate_dead_code()
|
||||
joint_module.recompile()
|
||||
|
|
|
|||
|
|
@ -435,7 +435,7 @@ class CppCodeCache:
|
|||
try:
|
||||
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise exc.CppCompileError(cmd, e.output)
|
||||
raise exc.CppCompileError(cmd, e.output) from e
|
||||
|
||||
cls.cache[key] = cls._load_library(output_path)
|
||||
cls.cache[key].key = key
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class CachingAutotuner(KernelInterface):
|
|||
raise RuntimeError(
|
||||
"""Consider updating Triton with
|
||||
`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
|
||||
)
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
|
|
|||
|
|
@ -136,5 +136,5 @@ class CrossRefFakeMode(TorchDispatchMode):
|
|||
r_out, fake_out, check_strides=self.check_strides
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Mismatch on {func}: {e}")
|
||||
raise RuntimeError(f"Mismatch on {func}: {e}") from e
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -819,7 +819,7 @@ def _test_batched_grad_forward_ad(func, inputs) -> bool:
|
|||
except RuntimeError as ex:
|
||||
# Rethrow to provide a better error message
|
||||
raise GradcheckError(
|
||||
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}')
|
||||
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') from ex
|
||||
|
||||
for input_idx, (res, exp) in enumerate(zip(result, expected)):
|
||||
if torch.allclose(res, exp):
|
||||
|
|
@ -861,7 +861,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
|
|||
# autograd.grad instead of the C++ traceback of what line in the
|
||||
# backward formula
|
||||
raise GradcheckError(
|
||||
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}')
|
||||
f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') from ex
|
||||
|
||||
for input_idx, (res, exp) in enumerate(zip(result, expected)):
|
||||
if torch.allclose(res, exp):
|
||||
|
|
@ -977,12 +977,12 @@ def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
|
|||
try:
|
||||
grads_input = torch.autograd.grad(output_to_check, diff_input_list,
|
||||
grads_output, allow_unused=True)
|
||||
except RuntimeError:
|
||||
except RuntimeError as e:
|
||||
warn_bc_breaking()
|
||||
raise GradcheckError((
|
||||
'Expected backward function to handle undefined output grads. '
|
||||
'Please look at "Notes about undefined output gradients" in '
|
||||
'"tools/autograd/derivatives.yaml"'))
|
||||
'"tools/autograd/derivatives.yaml"')) from e
|
||||
|
||||
for gi, i in zip(grads_input, diff_input_list):
|
||||
if (gi is not None) and (not gi.eq(0).all()):
|
||||
|
|
|
|||
|
|
@ -669,13 +669,13 @@ def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
|
|||
"""
|
||||
try:
|
||||
import pynvml # type: ignore[import]
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("pynvml module not found, please install pynvml")
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError("pynvml module not found, please install pynvml") from e
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
except NVMLError_DriverNotLoaded:
|
||||
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?")
|
||||
except NVMLError_DriverNotLoaded as e:
|
||||
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
||||
device = _get_device_index(device, optional=True)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
|
||||
|
|
@ -695,13 +695,13 @@ def utilization(device: Optional[Union[Device, int]] = None) -> int:
|
|||
"""
|
||||
try:
|
||||
import pynvml # type: ignore[import]
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("pynvml module not found, please install pynvml")
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError("pynvml module not found, please install pynvml") from e
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
except NVMLError_DriverNotLoaded:
|
||||
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?")
|
||||
except NVMLError_DriverNotLoaded as e:
|
||||
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
|
||||
device = _get_device_index(device, optional=True)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
|
||||
|
|
|
|||
|
|
@ -1640,10 +1640,10 @@ class DistributedDataParallel(Module, Joinable):
|
|||
)
|
||||
try:
|
||||
overlapped_optim.register_ddp(self)
|
||||
except NotImplementedError:
|
||||
except NotImplementedError as e:
|
||||
raise RuntimeError(
|
||||
f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
|
||||
)
|
||||
) from e
|
||||
|
||||
def _distributed_broadcast_coalesced(
|
||||
self, tensors, buffer_size, authoritative_rank=0
|
||||
|
|
|
|||
|
|
@ -195,11 +195,11 @@ class RendezvousParameters:
|
|||
return value
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"The rendezvous configuration option '{key}' does not represent a valid integer "
|
||||
"value."
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
|
||||
|
|
@ -244,11 +244,11 @@ class RendezvousHandlerRegistry:
|
|||
"""Creates a new :py:class:`RendezvousHandler`."""
|
||||
try:
|
||||
creator = self._registry[params.backend]
|
||||
except KeyError:
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
|
||||
f"to call `{self.register.__name__}`?"
|
||||
)
|
||||
) from e
|
||||
|
||||
handler = creator(params)
|
||||
|
||||
|
|
|
|||
|
|
@ -464,10 +464,10 @@ class EtcdRendezvous(object):
|
|||
version_counter = self.client.get(self.get_path("/rdzv/version_counter"))
|
||||
version_counter.value = str(int(version_counter.value) + 1)
|
||||
self.client.update(version_counter)
|
||||
except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed):
|
||||
except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e:
|
||||
raise RendezvousError(
|
||||
"Unexpected state of EtcdRendezvousHandler, worker needs to die."
|
||||
)
|
||||
) from e
|
||||
|
||||
# Any failure below results in declaring a retryable rendezvous failure.
|
||||
# The ephemeral /rdzv/active_version will expire and someone can then
|
||||
|
|
|
|||
|
|
@ -381,8 +381,8 @@ def _get_ignored_modules(
|
|||
msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
|
||||
try:
|
||||
ignored_root_modules = set(_ignored_modules)
|
||||
except TypeError:
|
||||
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}")
|
||||
except TypeError as e:
|
||||
raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
|
||||
for module in ignored_root_modules:
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
|
||||
|
|
|
|||
|
|
@ -1015,11 +1015,11 @@ def _get_param_id_to_param_from_optim_input(
|
|||
return list(model.parameters())
|
||||
try:
|
||||
params = list(optim_input)
|
||||
except TypeError:
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Optimizer input should be an iterable of Tensors or dicts, "
|
||||
f"but got {optim_input}"
|
||||
)
|
||||
) from e
|
||||
if len(params) == 0:
|
||||
raise ValueError("Optimizer input should not be empty")
|
||||
|
||||
|
|
|
|||
|
|
@ -44,8 +44,9 @@ def register_functional_optim(key, optim):
|
|||
def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
||||
try:
|
||||
functional_cls = functional_optim_map[optim_cls]
|
||||
except KeyError:
|
||||
raise ValueError(f"Optimizer {optim_cls} does not have a functional counterpart!")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Optimizer {optim_cls} does not have a functional "
|
||||
f"counterpart!") from e
|
||||
|
||||
return _create_functional_optim(functional_cls, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1323,9 +1323,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
f"Tensors, but got {torch.typename(params)}")
|
||||
try:
|
||||
all_params = list(params)
|
||||
except TypeError:
|
||||
except TypeError as e:
|
||||
raise TypeError("`params` argument should be an iterable of Tensors"
|
||||
f" or dicts, but got {torch.typename(params)}")
|
||||
f" or dicts, but got {torch.typename(params)}") from e
|
||||
if len(all_params) == 0:
|
||||
raise ValueError("ZeroRedundancyOptimizer got an empty parameter "
|
||||
"list")
|
||||
|
|
|
|||
|
|
@ -198,8 +198,8 @@ class Skippable(nn.Module):
|
|||
for ns, name in self.poppable():
|
||||
try:
|
||||
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"'{name}' has not been stashed")
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"'{name}' has not been stashed") from e
|
||||
input = batch.values
|
||||
|
||||
# Handle skip commands.
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
try:
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"urllib cannot be found, urlparse from python2 is no longer supported."
|
||||
)
|
||||
) from e
|
||||
|
||||
import numbers
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ def _handle_exception(result):
|
|||
raise RuntimeError( # noqa: B904
|
||||
f"Failed to create original exception type. Error msg was {str(e)}"
|
||||
f" Original exception on remote side was {exception_msg}"
|
||||
)
|
||||
) from e
|
||||
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
|
|
|||
|
|
@ -605,13 +605,13 @@ def determine_local_world_size(nproc_per_node: str):
|
|||
try:
|
||||
logging.info(f"Using nproc_per_node={nproc_per_node}.")
|
||||
return int(nproc_per_node)
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
if nproc_per_node == "cpu":
|
||||
num_proc = os.cpu_count()
|
||||
device_type = "cpu"
|
||||
elif nproc_per_node == "gpu":
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError("Cuda is not available.")
|
||||
raise ValueError("Cuda is not available.") from e
|
||||
device_type = "gpu"
|
||||
num_proc = torch.cuda.device_count()
|
||||
elif nproc_per_node == "auto":
|
||||
|
|
@ -622,7 +622,7 @@ def determine_local_world_size(nproc_per_node: str):
|
|||
num_proc = os.cpu_count()
|
||||
device_type = "cpu"
|
||||
else:
|
||||
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
|
||||
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
|
||||
|
||||
log.info(
|
||||
f"Using nproc_per_node={nproc_per_node},"
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedD
|
|||
# Parse the function declaration
|
||||
try:
|
||||
py_ast = ast.parse(decl)
|
||||
except SyntaxError:
|
||||
except SyntaxError as e:
|
||||
# This should only happen if there's some unforeseeable change
|
||||
# in the dataclasses module that makes our synthesized code fail
|
||||
raise RuntimeError(
|
||||
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
|
||||
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
|
||||
)
|
||||
) from e
|
||||
fake_filename = _get_fake_filename(cls, name)
|
||||
# Parse the function
|
||||
return ParsedDef(
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ def infer_concrete_type_builder(nn_module, share_types=True):
|
|||
except RuntimeError as re:
|
||||
raise RuntimeError(
|
||||
"Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re)
|
||||
)
|
||||
) from re
|
||||
|
||||
return attr_type, inferred
|
||||
|
||||
|
|
|
|||
|
|
@ -292,8 +292,8 @@ def tensorboard_trace_handler(dir_name: str, worker_name: Optional[str] = None,
|
|||
if not os.path.isdir(dir_name):
|
||||
try:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
except Exception:
|
||||
raise RuntimeError("Can't create directory: " + dir_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Can't create directory: " + dir_name) from e
|
||||
if not worker_name:
|
||||
worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
|
||||
file_name = "{}.{}.pt.trace.json".format(worker_name, int(time.time() * 1000))
|
||||
|
|
|
|||
|
|
@ -677,8 +677,8 @@ class TensorLikePair(Pair):
|
|||
|
||||
try:
|
||||
return torch.as_tensor(tensor_like)
|
||||
except Exception:
|
||||
raise UnsupportedInputs()
|
||||
except Exception as e:
|
||||
raise UnsupportedInputs() from e
|
||||
|
||||
def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
|
||||
if tensor.layout not in {
|
||||
|
|
|
|||
|
|
@ -949,7 +949,7 @@ class FSDPTest(MultiProcessTestCase):
|
|||
**init_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Initializing {model_class} raised error {str(e)}")
|
||||
raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
|
||||
if not isinstance(fsdp_model, FSDP):
|
||||
# Enforce that we wrap with top-level FSDP since we are comparing
|
||||
# assuming a data parallel reference and some test models may not
|
||||
|
|
|
|||
|
|
@ -3260,7 +3260,7 @@ def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
|
|||
if any(connect_error in str(error) for connect_error in connect_errors):
|
||||
tries_remaining -= 1
|
||||
if tries_remaining == 0:
|
||||
raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}")
|
||||
raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
|
||||
time.sleep(random.random())
|
||||
continue
|
||||
raise
|
||||
|
|
@ -4001,8 +4001,8 @@ def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
|
|||
"""
|
||||
try:
|
||||
return next(iter(samples))
|
||||
except StopIteration:
|
||||
raise unittest.SkipTest('Skipped! Need at least 1 sample input')
|
||||
except StopIteration as e:
|
||||
raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
|
||||
|
||||
# this helper method is to recursively
|
||||
# clone the tensor-type input of operators tested by OpInfo
|
||||
|
|
|
|||
|
|
@ -290,8 +290,8 @@ class DTensorConverter(object):
|
|||
tree_unflatten(new_args, self.flatten_args_spec),
|
||||
tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
|
||||
)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except StopIteration as e:
|
||||
raise StopIteration from e
|
||||
|
||||
def to_dist_tensor(
|
||||
self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
|
||||
|
|
|
|||
|
|
@ -349,10 +349,10 @@ class _DataPipeSerializationWrapper:
|
|||
def __len__(self):
|
||||
try:
|
||||
return len(self._datapipe)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
"{} instance doesn't have valid length".format(type(self).__name__)
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
||||
|
|
|
|||
|
|
@ -154,8 +154,8 @@ def _collate_helper(conversion, item):
|
|||
try:
|
||||
import torcharrow.pytorch as tap # type: ignore[import]
|
||||
collation_fn = tap.rec.Default()
|
||||
except Exception:
|
||||
raise Exception("unable to import default collation function from the TorchArrow")
|
||||
except Exception as e:
|
||||
raise Exception("unable to import default collation function from the TorchArrow") from e
|
||||
|
||||
tuple_names.append(str(name))
|
||||
value = collation_fn(df[name])
|
||||
|
|
|
|||
|
|
@ -93,8 +93,8 @@ class ZipperMapDataPipe(MapDataPipe[Tuple[T_co, ...]]):
|
|||
for dp in self.datapipes:
|
||||
try:
|
||||
res.append(dp[index])
|
||||
except IndexError:
|
||||
raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.")
|
||||
except IndexError as e:
|
||||
raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") from e
|
||||
return tuple(res)
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
|
|||
|
|
@ -53,11 +53,11 @@ class BatcherMapDataPipe(MapDataPipe[DataChunk]):
|
|||
for i in indices:
|
||||
batch.append(self.datapipe[i])
|
||||
return self.wrapper_class(batch)
|
||||
except IndexError:
|
||||
except IndexError as e:
|
||||
if not self.drop_last and len(batch) > 0:
|
||||
return self.wrapper_class(batch)
|
||||
else:
|
||||
raise IndexError(f"Index {index} is out of bound.")
|
||||
raise IndexError(f"Index {index} is out of bound.") from e
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.length is not None:
|
||||
|
|
|
|||
|
|
@ -148,13 +148,13 @@ class ImageHandler:
|
|||
import numpy as np
|
||||
except ImportError as e:
|
||||
raise ModuleNotFoundError("Package `numpy` is required to be installed for default image decoder."
|
||||
"Please use `pip install numpy` to install the package")
|
||||
"Please use `pip install numpy` to install the package") from e
|
||||
|
||||
try:
|
||||
import PIL.Image
|
||||
except ImportError as e:
|
||||
raise ModuleNotFoundError("Package `PIL` is required to be installed for default image decoder."
|
||||
"Please use `pip install Pillow` to install the package")
|
||||
"Please use `pip install Pillow` to install the package") from e
|
||||
|
||||
imagespec = self.imagespec
|
||||
atype, etype, mode = imagespecs[imagespec]
|
||||
|
|
@ -200,7 +200,7 @@ def videohandler(extension, data):
|
|||
except ImportError as e:
|
||||
raise ModuleNotFoundError("Package `torchvision` is required to be installed for default video file loader."
|
||||
"Please use `pip install torchvision` or `conda install torchvision -c pytorch`"
|
||||
"to install the package")
|
||||
"to install the package") from e
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
fname = os.path.join(dirname, f"file.{extension}")
|
||||
|
|
@ -221,7 +221,7 @@ def audiohandler(extension, data):
|
|||
except ImportError as e:
|
||||
raise ModuleNotFoundError("Package `torchaudio` is required to be installed for default audio file loader."
|
||||
"Please use `pip install torchaudio` or `conda install torchaudio -c pytorch`"
|
||||
"to install the package")
|
||||
"to install the package") from e
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
fname = os.path.join(dirname, f"file.{extension}")
|
||||
|
|
@ -240,7 +240,7 @@ class MatHandler:
|
|||
except ImportError as e:
|
||||
raise ModuleNotFoundError("Package `scipy` is required to be installed for mat file."
|
||||
"Please use `pip install scipy` or `conda install scipy`"
|
||||
"to install the package")
|
||||
"to install the package") from e
|
||||
self.sio = sio
|
||||
self.loadmat_kwargs = loadmat_kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -47,9 +47,9 @@ def _simple_graph_snapshot_restoration(datapipe: IterDataPipe, n_iterations: int
|
|||
try:
|
||||
next(it)
|
||||
remainder -= 1
|
||||
except StopIteration:
|
||||
except StopIteration as e:
|
||||
raise RuntimeError(f"Fast-forward {datapipe} by {n_iterations} iterations "
|
||||
"exceeds the number of samples available.")
|
||||
"exceeds the number of samples available.") from e
|
||||
datapipe._fast_forward_iterator = it
|
||||
# While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.
|
||||
|
||||
|
|
|
|||
|
|
@ -262,10 +262,10 @@ def error_on_missing_kernels(
|
|||
try:
|
||||
with open(kernel_defn_file_path, "r") as f:
|
||||
backend_defns = f.read()
|
||||
except IOError:
|
||||
except IOError as e:
|
||||
raise AssertionError(
|
||||
f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
|
||||
)
|
||||
) from e
|
||||
|
||||
if full_codegen is None:
|
||||
full_codegen = []
|
||||
|
|
|
|||
|
|
@ -137,10 +137,10 @@ def validate_shape_inference_header(
|
|||
with open(shape_inference_hdr, "r") as f:
|
||||
shape_infr_decls = f.read()
|
||||
shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
|
||||
except IOError:
|
||||
except IOError as e:
|
||||
raise AssertionError(
|
||||
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
|
||||
)
|
||||
) from e
|
||||
|
||||
shape_infr_regex = r"compute_shape_(\w+)"
|
||||
actual_shape_infr_name_counts = Counter(
|
||||
|
|
|
|||
|
|
@ -1719,8 +1719,8 @@ class Type:
|
|||
return CustomClassType(m.group(1))
|
||||
try:
|
||||
return BaseType(BaseTy[t])
|
||||
except KeyError:
|
||||
raise RuntimeError(f"unrecognized type {t}")
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"unrecognized type {t}") from e
|
||||
|
||||
def __str__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user