Apply ruff fixes to tests (#146140)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146140
Approved by: https://github.com/albanD
This commit is contained in:
cyy 2025-02-04 05:40:59 +00:00 committed by PyTorch MergeBot
parent 71e3575525
commit 1c16cf70c3
6 changed files with 14 additions and 16 deletions

View File

@ -135,8 +135,7 @@ class _NodeReporterReruns(_NodeReporter):
else:
assert isinstance(report.longrepr, tuple)
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
skipreason = skipreason.removeprefix("Skipped: ")
details = f"{filename}:{lineno}: {skipreason}"
skipped = ET.Element(

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, List, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional
import torch
import torch._dynamo
@ -81,7 +81,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]:
# look up dL_dentries. If a variable is never used to compute the loss,
# we consider its gradient None, see the note below about zeros for more information.
def gather_grad(entries: List[str]):
def gather_grad(entries: list[str]):
return [dL_d[entry] if entry in dL_d else None for entry in entries]
# propagate the gradient information backward
@ -127,7 +127,7 @@ def operator_mul(self: Variable, rhs: Variable) -> Variable:
outputs = [r.name]
# define backprop
def propagate(dL_doutputs: List[Variable]):
def propagate(dL_doutputs: list[Variable]):
(dL_dr,) = dL_doutputs
dr_dself = rhs # partial derivative of r = self*rhs
@ -150,7 +150,7 @@ def operator_add(self: Variable, rhs: Variable) -> Variable:
r = Variable(self.value + rhs.value)
# print(f'{r.name} = {self.name} + {rhs.name}')
def propagate(dL_doutputs: List[Variable]):
def propagate(dL_doutputs: list[Variable]):
(dL_dr,) = dL_doutputs
dr_dself = 1.0
dr_drhs = 1.0
@ -168,7 +168,7 @@ def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
r = Variable(torch.sum(self.value), name=name)
# print(f'{r.name} = {self.name}.sum()')
def propagate(dL_doutputs: List[Variable]):
def propagate(dL_doutputs: list[Variable]):
(dL_dr,) = dL_doutputs
size = self.value.size()
return [dL_dr.expand(*size)]
@ -179,12 +179,12 @@ def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
return r
def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
def operator_expand(self: Variable, sizes: list[int]) -> "Variable":
assert self.value.dim() == 0 # only works for scalars
r = Variable(self.value.expand(sizes))
# print(f'{r.name} = {self.name}.expand({sizes})')
def propagate(dL_doutputs: List[Variable]):
def propagate(dL_doutputs: list[Variable]):
(dL_dr,) = dL_doutputs
return [dL_dr.sum()]

View File

@ -44,7 +44,7 @@ def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"):
if line.startswith(".. autofunction::")
]
lines = api_lines1 + api_lines2
lines = [line[7:] if line.startswith("Tensor.") else line for line in lines]
lines = [line.removeprefix("Tensor.") for line in lines]
lines = [line for line in lines if hasattr(module, line)]
for line in lines:
api = getattr(module, line)

View File

@ -6,7 +6,7 @@ import dataclasses
import io
import logging
import typing
from typing import AbstractSet, Protocol
from typing import Protocol
import torch
from torch.onnx import errors
@ -27,7 +27,7 @@ class _SarifLogBuilder(Protocol):
def _assert_has_diagnostics(
sarif_log_builder: _SarifLogBuilder,
rule_level_pairs: AbstractSet[tuple[infra.Rule, infra.Level]],
rule_level_pairs: set[tuple[infra.Rule, infra.Level]],
):
sarif_log = sarif_log_builder.sarif_log()
unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
@ -62,7 +62,7 @@ class _RuleCollectionForTest(infra.RuleCollection):
def assert_all_diagnostics(
test_suite: unittest.TestCase,
sarif_log_builder: _SarifLogBuilder,
rule_level_pairs: AbstractSet[tuple[infra.Rule, infra.Level]],
rule_level_pairs: set[tuple[infra.Rule, infra.Level]],
):
"""Context manager to assert that all diagnostics are emitted.

View File

@ -12,7 +12,7 @@ import tempfile
import typing
import unittest
from types import BuiltinFunctionType
from typing import Callable, NamedTuple, Optional, Union, List
from typing import Callable, List, NamedTuple, Optional, Union
import torch
import torch.fx.experimental.meta_tracer

View File

@ -706,8 +706,7 @@ def generate_tensor_like_override_tests(cls):
for arg in annotated_args[func]:
# Guess valid input to aten function based on type of argument
t = arg["simple_type"]
if t.endswith("?"):
t = t[:-1]
t = t.removesuffix("?")
if t == "Tensor" and is_method and arg["name"] == "self":
# See "Note: properties and __get__"
func = func.__get__(instance_gen())