mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This was an Meta-internal change that seems to have deleted a bunch of types and is thus causing us to fail mypy type checking. Reverting that portion of the change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77327 Approved by: https://github.com/qihqi
554 lines
17 KiB
Python
Executable File
554 lines
17 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import asyncio
|
|
import collections
|
|
import csv
|
|
import hashlib
|
|
import itertools
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import shlex
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from typing import Awaitable, DefaultDict, Dict, List, Match, Optional, Set, cast
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]...
|
|
|
|
Run the commands given by nvcc --dryrun, in parallel.
|
|
|
|
All flags for this script itself (see the "optional arguments" section
|
|
of --help) must be passed before the first "--". Everything after that
|
|
first "--" is passed directly to nvcc, with the --dryrun argument added.
|
|
|
|
This script only works with the "normal" execution path of nvcc, so for
|
|
instance passing --help (after "--") doesn't work since the --help
|
|
execution path doesn't compile anything, so adding --dryrun there gives
|
|
nothing in stderr.
|
|
"""
|
|
parser = argparse.ArgumentParser(help_msg)
|
|
parser.add_argument(
|
|
"--faithful",
|
|
action="store_true",
|
|
help="don't modify the commands given by nvcc (slower)",
|
|
)
|
|
parser.add_argument(
|
|
"--graph",
|
|
metavar="FILE.gv",
|
|
help="write Graphviz DOT file with execution graph",
|
|
)
|
|
parser.add_argument(
|
|
"--nvcc",
|
|
metavar="PATH",
|
|
default="nvcc",
|
|
help='path to nvcc (default is just "nvcc")',
|
|
)
|
|
parser.add_argument(
|
|
"--save",
|
|
metavar="DIR",
|
|
help="copy intermediate files from each command into DIR",
|
|
)
|
|
parser.add_argument(
|
|
"--sequential",
|
|
action="store_true",
|
|
help="sequence commands instead of using the graph (slower)",
|
|
)
|
|
parser.add_argument(
|
|
"--table",
|
|
metavar="FILE.csv",
|
|
help="write CSV with times and intermediate file sizes",
|
|
)
|
|
parser.add_argument(
|
|
"--verbose",
|
|
metavar="FILE.txt",
|
|
help="like nvcc --verbose, but expanded and into a file",
|
|
)
|
|
default_config = parser.parse_args([])
|
|
|
|
|
|
# docs about temporary directories used by NVCC
|
|
url_base = "https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html"
|
|
url_vars = f"{url_base}#keeping-intermediate-phase-files"
|
|
|
|
|
|
# regex for temporary file names
|
|
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
|
|
|
|
|
|
def fast_nvcc_warn(warning: str) -> None:
|
|
"""
|
|
Warn the user about something regarding fast_nvcc.
|
|
"""
|
|
print(f"warning (fast_nvcc): {warning}", file=sys.stderr)
|
|
|
|
|
|
def warn_if_windows() -> None:
|
|
"""
|
|
Warn the user that using fast_nvcc on Windows might not work.
|
|
"""
|
|
# use os.name instead of platform.system() because there is a
|
|
# platform.py file in this directory, making it very difficult to
|
|
# import the platform module from the Python standard library
|
|
if os.name == "nt":
|
|
fast_nvcc_warn("untested on Windows, might not work; see this URL:")
|
|
fast_nvcc_warn(url_vars)
|
|
|
|
|
|
def warn_if_tmpdir_flag(args: List[str]) -> None:
|
|
"""
|
|
Warn the user that using fast_nvcc with some flags might not work.
|
|
"""
|
|
file_path_specs = "file-and-path-specifications"
|
|
guiding_driver = "options-for-guiding-compiler-driver"
|
|
scary_flags = {
|
|
"--objdir-as-tempdir": file_path_specs,
|
|
"-objtemp": file_path_specs,
|
|
"--keep": guiding_driver,
|
|
"-keep": guiding_driver,
|
|
"--keep-dir": guiding_driver,
|
|
"-keep-dir": guiding_driver,
|
|
"--save-temps": guiding_driver,
|
|
"-save-temps": guiding_driver,
|
|
}
|
|
for arg in args:
|
|
for flag, frag in scary_flags.items():
|
|
if re.match(rf"^{re.escape(flag)}(?:=.*)?$", arg):
|
|
fast_nvcc_warn(f"{flag} not supported since it interacts with")
|
|
fast_nvcc_warn("TMPDIR, so fast_nvcc may break; see this URL:")
|
|
fast_nvcc_warn(f"{url_base}#{frag}")
|
|
|
|
|
|
class DryunData(TypedDict):
|
|
env: Dict[str, str]
|
|
commands: List[str]
|
|
exit_code: int
|
|
|
|
|
|
def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
|
|
"""
|
|
Return parsed environment variables and commands from nvcc --dryrun.
|
|
"""
|
|
result = subprocess.run( # type: ignore[call-overload]
|
|
[binary, "--dryrun"] + args,
|
|
capture_output=True,
|
|
encoding="ascii", # this is just a guess
|
|
)
|
|
print(result.stdout, end="")
|
|
env = {}
|
|
commands = []
|
|
for line in result.stderr.splitlines():
|
|
match = re.match(r"^#\$ (.*)$", line)
|
|
if match:
|
|
(stripped,) = match.groups()
|
|
mapping = re.match(r"^(\w+)=(.*)$", stripped)
|
|
if mapping:
|
|
name, val = mapping.groups()
|
|
env[name] = val
|
|
else:
|
|
commands.append(stripped)
|
|
else:
|
|
print(line, file=sys.stderr)
|
|
return {"env": env, "commands": commands, "exit_code": result.returncode}
|
|
|
|
|
|
def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
|
|
"""
|
|
Warn the user that setting TMPDIR with fast_nvcc might not work.
|
|
"""
|
|
if os.getenv("TMPDIR") or "TMPDIR" in env:
|
|
fast_nvcc_warn("TMPDIR is set, might not work; see this URL:")
|
|
fast_nvcc_warn(url_vars)
|
|
|
|
|
|
def contains_non_executable(commands: List[str]) -> bool:
|
|
for command in commands:
|
|
# This is to deal with special command dry-run result from NVCC such as:
|
|
# ```
|
|
# #$ "/lib64/ccache"/c++ -std=c++11 -E -x c++ -D__CUDACC__ -D__NVCC__ -fPIC -fvisibility=hidden -O3 \
|
|
# -I ... -m64 "reduce_scatter.cu" > "/tmp/tmpxft_0037fae3_00000000-5_reduce_scatter.cpp4.ii
|
|
# #$ -- Filter Dependencies -- > ... pytorch/build/nccl/obj/collectives/device/reduce_scatter.dep.tmp
|
|
# ```
|
|
if command.startswith("--"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def module_id_contents(command: List[str]) -> str:
|
|
"""
|
|
Guess the contents of the .module_id file contained within command.
|
|
"""
|
|
if command[0] == "cicc":
|
|
path = command[-3]
|
|
elif command[0] == "cudafe++":
|
|
path = command[-1]
|
|
middle = pathlib.PurePath(path).name.replace("-", "_").replace(".", "_")
|
|
# this suffix is very wrong (the real one is far less likely to be
|
|
# unique), but it seems difficult to find a rule that reproduces the
|
|
# real suffixes, so here's one that, while inaccurate, is at least
|
|
# hopefully as straightforward as possible
|
|
suffix = hashlib.md5(str.encode(middle)).hexdigest()[:8]
|
|
return f"_{len(middle)}_{middle}_{suffix}"
|
|
|
|
|
|
def unique_module_id_files(commands: List[str]) -> List[str]:
|
|
"""
|
|
Give each command its own .module_id filename instead of sharing.
|
|
"""
|
|
module_id = None
|
|
uniqueified = []
|
|
for i, line in enumerate(commands):
|
|
arr = []
|
|
|
|
def uniqueify(s: Match[str]) -> str:
|
|
filename = re.sub(r"\-(\d+)", r"-\1-" + str(i), s.group(0))
|
|
arr.append(filename)
|
|
return filename
|
|
|
|
line = re.sub(re_tmp + r".module_id", uniqueify, line)
|
|
line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
|
|
if arr:
|
|
(filename,) = arr
|
|
if not module_id:
|
|
module_id = module_id_contents(shlex.split(line))
|
|
uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
|
|
uniqueified.append(line)
|
|
return uniqueified
|
|
|
|
|
|
def make_rm_force(commands: List[str]) -> List[str]:
|
|
"""
|
|
Add --force to all rm commands.
|
|
"""
|
|
return [f"{c} --force" if c.startswith("rm ") else c for c in commands]
|
|
|
|
|
|
def print_verbose_output(
|
|
*,
|
|
env: Dict[str, str],
|
|
commands: List[List[str]],
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Human-readably write nvcc --dryrun data to stderr.
|
|
"""
|
|
padding = len(str(len(commands) - 1))
|
|
with open(filename, "w") as f:
|
|
for name, val in env.items():
|
|
print(f'#{" "*padding}$ {name}={val}', file=f)
|
|
for i, command in enumerate(commands):
|
|
prefix = f"{str(i).rjust(padding)}$ "
|
|
print(f"#{prefix}{command[0]}", file=f)
|
|
for part in command[1:]:
|
|
print(f'#{" "*len(prefix)}{part}', file=f)
|
|
|
|
|
|
Graph = List[Set[int]]
|
|
|
|
|
|
def straight_line_dependencies(commands: List[str]) -> Graph:
|
|
"""
|
|
Return a straight-line dependency graph.
|
|
"""
|
|
return [({i - 1} if i > 0 else set()) for i in range(len(commands))]
|
|
|
|
|
|
def files_mentioned(command: str) -> List[str]:
|
|
"""
|
|
Return fully-qualified names of all tmp files referenced by command.
|
|
"""
|
|
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
|
|
|
|
|
|
def nvcc_data_dependencies(commands: List[str]) -> Graph:
|
|
"""
|
|
Return a list of the set of dependencies for each command.
|
|
"""
|
|
# fatbin needs to be treated specially because while the cicc steps
|
|
# do refer to .fatbin.c files, they do so through the
|
|
# --include_file_name option, since they're generating files that
|
|
# refer to .fatbin.c file(s) that will later be created by the
|
|
# fatbinary step; so for most files, we make a data dependency from
|
|
# the later step to the earlier step, but for .fatbin.c files, the
|
|
# data dependency is sort of flipped, because the steps that use the
|
|
# files generated by cicc need to wait for the fatbinary step to
|
|
# finish first
|
|
tmp_files: Dict[str, int] = {}
|
|
fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set)
|
|
graph = []
|
|
for i, line in enumerate(commands):
|
|
deps = set()
|
|
for tmp in files_mentioned(line):
|
|
if tmp in tmp_files:
|
|
dep = tmp_files[tmp]
|
|
deps.add(dep)
|
|
if dep in fatbins:
|
|
for filename in fatbins[dep]:
|
|
if filename in tmp_files:
|
|
deps.add(tmp_files[filename])
|
|
if tmp.endswith(".fatbin.c") and not line.startswith("fatbinary"):
|
|
fatbins[i].add(tmp)
|
|
else:
|
|
tmp_files[tmp] = i
|
|
if line.startswith("rm ") and not deps:
|
|
deps.add(i - 1)
|
|
graph.append(deps)
|
|
return graph
|
|
|
|
|
|
def is_weakly_connected(graph: Graph) -> bool:
|
|
"""
|
|
Return true iff graph is weakly connected.
|
|
"""
|
|
if not graph:
|
|
return True
|
|
neighbors: List[Set[int]] = [set() for _ in graph]
|
|
for node, predecessors in enumerate(graph):
|
|
for pred in predecessors:
|
|
neighbors[pred].add(node)
|
|
neighbors[node].add(pred)
|
|
# assume nonempty graph
|
|
stack = [0]
|
|
found = {0}
|
|
while stack:
|
|
node = stack.pop()
|
|
for neighbor in neighbors[node]:
|
|
if neighbor not in found:
|
|
found.add(neighbor)
|
|
stack.append(neighbor)
|
|
return len(found) == len(graph)
|
|
|
|
|
|
def warn_if_not_weakly_connected(graph: Graph) -> None:
|
|
"""
|
|
Warn the user if the execution graph is not weakly connected.
|
|
"""
|
|
if not is_weakly_connected(graph):
|
|
fast_nvcc_warn("execution graph is not (weakly) connected")
|
|
|
|
|
|
def print_dot_graph(
|
|
*,
|
|
commands: List[List[str]],
|
|
graph: Graph,
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Print a DOT file displaying short versions of the commands in graph.
|
|
"""
|
|
|
|
def name(k: int) -> str:
|
|
return f'"{k} {os.path.basename(commands[k][0])}"'
|
|
|
|
with open(filename, "w") as f:
|
|
print("digraph {", file=f)
|
|
# print all nodes, in case it's disconnected
|
|
for i in range(len(graph)):
|
|
print(f" {name(i)};", file=f)
|
|
for i, deps in enumerate(graph):
|
|
for j in deps:
|
|
print(f" {name(j)} -> {name(i)};", file=f)
|
|
print("}", file=f)
|
|
|
|
|
|
class Result(TypedDict, total=False):
|
|
exit_code: int
|
|
stdout: bytes
|
|
stderr: bytes
|
|
time: float
|
|
files: Dict[str, int]
|
|
|
|
|
|
async def run_command(
|
|
command: str,
|
|
*,
|
|
env: Dict[str, str],
|
|
deps: Set[Awaitable[Result]],
|
|
gather_data: bool,
|
|
i: int,
|
|
save: Optional[str],
|
|
) -> Result:
|
|
"""
|
|
Run the command with the given env after waiting for deps.
|
|
"""
|
|
for task in deps:
|
|
dep_result = await task
|
|
# abort if a previous step failed
|
|
if "exit_code" not in dep_result or dep_result["exit_code"] != 0:
|
|
return {}
|
|
if gather_data:
|
|
t1 = time.monotonic()
|
|
proc = await asyncio.create_subprocess_shell(
|
|
command,
|
|
env=env,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await proc.communicate()
|
|
code = cast(int, proc.returncode)
|
|
results: Result = {"exit_code": code, "stdout": stdout, "stderr": stderr}
|
|
if gather_data:
|
|
t2 = time.monotonic()
|
|
results["time"] = t2 - t1
|
|
sizes = {}
|
|
for tmp_file in files_mentioned(command):
|
|
if os.path.exists(tmp_file):
|
|
sizes[tmp_file] = os.path.getsize(tmp_file)
|
|
else:
|
|
sizes[tmp_file] = 0
|
|
results["files"] = sizes
|
|
if save:
|
|
dest = pathlib.Path(save) / str(i)
|
|
dest.mkdir()
|
|
for src in map(pathlib.Path, files_mentioned(command)):
|
|
if src.exists():
|
|
shutil.copy2(src, dest / (src.name))
|
|
return results
|
|
|
|
|
|
async def run_graph(
|
|
*,
|
|
env: Dict[str, str],
|
|
commands: List[str],
|
|
graph: Graph,
|
|
gather_data: bool = False,
|
|
save: Optional[str] = None,
|
|
) -> List[Result]:
|
|
"""
|
|
Return outputs/errors (and optionally time/file info) from commands.
|
|
"""
|
|
tasks: List[Awaitable[Result]] = []
|
|
for i, (command, indices) in enumerate(zip(commands, graph)):
|
|
deps = {tasks[j] for j in indices}
|
|
tasks.append(
|
|
asyncio.create_task(
|
|
run_command( # type: ignore[attr-defined]
|
|
command,
|
|
env=env,
|
|
deps=deps,
|
|
gather_data=gather_data,
|
|
i=i,
|
|
save=save,
|
|
)
|
|
)
|
|
)
|
|
return [await task for task in tasks]
|
|
|
|
|
|
def print_command_outputs(command_results: List[Result]) -> None:
|
|
"""
|
|
Print captured stdout and stderr from commands.
|
|
"""
|
|
for result in command_results:
|
|
sys.stdout.write(result.get("stdout", b"").decode("ascii"))
|
|
sys.stderr.write(result.get("stderr", b"").decode("ascii"))
|
|
|
|
|
|
def write_log_csv(
|
|
command_parts: List[List[str]],
|
|
command_results: List[Result],
|
|
*,
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Write a CSV file of the times and /tmp file sizes from each command.
|
|
"""
|
|
tmp_files: List[str] = []
|
|
for result in command_results:
|
|
tmp_files.extend(result.get("files", {}).keys())
|
|
with open(filename, "w", newline="") as csvfile:
|
|
fieldnames = ["command", "seconds"] + list(dict.fromkeys(tmp_files))
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for i, result in enumerate(command_results):
|
|
command = f"{i} {os.path.basename(command_parts[i][0])}"
|
|
row = {"command": command, "seconds": result.get("time", 0)}
|
|
writer.writerow({**row, **result.get("files", {})})
|
|
|
|
|
|
def exit_code(results: List[Result]) -> int:
|
|
"""
|
|
Aggregate individual exit codes into a single code.
|
|
"""
|
|
for result in results:
|
|
code = result.get("exit_code", 0)
|
|
if code != 0:
|
|
return code
|
|
return 0
|
|
|
|
|
|
def wrap_nvcc(
|
|
args: List[str],
|
|
config: argparse.Namespace = default_config,
|
|
) -> int:
|
|
return subprocess.call([config.nvcc] + args)
|
|
|
|
|
|
def fast_nvcc(
|
|
args: List[str],
|
|
*,
|
|
config: argparse.Namespace = default_config,
|
|
) -> int:
|
|
"""
|
|
Emulate the result of calling the given nvcc binary with args.
|
|
|
|
Should run faster than plain nvcc.
|
|
"""
|
|
warn_if_windows()
|
|
warn_if_tmpdir_flag(args)
|
|
dryrun_data = nvcc_dryrun_data(config.nvcc, args)
|
|
env = dryrun_data["env"]
|
|
warn_if_tmpdir_set(env)
|
|
commands = dryrun_data["commands"]
|
|
if not config.faithful:
|
|
commands = make_rm_force(unique_module_id_files(commands))
|
|
|
|
if contains_non_executable(commands):
|
|
return wrap_nvcc(args, config)
|
|
|
|
command_parts = list(map(shlex.split, commands))
|
|
if config.verbose:
|
|
print_verbose_output(
|
|
env=env,
|
|
commands=command_parts,
|
|
filename=config.verbose,
|
|
)
|
|
graph = nvcc_data_dependencies(commands)
|
|
warn_if_not_weakly_connected(graph)
|
|
if config.graph:
|
|
print_dot_graph(
|
|
commands=command_parts,
|
|
graph=graph,
|
|
filename=config.graph,
|
|
)
|
|
if config.sequential:
|
|
graph = straight_line_dependencies(commands)
|
|
results = asyncio.run(
|
|
run_graph( # type: ignore[attr-defined]
|
|
env=env,
|
|
commands=commands,
|
|
graph=graph,
|
|
gather_data=bool(config.table),
|
|
save=config.save,
|
|
)
|
|
)
|
|
print_command_outputs(results)
|
|
if config.table:
|
|
write_log_csv(command_parts, results, filename=config.table)
|
|
return exit_code([dryrun_data] + results) # type: ignore[arg-type, operator]
|
|
|
|
|
|
def our_arg(arg: str) -> bool:
|
|
return arg != "--"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argv = sys.argv[1:]
|
|
us = list(itertools.takewhile(our_arg, argv))
|
|
them = list(itertools.dropwhile(our_arg, argv))
|
|
sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us)))
|