pytorch/scripts/compile_tests/common.py
rzou c30346db0e Check in some torch.compile helper scripts (#117400)
- passrate.py: compute the pass rate
- update_failures.py: update `dynamo_test_failures.py`

Both of these scripts require you to download the test results from CI
locally. Maybe we can automate this more in the future. Checking these
in for now, with no tests :P.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117400
Approved by: https://github.com/voznesenskym
ghstack dependencies: #117391
2024-01-16 17:14:43 +00:00

113 lines
2.8 KiB
Python

import os
import xml.etree.ElementTree as ET
def open_test_results(directory):
xmls = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".xml"):
tree = ET.parse(f"{root}/{file}")
xmls.append(tree)
return xmls
def get_testcases(xmls):
testcases = []
for xml in xmls:
root = xml.getroot()
testcases.extend(list(root.iter("testcase")))
return testcases
def find(testcase, condition):
children = list(testcase.iter())
assert children[0] is testcase
children = children[1:]
return condition(children)
def skipped_test(testcase):
def condition(children):
tags = [child.tag for child in children]
if "skipped" in tags:
return True
return False
return find(testcase, condition)
def passed_test(testcase):
def condition(children):
if len(children) == 0:
return True
tags = [child.tag for child in children]
if "skipped" in tags:
return False
if "failed" in tags:
return False
return True
return find(testcase, condition)
def key(testcase):
file = testcase.attrib.get("file", "UNKNOWN")
classname = testcase.attrib["classname"]
name = testcase.attrib["name"]
return "::".join([file, classname, name])
def get_passed_testcases(xmls):
testcases = get_testcases(xmls)
passed_testcases = [testcase for testcase in testcases if passed_test(testcase)]
return passed_testcases
def get_excluded_testcases(xmls):
testcases = get_testcases(xmls)
excluded_testcases = [t for t in testcases if excluded_testcase(t)]
return excluded_testcases
def excluded_testcase(testcase):
def condition(children):
for child in children:
if child.tag == "skipped":
if "Policy: we don't run" in child.attrib["message"]:
return True
return False
return find(testcase, condition)
def is_unexpected_success(testcase):
def condition(children):
for child in children:
if child.tag != "failure":
continue
is_unexpected_success = (
"unexpected success" in child.attrib["message"].lower()
)
if is_unexpected_success:
return True
return False
return find(testcase, condition)
# NB: not an unexpected success
def is_failure(testcase):
def condition(children):
for child in children:
if child.tag != "failure":
continue
is_unexpected_success = (
"unexpected success" in child.attrib["message"].lower()
)
if not is_unexpected_success:
return True
return False
return find(testcase, condition)