mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR enables `PLW0127` in ruff, which checks self-assignment of variables with the form `var=var`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165851 Approved by: https://github.com/Lucaskabela
145 lines
3.6 KiB
Python
145 lines
3.6 KiB
Python
import re
|
|
|
|
import torch
|
|
|
|
|
|
"""
|
|
Instructions:
|
|
|
|
1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt
|
|
2. python test/xfail_suggester.py
|
|
"""
|
|
|
|
with open("result.txt") as f:
|
|
lines = f.readlines()
|
|
|
|
failed = [line for line in lines if line.startswith("FAILED")]
|
|
p = re.compile("FAILED test/test_\w+.py::\w+::(\S+)") # noqa: W605
|
|
|
|
|
|
def get_failed_test(line):
|
|
m = p.match(line)
|
|
if m is None:
|
|
return None
|
|
return m.group(1)
|
|
|
|
|
|
base_names = {
|
|
"test_grad_",
|
|
"test_vjp_",
|
|
"test_vmapvjp_",
|
|
"test_vmapvjp_has_batch_rule_",
|
|
"test_vjpvmap_",
|
|
"test_jvp_",
|
|
"test_vmapjvp_",
|
|
"test_vmapjvpall_has_batch_rule_",
|
|
"test_vmapjvpall_",
|
|
"test_jvpvjp_",
|
|
"test_vjpvjp_",
|
|
"test_decomposition_",
|
|
"test_make_fx_exhaustive_",
|
|
"test_vmap_exhaustive_",
|
|
"test_op_has_batch_rule_",
|
|
"test_vmap_autograd_grad_",
|
|
}
|
|
|
|
failed_tests = [get_failed_test(line) for line in lines]
|
|
failed_tests = [match for match in failed_tests if match is not None]
|
|
failed_tests = sorted(failed_tests)
|
|
|
|
suggested_xfails = {}
|
|
|
|
|
|
def remove_device_dtype(test):
|
|
return "_".join(test.split("_")[:-2])
|
|
|
|
|
|
def belongs_to_base(test, base):
|
|
if not test.startswith(base):
|
|
return False
|
|
candidates = [try_base for try_base in base_names if len(try_base) > len(base)]
|
|
for candidate in candidates:
|
|
if test.startswith(candidate):
|
|
return False
|
|
return True
|
|
|
|
|
|
def parse_namespace(base):
|
|
mappings = {
|
|
"nn_functional_": "nn.functional",
|
|
"fft_": "fft",
|
|
"linalg_": "linalg",
|
|
"_masked_": "_masked",
|
|
"sparse_": "sparse",
|
|
"special_": "special",
|
|
}
|
|
for heading in mappings.keys():
|
|
if base.startswith(heading):
|
|
return mappings[heading], base[len(heading) :]
|
|
return None, base
|
|
|
|
|
|
def get_torch_module(namespace):
|
|
if namespace is None:
|
|
return torch
|
|
if namespace == "nn.functional":
|
|
return torch.nn.functional
|
|
return getattr(torch, namespace)
|
|
|
|
|
|
def parse_base(base):
|
|
namespace, rest = parse_namespace(base)
|
|
|
|
apis = dir(get_torch_module(namespace))
|
|
apis = sorted(apis, key=lambda x: -len(x))
|
|
|
|
api = rest
|
|
variant = ""
|
|
for candidate in apis:
|
|
if rest.startswith(candidate):
|
|
api = candidate
|
|
variant = rest[len(candidate) + 1 :]
|
|
break
|
|
print(base, namespace, api, variant)
|
|
return namespace, api, variant
|
|
|
|
|
|
def any_starts_with(strs, thing):
|
|
for s in strs:
|
|
if s.startswith(thing):
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_suggested_xfails(base, tests):
|
|
result = []
|
|
tests = [test[len(base) :] for test in tests if belongs_to_base(test, base)]
|
|
|
|
base_tests = {remove_device_dtype(test) for test in tests}
|
|
tests = set(tests)
|
|
for base in base_tests:
|
|
cpu_variant = base + "_cpu_float32"
|
|
cuda_variant = base + "_cuda_float32"
|
|
namespace, api, variant = parse_base(base)
|
|
if namespace is not None:
|
|
api = f"{namespace}.{api}"
|
|
if cpu_variant in tests and cuda_variant in tests:
|
|
result.append(f"xfail('{api}', '{variant}'),")
|
|
continue
|
|
if cpu_variant in tests:
|
|
result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
|
|
continue
|
|
if cuda_variant in tests:
|
|
result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
|
|
continue
|
|
result.append(f"skip('{api}', '{variant}',")
|
|
return result
|
|
|
|
|
|
result = {base: get_suggested_xfails(base, failed_tests) for base in base_names}
|
|
for k, v in result.items():
|
|
print("=" * 50)
|
|
print(k)
|
|
print("=" * 50)
|
|
print("\n".join(v))
|