mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/55868. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57458 Reviewed By: mruberry Differential Revision: D28365745 Pulled By: walterddr fbshipit-source-id: 35cc3fa85f87b0ef98cf970f620ab909d240c7be
94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
from typing import Any, List, Union
|
|
|
|
try:
|
|
from junitparser import JUnitXml, TestSuite, TestCase, Error, Failure # type: ignore[import]
|
|
except ImportError:
|
|
raise ImportError(
|
|
"junitparser not found, please install with 'pip install junitparser'"
|
|
)
|
|
|
|
try:
|
|
import rich # type: ignore[import]
|
|
except ImportError:
|
|
print("rich not found, for color output use 'pip install rich'")
|
|
|
|
def parse_junit_reports(path_to_reports: str) -> List[TestCase]:
|
|
def parse_file(path: str) -> List[TestCase]:
|
|
try:
|
|
return convert_junit_to_testcases(JUnitXml.fromfile(path))
|
|
except Exception as err:
|
|
rich.print(f":Warning: [yellow]Warning[/yellow]: Failed to read {path}: {err}")
|
|
return []
|
|
|
|
if not os.path.exists(path_to_reports):
|
|
raise FileNotFoundError(f"Path '{path_to_reports}', not found")
|
|
# Return early if the path provided is just a file
|
|
if os.path.isfile(path_to_reports):
|
|
return parse_file(path_to_reports)
|
|
ret_xml = []
|
|
if os.path.isdir(path_to_reports):
|
|
for root, _, files in os.walk(path_to_reports):
|
|
for fname in [f for f in files if f.endswith("xml")]:
|
|
ret_xml += parse_file(os.path.join(root, fname))
|
|
return ret_xml
|
|
|
|
|
|
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]:
|
|
testcases = []
|
|
for item in xml:
|
|
if isinstance(item, TestSuite):
|
|
testcases.extend(convert_junit_to_testcases(item))
|
|
else:
|
|
testcases.append(item)
|
|
return testcases
|
|
|
|
def render_tests(testcases: List[TestCase]) -> None:
|
|
num_passed = 0
|
|
num_skipped = 0
|
|
num_failed = 0
|
|
for testcase in testcases:
|
|
if not testcase.result:
|
|
num_passed += 1
|
|
continue
|
|
for result in testcase.result:
|
|
if isinstance(result, Error):
|
|
icon = ":rotating_light: [white on red]ERROR[/white on red]:"
|
|
num_failed += 1
|
|
elif isinstance(result, Failure):
|
|
icon = ":x: [white on red]Failure[/white on red]:"
|
|
num_failed += 1
|
|
else:
|
|
num_skipped += 1
|
|
continue
|
|
rich.print(f"{icon} [bold red]{testcase.classname}.{testcase.name}[/bold red]")
|
|
print(f"{result.text}")
|
|
rich.print(f":white_check_mark: {num_passed} [green]Passed[green]")
|
|
rich.print(f":dash: {num_skipped} [grey]Skipped[grey]")
|
|
rich.print(f":rotating_light: {num_failed} [grey]Failed[grey]")
|
|
|
|
|
|
|
|
def parse_args() -> Any:
|
|
parser = argparse.ArgumentParser(
|
|
description="Render xunit output for failed tests",
|
|
)
|
|
parser.add_argument(
|
|
"report_path",
|
|
help="Base xunit reports (single file or directory) to compare to",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
options = parse_args()
|
|
testcases = parse_junit_reports(options.report_path)
|
|
render_tests(testcases)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|