removing some redundant str splits (#106089)

drop some redundant string splits, no factual changes, just cleaning the codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106089
Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
Jirka Borovec 2023-09-01 00:22:55 +00:00 committed by PyTorch MergeBot
parent cc220e45a8
commit 9178deedff
18 changed files with 25 additions and 29 deletions

View File

@ -19,7 +19,7 @@ def parse_args():
"--local_rank", "--local_rank",
type=int, type=int,
required=True, required=True,
help="The rank of the node for multi-node distributed " "training", help="The rank of the node for multi-node distributed training",
) )
return parser.parse_args() return parser.parse_args()

View File

@ -32,7 +32,7 @@ class SqueezeNet(nn.Module):
super().__init__() super().__init__()
if version not in [1.0, 1.1]: if version not in [1.0, 1.1]:
raise ValueError( raise ValueError(
f"Unsupported SqueezeNet version {version}:" "1.0 or 1.1 expected" f"Unsupported SqueezeNet version {version}:1.0 or 1.1 expected"
) )
self.num_classes = num_classes self.num_classes = num_classes
if version == 1.0: if version == 1.0:

View File

@ -360,19 +360,17 @@ class TestNestedTensor(TestCase):
@torch.inference_mode() @torch.inference_mode()
def test_repr_string(self): def test_repr_string(self):
a = torch.nested.nested_tensor([]) a = torch.nested.nested_tensor([])
expected = "nested_tensor([" "\n\n])" expected = "nested_tensor([\n\n])"
self.assertEqual(str(a), expected) self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected) self.assertEqual(repr(a), expected)
a = torch.nested.nested_tensor([torch.tensor(1.0)]) a = torch.nested.nested_tensor([torch.tensor(1.0)])
expected = "nested_tensor([" "\n tensor(1.)" "\n])" expected = "nested_tensor([\n tensor(1.)\n])"
self.assertEqual(str(a), expected) self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected) self.assertEqual(repr(a), expected)
a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])]) a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
expected = ( expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])"
"nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])"
)
self.assertEqual(str(a), expected) self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected) self.assertEqual(repr(a), expected)

View File

@ -57,7 +57,7 @@ class WeakTest(TestCase):
self.assertIsNot( self.assertIsNot(
value1, value1,
value2, value2,
"invalid test" " -- value parameters must be distinct objects", "invalid test -- value parameters must be distinct objects",
) )
weakdict = klass() weakdict = klass()
o = weakdict.setdefault(key, value1) o = weakdict.setdefault(key, value1)

View File

@ -700,7 +700,7 @@ class TestMultiIndexingAutomated:
in_indices[i] = indx in_indices[i] = indx
elif indx.dtype.kind != "b" and indx.dtype.kind != "i": elif indx.dtype.kind != "b" and indx.dtype.kind != "i":
raise IndexError( raise IndexError(
"arrays used as indices must be of " "integer (or boolean) type" "arrays used as indices must be of integer (or boolean) type"
) )
if indx.ndim != 0: if indx.ndim != 0:
no_copy = False no_copy = False

View File

@ -390,7 +390,7 @@ class AxisConcatenator:
newobj = newobj.swapaxes(-1, trans1d) newobj = newobj.swapaxes(-1, trans1d)
elif isinstance(item, str): elif isinstance(item, str):
if k != 0: if k != 0:
raise ValueError("special directives must be the " "first entry.") raise ValueError("special directives must be the first entry.")
if item in ("r", "c"): if item in ("r", "c"):
matrix = True matrix = True
col = item == "c" col = item == "c"

View File

@ -2920,7 +2920,7 @@ class TestPercentile:
assert_equal(c1.shape, r1.shape) assert_equal(c1.shape, r1.shape)
@pytest.mark.xfail( @pytest.mark.xfail(
reason="numpy: x.dtype is int, out is int; " "torch: result is float" reason="numpy: x.dtype is int, out is int; torch: result is float"
) )
def test_scalar_q_2(self): def test_scalar_q_2(self):
x = np.arange(12).reshape(3, 4) x = np.arange(12).reshape(3, 4)

View File

@ -566,7 +566,7 @@ class TestHistogramOptimBinNums:
assert_equal( assert_equal(
len(a), len(a),
numbins, numbins,
err_msg=f"{estimator} estimator, " "No Variance test", err_msg=f"{estimator} estimator, No Variance test",
) )
def test_limited_variance(self): def test_limited_variance(self):

View File

@ -784,7 +784,7 @@ class TestCond(CondCases):
linalg.cond(A, p) linalg.cond(A, p)
@pytest.mark.xfail( @pytest.mark.xfail(
True, run=False, reason="Platform/LAPACK-dependent failure, " "see gh-18914" True, run=False, reason="Platform/LAPACK-dependent failure, see gh-18914"
) )
def test_nan(self): def test_nan(self):
# nans should be passed through, not converted to infs # nans should be passed through, not converted to infs

View File

