mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Apply ruff fixes to tests (#146140)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/146140 Approved by: https://github.com/albanD
This commit is contained in:
parent
71e3575525
commit
1c16cf70c3
|
|
@ -135,8 +135,7 @@ class _NodeReporterReruns(_NodeReporter):
|
|||
else:
|
||||
assert isinstance(report.longrepr, tuple)
|
||||
filename, lineno, skipreason = report.longrepr
|
||||
if skipreason.startswith("Skipped: "):
|
||||
skipreason = skipreason[9:]
|
||||
skipreason = skipreason.removeprefix("Skipped: ")
|
||||
details = f"{filename}:{lineno}: {skipreason}"
|
||||
|
||||
skipped = ET.Element(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
from typing import Callable, List, NamedTuple, Optional
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
|
@ -81,7 +81,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]:
|
|||
|
||||
# look up dL_dentries. If a variable is never used to compute the loss,
|
||||
# we consider its gradient None, see the note below about zeros for more information.
|
||||
def gather_grad(entries: List[str]):
|
||||
def gather_grad(entries: list[str]):
|
||||
return [dL_d[entry] if entry in dL_d else None for entry in entries]
|
||||
|
||||
# propagate the gradient information backward
|
||||
|
|
@ -127,7 +127,7 @@ def operator_mul(self: Variable, rhs: Variable) -> Variable:
|
|||
outputs = [r.name]
|
||||
|
||||
# define backprop
|
||||
def propagate(dL_doutputs: List[Variable]):
|
||||
def propagate(dL_doutputs: list[Variable]):
|
||||
(dL_dr,) = dL_doutputs
|
||||
|
||||
dr_dself = rhs # partial derivative of r = self*rhs
|
||||
|
|
@ -150,7 +150,7 @@ def operator_add(self: Variable, rhs: Variable) -> Variable:
|
|||
r = Variable(self.value + rhs.value)
|
||||
# print(f'{r.name} = {self.name} + {rhs.name}')
|
||||
|
||||
def propagate(dL_doutputs: List[Variable]):
|
||||
def propagate(dL_doutputs: list[Variable]):
|
||||
(dL_dr,) = dL_doutputs
|
||||
dr_dself = 1.0
|
||||
dr_drhs = 1.0
|
||||
|
|
@ -168,7 +168,7 @@ def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
|
|||
r = Variable(torch.sum(self.value), name=name)
|
||||
# print(f'{r.name} = {self.name}.sum()')
|
||||
|
||||
def propagate(dL_doutputs: List[Variable]):
|
||||
def propagate(dL_doutputs: list[Variable]):
|
||||
(dL_dr,) = dL_doutputs
|
||||
size = self.value.size()
|
||||
return [dL_dr.expand(*size)]
|
||||
|
|
@ -179,12 +179,12 @@ def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
|
|||
return r
|
||||
|
||||
|
||||
def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
|
||||
def operator_expand(self: Variable, sizes: list[int]) -> "Variable":
|
||||
assert self.value.dim() == 0 # only works for scalars
|
||||
r = Variable(self.value.expand(sizes))
|
||||
# print(f'{r.name} = {self.name}.expand({sizes})')
|
||||
|
||||
def propagate(dL_doutputs: List[Variable]):
|
||||
def propagate(dL_doutputs: list[Variable]):
|
||||
(dL_dr,) = dL_doutputs
|
||||
return [dL_dr.sum()]
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"):
|
|||
if line.startswith(".. autofunction::")
|
||||
]
|
||||
lines = api_lines1 + api_lines2
|
||||
lines = [line[7:] if line.startswith("Tensor.") else line for line in lines]
|
||||
lines = [line.removeprefix("Tensor.") for line in lines]
|
||||
lines = [line for line in lines if hasattr(module, line)]
|
||||
for line in lines:
|
||||
api = getattr(module, line)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import dataclasses
|
|||
import io
|
||||
import logging
|
||||
import typing
|
||||
from typing import AbstractSet, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
import torch
|
||||
from torch.onnx import errors
|
||||
|
|
@ -27,7 +27,7 @@ class _SarifLogBuilder(Protocol):
|
|||
|
||||
def _assert_has_diagnostics(
|
||||
sarif_log_builder: _SarifLogBuilder,
|
||||
rule_level_pairs: AbstractSet[tuple[infra.Rule, infra.Level]],
|
||||
rule_level_pairs: set[tuple[infra.Rule, infra.Level]],
|
||||
):
|
||||
sarif_log = sarif_log_builder.sarif_log()
|
||||
unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
|
||||
|
|
@ -62,7 +62,7 @@ class _RuleCollectionForTest(infra.RuleCollection):
|
|||
def assert_all_diagnostics(
|
||||
test_suite: unittest.TestCase,
|
||||
sarif_log_builder: _SarifLogBuilder,
|
||||
rule_level_pairs: AbstractSet[tuple[infra.Rule, infra.Level]],
|
||||
rule_level_pairs: set[tuple[infra.Rule, infra.Level]],
|
||||
):
|
||||
"""Context manager to assert that all diagnostics are emitted.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import tempfile
|
|||
import typing
|
||||
import unittest
|
||||
from types import BuiltinFunctionType
|
||||
from typing import Callable, NamedTuple, Optional, Union, List
|
||||
from typing import Callable, List, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.experimental.meta_tracer
|
||||
|
|
|
|||
|
|
@ -706,8 +706,7 @@ def generate_tensor_like_override_tests(cls):
|
|||
for arg in annotated_args[func]:
|
||||
# Guess valid input to aten function based on type of argument
|
||||
t = arg["simple_type"]
|
||||
if t.endswith("?"):
|
||||
t = t[:-1]
|
||||
t = t.removesuffix("?")
|
||||
if t == "Tensor" and is_method and arg["name"] == "self":
|
||||
# See "Note: properties and __get__"
|
||||
func = func.__get__(instance_gen())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user