[BE][CI] bump ruff to 0.9.0: string quote styles (#144569)

Reference: https://docs.astral.sh/ruff/formatter/#f-string-formatting

- Change the outer quotes to double quotes for nested f-strings

```diff
- f'{", ".join(args)}'
+ f"{', '.join(args)}"
```

- Change the inner quotes to double quotes for triple f-strings

```diff
  string = """
-     {', '.join(args)}
+     {", ".join(args)}
  """
```

- Join implicitly concatenated strings

```diff
- string = "short string " "short string " f"{var}"
+ string = f"short string short string {var}"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144569
Approved by: https://github.com/Skylion007
ghstack dependencies: #146509
This commit is contained in:
Xuehai Pan 2025-02-24 23:54:38 +08:00 committed by PyTorch MergeBot
parent 52f6d4aa30
commit 754fb834db
52 changed files with 135 additions and 135 deletions

View File

@ -204,7 +204,7 @@ if __name__ == "__main__":
else: else:
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 " build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 "
elif branch.startswith(("v1.", "v2.")): elif branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1 " build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "
if enable_mkldnn: if enable_mkldnn:
build_ArmComputeLibrary() build_ArmComputeLibrary()

View File

@ -761,7 +761,7 @@ def start_build(
version = host.check_output("cat pytorch/version.txt").strip()[:-2] version = host.check_output("cat pytorch/version.txt").strip()[:-2]
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1" build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1"
if branch.startswith(("v1.", "v2.")): if branch.startswith(("v1.", "v2.")):
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1" build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1 : branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
if host.using_docker(): if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
if enable_mkldnn: if enable_mkldnn:

View File

@ -46,7 +46,9 @@ def train(args, model, device, train_loader, optimizer, epoch):
optimizer.step() optimizer.step()
if batch_idx % args.log_interval == 0: if batch_idx % args.log_interval == 0:
print( print(
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}" # noqa: B950 f"Train Epoch: {epoch} "
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
) )
if args.dry_run: if args.dry_run:
break break
@ -71,7 +73,9 @@ def test(model, device, test_loader):
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
print( print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n" # noqa: B950 f"\nTest set: Average loss: {test_loss:.4f}, "
f"Accuracy: {correct}/{len(test_loader.dataset)} "
f"({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
) )

View File

@ -57,10 +57,10 @@ def gh_fetch_url_and_headers(
print( print(
f"""{url} f"""{url}
Rate limit exceeded: Rate limit exceeded:
Used: {err.headers['X-RateLimit-Used']} Used: {err.headers["X-RateLimit-Used"]}
Limit: {err.headers['X-RateLimit-Limit']} Limit: {err.headers["X-RateLimit-Limit"]}
Remaining: {err.headers['X-RateLimit-Remaining']} Remaining: {err.headers["X-RateLimit-Remaining"]}
Resets at: {err.headers['x-RateLimit-Reset']}""" Resets at: {err.headers["x-RateLimit-Reset"]}"""
) )
else: else:
print(f"Error fetching {url} {err}") print(f"Error fetching {url} {err}")

View File

@ -485,7 +485,7 @@ def get_check_run_name_prefix(workflow_run: Any) -> str:
if workflow_run is None: if workflow_run is None:
return "" return ""
else: else:
return f'{workflow_run["workflow"]["name"]} / ' return f"{workflow_run['workflow']['name']} / "
def is_passing_status(status: Optional[str]) -> bool: def is_passing_status(status: Optional[str]) -> bool:
@ -545,7 +545,7 @@ def add_workflow_conclusions(
if not isinstance(checkrun_node, dict): if not isinstance(checkrun_node, dict):
warn(f"Expected dictionary, but got {type(checkrun_node)}") warn(f"Expected dictionary, but got {type(checkrun_node)}")
continue continue
checkrun_name = f'{get_check_run_name_prefix(workflow_run)}{checkrun_node["name"]}' checkrun_name = f"{get_check_run_name_prefix(workflow_run)}{checkrun_node['name']}"
existing_checkrun = workflow_obj.jobs.get(checkrun_name) existing_checkrun = workflow_obj.jobs.get(checkrun_name)
if existing_checkrun is None or not is_passing_status( if existing_checkrun is None or not is_passing_status(
existing_checkrun.status existing_checkrun.status

View File

@ -79,7 +79,7 @@ class TryMergeExplainer:
( (
"<details><summary>Advanced Debugging</summary>", "<details><summary>Advanced Debugging</summary>",
"Check the merge workflow status ", "Check the merge workflow status ",
f"<a href=\"{os.getenv('GH_RUN_URL')}\">here</a>", f'<a href="{os.getenv("GH_RUN_URL")}">here</a>',
"</details>", "</details>",
) )
) )

View File

@ -103,7 +103,7 @@ if __name__ == "__main__":
"-o", "-o",
"--output_dir", "--output_dir",
required=False, required=False,
help="Where to generate the kernels " " will default to the current directory ", help="Where to generate the kernels will default to the current directory",
) )
args = parser.parse_args() args = parser.parse_args()
main(args.output_dir) main(args.output_dir)

View File

@ -102,7 +102,7 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
msg += textwrap.dedent( msg += textwrap.dedent(
f""" f"""
Error: {len(failed)} models have accuracy status regressed: Error: {len(failed)} models have accuracy status regressed:
{' '.join(failed)} {" ".join(failed)}
""" """
) )
@ -110,7 +110,7 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
msg += textwrap.dedent( msg += textwrap.dedent(
f""" f"""
Improvement: {len(improved)} models have accuracy status improved: Improvement: {len(improved)} models have accuracy status improved:
{' '.join(improved)} {" ".join(improved)}
""" """
) )

View File

@ -26,7 +26,7 @@ def check_csv(filename):
textwrap.dedent( textwrap.dedent(
f""" f"""
Error {len(failed)} models failed Error {len(failed)} models failed
{' '.join(failed)} {" ".join(failed)}
""" """
) )
) )

View File

@ -91,7 +91,7 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
msg += textwrap.dedent( msg += textwrap.dedent(
f""" f"""
Error: {len(failed)} models have new dynamo graph breaks: Error: {len(failed)} models have new dynamo graph breaks:
{' '.join(failed)} {" ".join(failed)}
""" """
) )
@ -99,7 +99,7 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
msg += textwrap.dedent( msg += textwrap.dedent(
f""" f"""
Improvement: {len(improved)} models have fixed dynamo graph breaks: Improvement: {len(improved)} models have fixed dynamo graph breaks:
{' '.join(improved)} {" ".join(improved)}
""" """
) )

View File

@ -40,7 +40,7 @@ def main(args):
textwrap.dedent( textwrap.dedent(
f""" f"""
Error: {len(failed)} models below expected memory compression ratio: Error: {len(failed)} models below expected memory compression ratio:
{' '.join(failed)} {" ".join(failed)}
If this drop is expected, you can update `{args.expected}`. If this drop is expected, you can update `{args.expected}`.
""" """
) )

View File

@ -26,7 +26,7 @@ def check_perf_csv(filename, threshold, threshold_scale):
textwrap.dedent( textwrap.dedent(
f""" f"""
Error {len(failed)} models performance regressed Error {len(failed)} models performance regressed
{' '.join(failed)} {" ".join(failed)}
""" """
) )
) )

View File

@ -368,7 +368,7 @@ class GroupedBenchmark:
return textwrap.dedent( return textwrap.dedent(
f"""\ f"""\
def model({', '.join(signature_args)}): def model({", ".join(signature_args)}):
{{stmt_str}} {{stmt_str}}
return {signature_output} return {signature_output}
""" """
@ -397,7 +397,7 @@ class GroupedBenchmark:
cpp_invocation = textwrap.dedent( cpp_invocation = textwrap.dedent(
f"""\ f"""\
std::vector<torch::jit::IValue> ivalue_inputs({{ std::vector<torch::jit::IValue> ivalue_inputs({{
{', '.join([f'torch::jit::IValue({a})' for a in signature_args])} {", ".join([f"torch::jit::IValue({a})" for a in signature_args])}
}}); }});
{cpp_prefix}{model_name}.forward(ivalue_inputs); {cpp_prefix}{model_name}.forward(ivalue_inputs);
""" """

View File

@ -49,7 +49,7 @@ def generate_example_rst(example_case: ExportCase):
# Generate contents of the .rst file # Generate contents of the .rst file
title = f"{example_case.name}" title = f"{example_case.name}"
doc_contents = f"""{title} doc_contents = f"""{title}
{'^' * (len(title))} {"^" * (len(title))}
.. note:: .. note::
@ -117,7 +117,7 @@ def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
module_contents = "\n\n".join(v) module_contents = "\n\n".join(v)
support_contents += f""" support_contents += f"""
{support_level} {support_level}
{'-' * (len(support_level))} {"-" * (len(support_level))}
{module_contents} {module_contents}
""" """

View File

@ -119,7 +119,7 @@ def test(args, model, test_loader, device):
top1_avg = np.mean(top1_acc) top1_avg = np.mean(top1_acc)
print(f"\tTest set:Loss: {np.mean(losses):.6f} Acc@1: {top1_avg :.6f} ") print(f"\tTest set:Loss: {np.mean(losses):.6f} Acc@1: {top1_avg:.6f}")
return np.mean(top1_acc) return np.mean(top1_acc)

View File

@ -185,7 +185,7 @@ def test(args, model, test_loader, device):
top1_avg = np.mean(top1_acc) top1_avg = np.mean(top1_acc)
print(f"\tTest set:Loss: {np.mean(losses):.6f} Acc@1: {top1_avg :.6f} ") print(f"\tTest set:Loss: {np.mean(losses):.6f} Acc@1: {top1_avg:.6f}")
return np.mean(top1_acc) return np.mean(top1_acc)

View File

@ -108,7 +108,7 @@ def failures_histogram(eager_dir, dynamo_dir, verbose=False, format_issues=False
def as_issue(count, msg, repro, tests): def as_issue(count, msg, repro, tests):
tests = "\n".join(tests) tests = "\n".join(tests)
result = f""" result = f"""
{'-' * 50} {"-" * 50}
{count} Dynamo test are failing with \"{msg}\". {count} Dynamo test are failing with \"{msg}\".
## Repro ## Repro

View File

@ -145,7 +145,7 @@ Labels: {features.labels}
Current category: {commit.category} Current category: {commit.category}
Select from: {', '.join(common.categories)} Select from: {", ".join(common.categories)}
""" """
) )
@ -165,7 +165,7 @@ Select from: {', '.join(common.categories)}
cat_choice = choices[0] cat_choice = choices[0]
print(f"\nSelected: {cat_choice}") print(f"\nSelected: {cat_choice}")
print(f"\nCurrent topic: {commit.topic}") print(f"\nCurrent topic: {commit.topic}")
print(f"""Select from: {', '.join(topics)}""") print(f"""Select from: {", ".join(topics)}""")
topic_choice = None topic_choice = None
while topic_choice is None: while topic_choice is None:
value = input("topic> ").strip() value = input("topic> ").strip()

View File

@ -1454,8 +1454,7 @@ def main():
name=package_name, name=package_name,
version=version, version=version,
description=( description=(
"Tensors and Dynamic neural networks in " "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
"Python with strong GPU acceleration"
), ),
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -759,7 +759,7 @@ def run_test(
stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}" stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}"
if options.verbose: if options.verbose:
unittest_args.append(f'-{"v" * options.verbose}') # in case of pytest unittest_args.append(f"-{'v' * options.verbose}") # in case of pytest
if test_file in RUN_PARALLEL_BLOCKLIST: if test_file in RUN_PARALLEL_BLOCKLIST:
unittest_args = [ unittest_args = [
@ -1895,8 +1895,7 @@ def get_selected_tests(options) -> list[str]:
selected_tests = exclude_tests( selected_tests = exclude_tests(
TESTS_NOT_USING_GRADCHECK, TESTS_NOT_USING_GRADCHECK,
selected_tests, selected_tests,
"Running in slow gradcheck mode, skipping tests " "Running in slow gradcheck mode, skipping tests that don't use gradcheck.",
"that don't use gradcheck.",
exact_match=True, exact_match=True,
) )

View File

@ -151,7 +151,7 @@ class TestDispatch(TestCase):
active_ops.add(op_ix) active_ops.add(op_ix)
try: try:
ops[op_ix](refs[op_ix]) ops[op_ix](refs[op_ix])
check_invariants(f"running ctors {ctor_order[:i + 1]}") check_invariants(f"running ctors {ctor_order[: i + 1]}")
except RuntimeError as e: except RuntimeError as e:
if not expect_raises: if not expect_raises:
raise raise
@ -160,7 +160,7 @@ class TestDispatch(TestCase):
expected, _, expected_provenance = results.setdefault( expected, _, expected_provenance = results.setdefault(
frozenset(active_ops), frozenset(active_ops),
Result( Result(
actual, "", f"error after running ctors {ctor_order[:i + 1]}" actual, "", f"error after running ctors {ctor_order[: i + 1]}"
), ),
) )
self.assertMultiLineEqual(expected, actual, expected_provenance) self.assertMultiLineEqual(expected, actual, expected_provenance)
@ -195,7 +195,7 @@ class TestDispatch(TestCase):
else: else:
active_ops.remove(op_ix) active_ops.remove(op_ix)
check_invariants( check_invariants(
f"running ctors {ctor_order[:last_ctor + 1]}, then running dtors {dtor_order[:i + 1]}" f"running ctors {ctor_order[: last_ctor + 1]}, then running dtors {dtor_order[: i + 1]}"
) )
return results[set_to_report][0] return results[set_to_report][0]

View File

@ -2878,8 +2878,8 @@ class TestNNCOpInfo(TestNNCOpInfoParent):
fx_args.append(f"{k} = {repr(v)}") fx_args.append(f"{k} = {repr(v)}")
code = f""" code = f"""
def f({', '.join(param_names)}): def f({", ".join(param_names)}):
return op.op({', '.join(fx_args)})""" return op.op({", ".join(fx_args)})"""
g = {"torch": torch, "inf": math.inf, "op": op} g = {"torch": torch, "inf": math.inf, "op": op}
exec(code, g) exec(code, g)
f = g["f"] f = g["f"]

View File

@ -575,7 +575,7 @@ def gen_formals(f: NativeFunction) -> str:
# See Note [Plumbing Keys Through The Dispatcher] for details. # See Note [Plumbing Keys Through The Dispatcher] for details.
["c10::DispatchKeySet ks"] ["c10::DispatchKeySet ks"]
+ [ + [
f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' f"{cpp.argument_type(a, binds='__placeholder__', symint=True).cpp_type()} {a.name}"
for a in f.func.schema_order_arguments() for a in f.func.schema_order_arguments()
] ]
) )

View File

@ -723,7 +723,7 @@ def emit_structseq_call(
tn_key = gen_structseq_typename_key(overload.function) tn_key = gen_structseq_typename_key(overload.function)
typename = typenames.get(tn_key) typename = typenames.get(tn_key)
if typename is None: if typename is None:
typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' typename = f"NamedTuple{'' if not typedefs else len(typedefs)}"
typenames[tn_key] = typename typenames[tn_key] = typename
typedefs.append( typedefs.append(
f"""\ f"""\
@ -759,7 +759,7 @@ def generate_return_type_definition_and_registrations(
typename = typenames.get(tn_key) typename = typenames.get(tn_key)
if typename is None: if typename is None:
typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' typename = f"{name}NamedTuple{'' if not definitions else len(definitions)}"
typenames[tn_key] = typename typenames[tn_key] = typename
definitions.append( definitions.append(
f"""\ f"""\
@ -807,7 +807,7 @@ def generate_return_type_declarations(
if typename is None: if typename is None:
typename = ( typename = (
f'{name}NamedTuple{"" if not declarations else len(declarations)}' f"{name}NamedTuple{'' if not declarations else len(declarations)}"
) )
typenames[tn_key] = typename typenames[tn_key] = typename
declarations.append(f"PyTypeObject* get_{name}_structseq();") declarations.append(f"PyTypeObject* get_{name}_structseq();")
@ -1351,7 +1351,7 @@ def emit_single_dispatch(
or (ps.method and ("requires_grad" in parser_outputs)) or (ps.method and ("requires_grad" in parser_outputs))
) )
set_requires_grad = ( set_requires_grad = (
f'.set_requires_grad({parser_outputs["requires_grad"].expr})' f".set_requires_grad({parser_outputs['requires_grad'].expr})"
if need_set_requires_grad if need_set_requires_grad
else "" else ""
) )

View File

@ -381,9 +381,9 @@ def format_postrecord_trace(f: NativeFunction) -> str:
def tie_return_values(f: NativeFunction) -> str: def tie_return_values(f: NativeFunction) -> str:
if len(f.func.returns) == 1: if len(f.func.returns) == 1:
return f'auto {f.func.returns[0].name or "result"}' return f"auto {f.func.returns[0].name or 'result'}"
names = cpp.return_names(f) names = cpp.return_names(f)
return f'auto [{", ".join(names)}]' return f"auto [{', '.join(names)}]"
def get_return_value(f: NativeFunction) -> str: def get_return_value(f: NativeFunction) -> str:
@ -391,7 +391,7 @@ def get_return_value(f: NativeFunction) -> str:
if len(f.func.returns) == 1: if len(f.func.returns) == 1:
return names[0] return names[0]
if f.func.kind() == SchemaKind.out: if f.func.kind() == SchemaKind.out:
return f'std::forward_as_tuple({", ".join(names)})' return f"std::forward_as_tuple({', '.join(names)})"
else: else:
moved = ", ".join(f"std::move({name})" for name in names) moved = ", ".join(f"std::move({name})" for name in names)
return f"std::make_tuple({moved})" return f"std::make_tuple({moved})"
@ -474,7 +474,7 @@ def method_definition(f: NativeFunction) -> str:
# See Note [Plumbing Keys Through The Dispatcher] for details. # See Note [Plumbing Keys Through The Dispatcher] for details.
["c10::DispatchKeySet ks"] ["c10::DispatchKeySet ks"]
+ [ + [
f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' f"{cpp.argument_type(a, binds='__placeholder__', symint=True).cpp_type()} {a.name}"
for a in f.func.schema_order_arguments() for a in f.func.schema_order_arguments()
] ]
) )

View File

@ -108,9 +108,9 @@ def process_function(f: NativeFunction) -> str | None:
exprs.append(arg.name) exprs.append(arg.name)
r += f"""\ r += f"""\
inline at::Tensor {sig.name()}({', '.join(formals)}) {{ inline at::Tensor {sig.name()}({", ".join(formals)}) {{
at::AutoDispatchBelowADInplaceOrView guard; at::AutoDispatchBelowADInplaceOrView guard;
return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); return autograd::make_variable(at::{sig.name()}({", ".join(exprs)}), /*requires_grad=*/{requires_grad});
}} }}
""" """
return r return r

View File

@ -1410,7 +1410,7 @@ def emit_body(
if all_forward_grad_cond: if all_forward_grad_cond:
if not is_inplace_foreach: if not is_inplace_foreach:
body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{') body.append(f"if ({' || '.join(all_forward_grad_cond)}) {{")
body.append(" original_self = self.clone();") body.append(" original_self = self.clone();")
body.append("}") body.append("}")
else: else:
@ -1801,7 +1801,7 @@ def emit_body(
if len(var_names) == 1: if len(var_names) == 1:
return f"_any_has_forward_grad_{var_names[0]}" return f"_any_has_forward_grad_{var_names[0]}"
else: else:
return f'_any_has_forward_grad_{"_".join(var_names)}' return f"_any_has_forward_grad_{'_'.join(var_names)}"
def emit_any_has_forward_grad() -> list[str]: def emit_any_has_forward_grad() -> list[str]:
content: list[str] = [] content: list[str] = []
@ -2089,7 +2089,7 @@ def emit_body(
raise RuntimeError( raise RuntimeError(
f'Unsupported input type for "{name}" when forbidding forward AD usage.' f'Unsupported input type for "{name}" when forbidding forward AD usage.'
) )
return f'({" || ".join(to_check)})' return f"({' || '.join(to_check)})"
else: else:
# (2) If derivative is provided, use that information to determine which inputs # (2) If derivative is provided, use that information to determine which inputs
# to check fw_grad for # to check fw_grad for

View File

@ -33,10 +33,10 @@ def gh_fetch_url_and_headers(
): ):
print( print(
f"""Rate limit exceeded: f"""Rate limit exceeded:
Used: {err.headers['X-RateLimit-Used']} Used: {err.headers["X-RateLimit-Used"]}
Limit: {err.headers['X-RateLimit-Limit']} Limit: {err.headers["X-RateLimit-Limit"]}
Remaining: {err.headers['X-RateLimit-Remaining']} Remaining: {err.headers["X-RateLimit-Remaining"]}
Resets at: {err.headers['x-RateLimit-Reset']}""" Resets at: {err.headers["x-RateLimit-Reset"]}"""
) )
raise raise

View File

@ -41,7 +41,7 @@ def main() -> None:
# Convert all quoted includes to angle brackets # Convert all quoted includes to angle brackets
match = QUOTE_INCLUDE_RE.match(line) match = QUOTE_INCLUDE_RE.match(line)
if match is not None: if match is not None:
print(f"#include <{match.group(1)}>{line[match.end(0):]}", end="") print(f"#include <{match.group(1)}>{line[match.end(0) :]}", end="")
continue continue
match = ANGLE_INCLUDE_RE.match(line) match = ANGLE_INCLUDE_RE.match(line)

View File

@ -161,8 +161,7 @@ def main() -> None:
"--output_file_path", "--output_file_path",
type=str, type=str,
required=True, required=True,
help="Path to destination" help="Path to destinationfolder where selected_mobile_ops.h will be written.",
"folder where selected_mobile_ops.h will be written.",
) )
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
model_file_name = parsed_args.yaml_file_path model_file_name = parsed_args.yaml_file_path

View File

@ -105,7 +105,7 @@ def parse_args() -> argparse.Namespace:
"--destination", "--destination",
default="dist/", default="dist/",
type=str, type=str,
help=("Destination to put the compailed binaries"), help="Destination to put the compiled binaries",
) )
return parser.parse_args() return parser.parse_args()

View File

@ -1482,7 +1482,7 @@ def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None:
_C._set_deterministic_algorithms(True) _C._set_deterministic_algorithms(True)
else: else:
raise RuntimeError( raise RuntimeError(
"invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}" f"invalid value of debug_mode, expected 0, 1, or 2, but got {debug_mode}"
) )

View File

@ -852,8 +852,7 @@ def ignore(drop=False, **kwargs):
if not isinstance(drop, bool): if not isinstance(drop, bool):
raise RuntimeError( raise RuntimeError(
"Argument to @torch.jit.ignore must be a bool or " f"Argument to @torch.jit.ignore must be a bool or a function but got {drop}"
f"a function but got {drop}"
) )
# for backwards compat # for backwards compat
@ -1541,7 +1540,7 @@ def _get_model_id(obj) -> Optional[str]:
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass # In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
# that were previously dropped. To preserve the behavior, explicitly drop them there # that were previously dropped. To preserve the behavior, explicitly drop them there
if sys.version_info > (3, 10): if sys.version_info >= (3, 11):
_drop(enum.Enum.__new__) _drop(enum.Enum.__new__)
_drop(enum.Enum.__format__) _drop(enum.Enum.__format__)
_drop(enum.Enum.__repr__) _drop(enum.Enum.__repr__)

View File

@ -694,9 +694,7 @@ def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
bdim = torch._C._functorch.maybe_get_bdim(tensor) bdim = torch._C._functorch.maybe_get_bdim(tensor)
assert bdim != -1 assert bdim != -1
return ( return (
f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n" f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n{indented_value_repr}\n)"
f"{indented_value_repr}\n"
f")"
) )
if torch._C._functorch.is_gradtrackingtensor(tensor): if torch._C._functorch.is_gradtrackingtensor(tensor):
return f"GradTrackingTensor(lvl={level}, value=\n{indented_value_repr}\n)" return f"GradTrackingTensor(lvl={level}, value=\n{indented_value_repr}\n)"

View File

@ -143,7 +143,7 @@ class Diagnostic:
""" """
if self.logger.isEnabledFor(level): if self.logger.isEnabledFor(level):
indented_format_message = ( indented_format_message = (
f"##{'#' * self._current_log_section_depth } {message}" f"##{'#' * self._current_log_section_depth} {message}"
) )
self.log( self.log(
level, level,

View File

@ -81,10 +81,10 @@ class ONNXProgram:
return f"""\ return f"""\
ONNXProgram( ONNXProgram(
model= model=
{textwrap.indent(str(self.model), ' ' * 8)} {textwrap.indent(str(self.model), " " * 8)}
, ,
exported_program= exported_program=
{textwrap.indent(str(self.exported_program), ' ' * 8)} {textwrap.indent(str(self.exported_program), " " * 8)}
) )
""" """

View File

@ -160,8 +160,7 @@ def _unpack_list(list_value: _C.Value) -> list[_C.Value]:
list_node = list_value.node() list_node = list_value.node()
if list_node.kind() != "prim::ListConstruct": if list_node.kind() != "prim::ListConstruct":
raise errors.SymbolicValueError( raise errors.SymbolicValueError(
f"ONNX symbolic expected node type prim::ListConstruct, " f"ONNX symbolic expected node type prim::ListConstruct, got '{list_node}'.",
f"got '{list_node}'.",
list_value, list_value,
) )
return list(list_node.inputs()) return list(list_node.inputs())

View File

@ -405,7 +405,7 @@ class PythonSignature:
if len(schema_formals) > positional_argc: if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, "*") schema_formals.insert(positional_argc, "*")
return f'{self.name}({", ".join(schema_formals)})' return f"{self.name}({', '.join(schema_formals)})"
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
@ -421,7 +421,7 @@ class PythonSignature:
# pyi also includes self (with no typing/defaults) for methods # pyi also includes self (with no typing/defaults) for methods
if self.method: if self.method:
schema_formals.insert(0, "self") schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..."
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
# only pyi uses vararg signatures # only pyi uses vararg signatures
@ -457,7 +457,7 @@ class PythonSignature:
# pyi also includes self (with no typing/defaults) for methods # pyi also includes self (with no typing/defaults) for methods
if self.method: if self.method:
schema_formals.insert(0, "self") schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..."
# The deprecated python signature involves some special logic, so create a # The deprecated python signature involves some special logic, so create a
@ -498,7 +498,7 @@ class PythonSignatureDeprecated(PythonSignature):
schema_formals.insert(positional_argc, "*") schema_formals.insert(positional_argc, "*")
returns_str = returns_str_pyi(self) returns_str = returns_str_pyi(self)
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' return f"def {self.name}({', '.join(schema_formals)}) -> {returns_str}: ..."
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
# the codegen doesn't include vararg variants for deprecated signatures # the codegen doesn't include vararg variants for deprecated signatures
@ -1474,11 +1474,11 @@ def dispatch_lambda_exprs(
inits.append( inits.append(
f"""\ f"""\
const auto options = TensorOptions() const auto options = TensorOptions()
.dtype({arg_parser_outputs['dtype'].expr}) .dtype({arg_parser_outputs["dtype"].expr})
.device({arg_parser_outputs['device'].expr}) .device({arg_parser_outputs["device"].expr})
.layout({arg_parser_outputs['layout'].expr}) .layout({arg_parser_outputs["layout"].expr})
.requires_grad({arg_parser_outputs['requires_grad'].expr}) .requires_grad({arg_parser_outputs["requires_grad"].expr})
.pinned_memory({arg_parser_outputs['pin_memory'].expr}); .pinned_memory({arg_parser_outputs["pin_memory"].expr});
torch::utils::maybe_initialize_device(options); torch::utils::maybe_initialize_device(options);
""" """
) )
@ -1500,9 +1500,9 @@ torch::utils::maybe_initialize_device(options);
inits.append( inits.append(
f"""\ f"""\
check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr}, check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
{arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr}, {arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
{arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr}); {arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
""" """
) )
# we'll set requires_grad on outgoing tensor # we'll set requires_grad on outgoing tensor

View File

@ -366,9 +366,9 @@ class FunctionalizationLambda:
e.expr for e in translate.translate(full_ctx, call_bindings, method=False) e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
] ]
if not self.is_reverse and maybe_index is not None: if not self.is_reverse and maybe_index is not None:
return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];' return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
else: else:
return f'{inner_call_name}({", ".join(call_exprs)});' return f"{inner_call_name}({', '.join(call_exprs)});"
@staticmethod @staticmethod
def from_func( def from_func(

View File

@ -131,7 +131,7 @@ class TupleCType(CType):
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>' return f"::std::tuple<{','.join([e.cpp_type() for e in self.elems])}>"
def remove_const_ref(self) -> CType: def remove_const_ref(self) -> CType:
return TupleCType([e.remove_const_ref() for e in self.elems]) return TupleCType([e.remove_const_ref() for e in self.elems])

View File

@ -543,7 +543,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
aten_name += "_symint" aten_name += "_symint"
shape_str = f"""\ shape_str = f"""\
{meta_conversion_str} {meta_conversion_str}
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); auto out_meta = at::{dispatch_ns}::{aten_name}({", ".join(meta_call_args)});
{meta_out}""" {meta_out}"""
else: else:
shape_sig = ComputeShapeSignature( shape_sig = ComputeShapeSignature(
@ -559,7 +559,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
func_schema_str = "aten::" + str(func.func) func_schema_str = "aten::" + str(func.func)
shape_str += f""" shape_str += f"""
if(torch::lazy::symbolicShapeEnabled()){{ if(torch::lazy::symbolicShapeEnabled()){{
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; std::vector<torch::jit::IValue> inputs = {{ {", ".join(str(a.name) for a in all_args)} }};
const char* schema_str = "{func_schema_str}"; const char* schema_str = "{func_schema_str}";
applySymbolicShapesOnLT(schema_str, inputs, shapes); applySymbolicShapesOnLT(schema_str, inputs, shapes);
}} }}

View File

@ -53,7 +53,7 @@ def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list
return [ return [
f"""\ f"""\
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)}); void impl({", ".join(a.decl() for a in out_args)});
}}; }};
""" """
] ]

View File

@ -332,7 +332,7 @@ class RegisterDispatchKey:
f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
for i, ret_name in enumerate(return_names) for i, ret_name in enumerate(return_names)
) )
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})"
elif len(return_names) == 1: elif len(return_names) == 1:
ret_name = return_names[0] ret_name = return_names[0]
updates = f"{copy_op}({func_res}, {ret_name});" updates = f"{copy_op}({func_res}, {ret_name});"
@ -448,7 +448,7 @@ class RegisterDispatchKey:
def generate_defn(cpp_sig: CppSignature) -> str: def generate_defn(cpp_sig: CppSignature) -> str:
return f""" return f"""
{cpp_sig.defn()} {{ {cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}} }}
""" """
@ -802,7 +802,7 @@ resize_out(out, sizes, strides, options);
def generate_defn(cpp_sig: CppSignature) -> str: def generate_defn(cpp_sig: CppSignature) -> str:
return f""" return f"""
{cpp_sig.defn()} {{ {cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}} }}
""" """
@ -986,12 +986,15 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
# For an overview of what this template code looks like, see # For an overview of what this template code looks like, see
# https://github.com/pytorch/rfcs/pull/9 # https://github.com/pytorch/rfcs/pull/9
return f"""\ return f"""\
{self.gen_class( {
f, k, self.gen_class(
class_name=class_name, f,
parent_class=parent_class, k,
generate_super=self.g.out.structured_inherits is not None class_name=class_name,
)} parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None,
)
}
{sig.defn()} {{ {sig.defn()} {{
{sig_body_str} {sig_body_str}

View File

@ -477,15 +477,15 @@ def compute_ufunc_cpu_dtype_body(
return f""" return f"""
{body_str} {body_str}
cpu_kernel_vec(iter, cpu_kernel_vec(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} [=]({", ".join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
); );
""" """
else: else:
return f""" return f"""
{body_str} {body_str}
cpu_kernel(iter, cpu_kernel(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
); );
""" """

View File

@ -499,7 +499,7 @@ def generate_static_dispatch_fallback_call(
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});" return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
else: else:
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" {", ".join([str(index.dispatch_key) for index in backend_indices])} ");"""
def static_dispatch( def static_dispatch(
@ -552,7 +552,7 @@ def static_dispatch(
) )
if tensor_args != "": if tensor_args != "":
subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})") subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""") stmts.append(f"""DispatchKeySet _dk_set = {" | ".join(subexprs)};""")
stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);") stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
dispatch_code = [] dispatch_code = []
@ -1016,7 +1016,7 @@ C10_ALWAYS_INLINE
{sig.defn(name)} {{ {sig.defn(name)} {{
{compute_dk} {compute_dk}
return at::_ops::{f.func.name.unambiguous_name()}::redispatch( return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
_dk, {', '.join(a.expr for a in dispatcher_exprs)}); _dk, {", ".join(a.expr for a in dispatcher_exprs)});
}} }}
""" """
elif self.target is Target.REGISTRATION: elif self.target is Target.REGISTRATION:

View File

@ -299,7 +299,7 @@ def gen_declaration_and_definition(
{declaration} {{ {declaration} {{
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
{tmp_result}{backend_call}( {tmp_result}{backend_call}(
{textwrap.indent(', '.join(callsite_exprs), " ")} {textwrap.indent(", ".join(callsite_exprs), " ")}
);{textwrap.indent(ret_assignments_str, " ")} );{textwrap.indent(ret_assignments_str, " ")}
}}); }});
}} }}

View File

@ -119,10 +119,10 @@ def parse_backend_yaml(
# ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
yaml_values.pop("ir_gen", {}) yaml_values.pop("ir_gen", {})
assert ( assert len(yaml_values.keys()) == 0, (
len(yaml_values.keys()) == 0 f"{backend_yaml_path} contains unexpected keys: {', '.join(yaml_values.keys())}. "
), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \ f"Only the following keys are supported: {', '.join(valid_keys)}"
Only the following keys are supported: {", ".join(valid_keys)}' )
def create_backend_index( def create_backend_index(
backend_ops: list[str], backend_ops: list[str],

View File

@ -280,7 +280,7 @@ class ComputeCodegenUnboxedKernels:
[ [
f""" f"""
Kernel( Kernel(
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''} "{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""}
[]({contextArg.defn()}, EValue** stack) {{ []({contextArg.defn()}, EValue** stack) {{
{code_connector.join(code_list)} {code_connector.join(code_list)}

View File

@ -407,7 +407,7 @@ def emit_view_functionalization_body(
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
{unwrap_tensor_args_str} {unwrap_tensor_args_str}
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); return at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
}} }}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = ( auto inverse_return_mode = (
@ -436,7 +436,7 @@ def emit_view_functionalization_body(
{meta_conversion_str} {meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard; at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); reference_tensor_output = at::_ops::{noop_api_name}::call({", ".join(meta_call_args)});
}} }}
// This function adds the above view meta to the current tensor and replays them off the base, // This function adds the above view meta to the current tensor and replays them off the base,
// mutating the size/stride info of the current FunctionalTensorWrapper. // mutating the size/stride info of the current FunctionalTensorWrapper.
@ -462,7 +462,7 @@ def emit_view_functionalization_body(
if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
// functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); return at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
}} }}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = ( auto inverse_return_mode = (
@ -477,15 +477,15 @@ def emit_view_functionalization_body(
{meta_conversion_str} {meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard; at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); reference_tensor_output = at::_ops::{noop_api_name}::call({", ".join(meta_call_args)});
}} }}
{return_type} tmp_output; {return_type} tmp_output;
{{ {{
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
if (reapply_views) {{ if (reapply_views) {{
tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); tmp_output = at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
}} else {{ }} else {{
tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)}); tmp_output = at::_ops::{api_name}::call({", ".join(view_redispatch_args)});
}} }}
}} }}
{symbolic_inputs_check} {symbolic_inputs_check}
@ -502,7 +502,7 @@ def emit_view_functionalization_body(
}}, }},
/*has_symbolic_inputs=*/{symbolic_inputs_varname}, /*has_symbolic_inputs=*/{symbolic_inputs_varname},
/*is_multi_output=*/{str(is_multi_output_view).lower()}, /*is_multi_output=*/{str(is_multi_output_view).lower()},
/*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()} /*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
); );
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass] // See Note [Propagating strides in the functionalization pass]
@ -686,7 +686,7 @@ def emit_inplace_functionalization_body(
[ [
f""" f"""
at::functionalization::impl::replace_( at::functionalization::impl::replace_(
{a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); {a.name}, {"std::get<" + str(i) + ">(tmp_output)" if len(f.func.returns) > 1 else "tmp_output"});
at::functionalization::impl::commit_update({a.name});""" at::functionalization::impl::commit_update({a.name});"""
for (i, a) in enumerate(f.func.arguments.out) for (i, a) in enumerate(f.func.arguments.out)
if a.annotation and a.annotation.is_write and a.type.is_tensor_like() if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
@ -722,7 +722,7 @@ def emit_inplace_functionalization_body(
{meta_conversion_str} {meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard; at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)}); at::_ops::{f.func.name.unambiguous_name()}::call({", ".join(a.name for a in meta_call_ctx)});
}} }}
{unwrap_tensor_args_str} {unwrap_tensor_args_str}
if (!({check_all_mutated_args_are_functional})) {{ if (!({check_all_mutated_args_are_functional})) {{
@ -736,16 +736,16 @@ def emit_inplace_functionalization_body(
}} else {{ }} else {{
// case 2: arguments are not functional tensors, so we no-op and redispatch. // case 2: arguments are not functional tensors, so we no-op and redispatch.
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
{maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); {maybe_create_output(f, "tmp_output")}at::_ops::{f.func.name.unambiguous_name()}::call({", ".join(inplace_exprs)});
{return_from_mutable_noop_redispatch(f, 'tmp_output')} {return_from_mutable_noop_redispatch(f, "tmp_output")}
}} }}
}} else {{ }} else {{
{return_type} tmp_output; {return_type} tmp_output;
{{ {{
at::AutoDispatchSkipFunctionalize guard; at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)}); tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({", ".join(functional_exprs)});
}} }}
{wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')} {wrap_propagate_mutations_and_return(f, g.functional, "tmp_output")}
}} }}
}}""" }}"""

View File

@ -97,7 +97,7 @@ def gen_case_where_all_bdims_are_none(
e.expr for e in translate(outer_sig.arguments(), sig.arguments()) e.expr for e in translate(outer_sig.arguments(), sig.arguments())
) )
return f"""\ return f"""\
if ({' && '.join(conditions)}) {{ if ({" && ".join(conditions)}) {{
return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args}); return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
}}""" }}"""
@ -124,7 +124,7 @@ def gen_returns(
if len(wrapped_returns) == 1: if len(wrapped_returns) == 1:
result = f"return {wrapped_returns[0]};" result = f"return {wrapped_returns[0]};"
else: else:
result = f'return std::make_tuple({", ".join(wrapped_returns)});' result = f"return std::make_tuple({', '.join(wrapped_returns)});"
return result return result
@ -168,14 +168,14 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
return f"""\ return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule> template <typename batch_rule_t, batch_rule_t batch_rule>
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ {sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer(); auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
int64_t {cur_level_var} = maybe_layer->layerId(); int64_t {cur_level_var} = maybe_layer->layerId();
{textwrap.indent(bdims_all_none_case, " ")} {textwrap.indent(bdims_all_none_case, " ")}
{textwrap.indent(unwraps, " ")} {textwrap.indent(unwraps, " ")}
batch_rule({', '.join(unwrapped_arg_list)}); batch_rule({", ".join(unwrapped_arg_list)});
return {schema.arguments.flat_all[0].name}; return {schema.arguments.flat_all[0].name};
}}""" }}"""
@ -190,14 +190,14 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
return f"""\ return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule> template <typename batch_rule_t, batch_rule_t batch_rule>
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ {sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer(); auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
int64_t {cur_level_var} = maybe_layer->layerId(); int64_t {cur_level_var} = maybe_layer->layerId();
{textwrap.indent(bdims_all_none_case, " ")} {textwrap.indent(bdims_all_none_case, " ")}
{textwrap.indent(unwraps, " ")} {textwrap.indent(unwraps, " ")}
batch_rule({', '.join(unwrapped_arg_list)}); batch_rule({", ".join(unwrapped_arg_list)});
}}""" }}"""
@ -240,14 +240,14 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
wrapped_returns = gen_returns(returns, cur_level_var, results_var) wrapped_returns = gen_returns(returns, cur_level_var, results_var)
return f"""\ return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule> template <typename batch_rule_t, batch_rule_t batch_rule>
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ {sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
auto maybe_layer = maybeCurrentDynamicLayer(); auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
int64_t {cur_level_var} = maybe_layer->layerId(); int64_t {cur_level_var} = maybe_layer->layerId();
{textwrap.indent(bdims_all_none_case, " ")} {textwrap.indent(bdims_all_none_case, " ")}
{textwrap.indent(unwraps, " ")} {textwrap.indent(unwraps, " ")}
auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)}); auto {results_var} = batch_rule({", ".join(unwrapped_arg_list)});
{wrapped_returns} {wrapped_returns}
}}""" }}"""

View File

@ -1822,7 +1822,7 @@ class Annotation:
alias_set = f"{alias_set}!" alias_set = f"{alias_set}!"
alias_set_after = "|".join(self.alias_set_after) alias_set_after = "|".join(self.alias_set_after)
if alias_set_after: if alias_set_after:
alias_set = f'{alias_set}{" -> "}{alias_set_after}' alias_set = f"{alias_set} -> {alias_set_after}"
return alias_set return alias_set

View File

@ -534,7 +534,7 @@ def generate_non_out_variant_call(
kernel_name = get_kernel_name(g, backend_index) kernel_name = get_kernel_name(g, backend_index)
arg_names = (arg.name for arg in schema.schema_order_arguments()) arg_names = (arg.name for arg in schema.schema_order_arguments())
namespace_name = "cpu" if g.structured else "native" namespace_name = "cpu" if g.structured else "native"
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
def generate_call_to_view_ops( def generate_call_to_view_ops(
@ -547,7 +547,7 @@ def generate_call_to_view_ops(
kernel_name = kernel.kernel kernel_name = kernel.kernel
arg_names = (arg.name for arg in schema.schema_order_arguments()) arg_names = (arg.name for arg in schema.schema_order_arguments())
namespace_name = "native" namespace_name = "native"
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
def generate_out_variant_call( def generate_out_variant_call(