@ -149,7 +149,7 @@ def main() -> None:
"--yaml_file_path", "--yaml_file_path",
type=str, type=str,
required=True, required=True,
help="Path to the yaml" " file with a list of operators used by the model.", help="Path to the yaml file with a list of operators used by the model.",
) )
parser.add_argument( parser.add_argument(
"-o", "-o",

View File

@ -343,7 +343,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No
@timed("Installing pytorch nightly binaries") @timed("Installing pytorch nightly binaries")
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]": def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
""" "Install pytorch into a temporary directory""" """Install pytorch into a temporary directory"""
pytdir = tempfile.TemporaryDirectory() pytdir = tempfile.TemporaryDirectory()
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url] cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
p = subprocess.run(cmd, check=True) p = subprocess.run(cmd, check=True)

View File

@ -405,12 +405,12 @@ def cond_func(pred, true_fn, false_fn, inputs):
for branch in [true_fn, false_fn]: for branch in [true_fn, false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs): if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be modifying the input!" "One of torch.cond branch might be modifying the input!"
) )
if _has_potential_branch_input_alias(branch, unwrapped_inputs): if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be aliasing the input!" "One of torch.cond branch might be aliasing the input!"
) )
cond_return = cond_op( cond_return = cond_op(
@ -443,12 +443,12 @@ def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
for branch in [functional_true_fn, functional_false_fn]: for branch in [functional_true_fn, functional_false_fn]:
if _has_potential_branch_input_mutation(branch, unwrapped_inputs): if _has_potential_branch_input_mutation(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be modifying the input!" "One of torch.cond branch might be modifying the input!"
) )
for branch in [true_fn, false_fn]: for branch in [true_fn, false_fn]:
if _has_potential_branch_input_alias(branch, unwrapped_inputs): if _has_potential_branch_input_alias(branch, unwrapped_inputs):
raise UnsupportedAliasMutationException( raise UnsupportedAliasMutationException(
"One of torch.cond branch " "might be aliasing the input!" "One of torch.cond branch might be aliasing the input!"
) )
cond_return = cond_op( cond_return = cond_op(

View File

@ -1226,7 +1226,7 @@ def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None):
# Move working axis to the end of the shape # Move working axis to the end of the shape
a = torch.moveaxis(a, axisa, -1) a = torch.moveaxis(a, axisa, -1)
b = torch.moveaxis(b, axisb, -1) b = torch.moveaxis(b, axisb, -1)
msg = "incompatible dimensions for cross product\n" "(dimension must be 2 or 3)" msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)"
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError(msg) raise ValueError(msg)

View File

@ -321,7 +321,7 @@ def average(
if a.shape != weights.shape: if a.shape != weights.shape:
if axis is None: if axis is None:
raise TypeError( raise TypeError(
"Axis must be specified when shapes of a and weights " "differ." "Axis must be specified when shapes of a and weights differ."
) )
if weights.ndim != 1: if weights.ndim != 1:
raise TypeError( raise TypeError(

View File

@ -358,7 +358,7 @@ class _SingleLevelFunction(
if they are intended to be used for in ``jvp``. if they are intended to be used for in ``jvp``.
""" """
raise NotImplementedError( raise NotImplementedError(
"You must implement the forward function for custom" " autograd.Function." "You must implement the forward function for custom autograd.Function."
) )
@staticmethod @staticmethod

View File

@ -2043,9 +2043,9 @@ def _get_param_to_fqn(
""" """
param_to_param_names = _get_param_to_fqns(model) param_to_param_names = _get_param_to_fqns(model)
for param_names in param_to_param_names.values(): for param_names in param_to_param_names.values():
assert len(param_names) > 0, ( assert (
"`_get_param_to_fqns()` " "should not construct empty lists" len(param_names) > 0
) ), "`_get_param_to_fqns()` should not construct empty lists"
if len(param_names) > 1: if len(param_names) > 1:
raise RuntimeError( raise RuntimeError(
"Each parameter should only map to one parameter name but got " "Each parameter should only map to one parameter name but got "

View File

@ -1060,7 +1060,7 @@ class ExprBuilder(Builder):
if isinstance(index_expr.value, ast.Tuple): if isinstance(index_expr.value, ast.Tuple):
raise NotSupportedError( raise NotSupportedError(
base.range(), base.range(),
"slicing multiple dimensions with " "tuples not supported yet", "slicing multiple dimensions with tuples not supported yet",
) )
return build_expr(ctx, index_expr.value) return build_expr(ctx, index_expr.value)

View File

@ -44,9 +44,7 @@ def _sanity_check(name, package, level):
if not isinstance(package, str): if not isinstance(package, str):
raise TypeError("__package__ not set to a string") raise TypeError("__package__ not set to a string")
elif not package: elif not package:
raise ImportError( raise ImportError("attempted relative import with no known parent package")
"attempted relative import with no known parent " "package"
)
if not name and level == 0: if not name and level == 0:
raise ValueError("Empty module name") raise ValueError("Empty module name")