mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
As Python-3.6 have reached EOL
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71494
Reviewed By: atalman
Differential Revision: D33667509
Pulled By: malfet
fbshipit-source-id: ab1f03085cfb9161df77ba5ce373b81f5e7ef3ae
(cherry picked from commit 60343166d9)
548 lines
17 KiB
Python
Executable File
548 lines
17 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import asyncio
|
|
import collections
|
|
import csv
|
|
import hashlib
|
|
import itertools
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import shlex
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from typing import (Awaitable, DefaultDict, Dict, List, Match, Optional, Set,
|
|
cast)
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]...
|
|
|
|
Run the commands given by nvcc --dryrun, in parallel.
|
|
|
|
All flags for this script itself (see the "optional arguments" section
|
|
of --help) must be passed before the first "--". Everything after that
|
|
first "--" is passed directly to nvcc, with the --dryrun argument added.
|
|
|
|
This script only works with the "normal" execution path of nvcc, so for
|
|
instance passing --help (after "--") doesn't work since the --help
|
|
execution path doesn't compile anything, so adding --dryrun there gives
|
|
nothing in stderr.
|
|
'''
|
|
parser = argparse.ArgumentParser(help_msg)
|
|
parser.add_argument(
|
|
'--faithful',
|
|
action='store_true',
|
|
help="don't modify the commands given by nvcc (slower)",
|
|
)
|
|
parser.add_argument(
|
|
'--graph',
|
|
metavar='FILE.gv',
|
|
help='write Graphviz DOT file with execution graph',
|
|
)
|
|
parser.add_argument(
|
|
'--nvcc',
|
|
metavar='PATH',
|
|
default='nvcc',
|
|
help='path to nvcc (default is just "nvcc")',
|
|
)
|
|
parser.add_argument(
|
|
'--save',
|
|
metavar='DIR',
|
|
help='copy intermediate files from each command into DIR',
|
|
)
|
|
parser.add_argument(
|
|
'--sequential',
|
|
action='store_true',
|
|
help='sequence commands instead of using the graph (slower)',
|
|
)
|
|
parser.add_argument(
|
|
'--table',
|
|
metavar='FILE.csv',
|
|
help='write CSV with times and intermediate file sizes',
|
|
)
|
|
parser.add_argument(
|
|
'--verbose',
|
|
metavar='FILE.txt',
|
|
help='like nvcc --verbose, but expanded and into a file',
|
|
)
|
|
default_config = parser.parse_args([])
|
|
|
|
|
|
# docs about temporary directories used by NVCC
|
|
url_base = 'https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html'
|
|
url_vars = f'{url_base}#keeping-intermediate-phase-files'
|
|
|
|
|
|
# regex for temporary file names
|
|
re_tmp = r'(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)'
|
|
|
|
|
|
def fast_nvcc_warn(warning: str) -> None:
|
|
"""
|
|
Warn the user about something regarding fast_nvcc.
|
|
"""
|
|
print(f'warning (fast_nvcc): {warning}', file=sys.stderr)
|
|
|
|
|
|
def warn_if_windows() -> None:
|
|
"""
|
|
Warn the user that using fast_nvcc on Windows might not work.
|
|
"""
|
|
# use os.name instead of platform.system() because there is a
|
|
# platform.py file in this directory, making it very difficult to
|
|
# import the platform module from the Python standard library
|
|
if os.name == 'nt':
|
|
fast_nvcc_warn("untested on Windows, might not work; see this URL:")
|
|
fast_nvcc_warn(url_vars)
|
|
|
|
|
|
def warn_if_tmpdir_flag(args: List[str]) -> None:
|
|
"""
|
|
Warn the user that using fast_nvcc with some flags might not work.
|
|
"""
|
|
file_path_specs = 'file-and-path-specifications'
|
|
guiding_driver = 'options-for-guiding-compiler-driver'
|
|
scary_flags = {
|
|
'--objdir-as-tempdir': file_path_specs,
|
|
'-objtemp': file_path_specs,
|
|
'--keep': guiding_driver,
|
|
'-keep': guiding_driver,
|
|
'--keep-dir': guiding_driver,
|
|
'-keep-dir': guiding_driver,
|
|
'--save-temps': guiding_driver,
|
|
'-save-temps': guiding_driver,
|
|
}
|
|
for arg in args:
|
|
for flag, frag in scary_flags.items():
|
|
if re.match(fr'^{re.escape(flag)}(?:=.*)?$', arg):
|
|
fast_nvcc_warn(f'{flag} not supported since it interacts with')
|
|
fast_nvcc_warn('TMPDIR, so fast_nvcc may break; see this URL:')
|
|
fast_nvcc_warn(f'{url_base}#{frag}')
|
|
|
|
|
|
class DryunData(TypedDict):
|
|
env: Dict[str, str]
|
|
commands: List[str]
|
|
exit_code: int
|
|
|
|
|
|
def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
|
|
"""
|
|
Return parsed environment variables and commands from nvcc --dryrun.
|
|
"""
|
|
result = subprocess.run( # type: ignore[call-overload]
|
|
[binary, '--dryrun'] + args,
|
|
capture_output=True,
|
|
encoding='ascii', # this is just a guess
|
|
)
|
|
print(result.stdout, end='')
|
|
env = {}
|
|
commands = []
|
|
for line in result.stderr.splitlines():
|
|
match = re.match(r'^#\$ (.*)$', line)
|
|
if match:
|
|
stripped, = match.groups()
|
|
mapping = re.match(r'^(\w+)=(.*)$', stripped)
|
|
if mapping:
|
|
name, val = mapping.groups()
|
|
env[name] = val
|
|
else:
|
|
commands.append(stripped)
|
|
else:
|
|
print(line, file=sys.stderr)
|
|
return {'env': env, 'commands': commands, 'exit_code': result.returncode}
|
|
|
|
|
|
def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
|
|
"""
|
|
Warn the user that setting TMPDIR with fast_nvcc might not work.
|
|
"""
|
|
if os.getenv('TMPDIR') or 'TMPDIR' in env:
|
|
fast_nvcc_warn("TMPDIR is set, might not work; see this URL:")
|
|
fast_nvcc_warn(url_vars)
|
|
|
|
|
|
def contains_non_executable(commands: List[str]) -> bool:
|
|
for command in commands:
|
|
# This is to deal with special command dry-run result from NVCC such as:
|
|
# ```
|
|
# #$ "/lib64/ccache"/c++ -std=c++11 -E -x c++ -D__CUDACC__ -D__NVCC__ -fPIC -fvisibility=hidden -O3 \
|
|
# -I ... -m64 "reduce_scatter.cu" > "/tmp/tmpxft_0037fae3_00000000-5_reduce_scatter.cpp4.ii
|
|
# #$ -- Filter Dependencies -- > ... pytorch/build/nccl/obj/collectives/device/reduce_scatter.dep.tmp
|
|
# ```
|
|
if command.startswith("--"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def module_id_contents(command: List[str]) -> str:
|
|
"""
|
|
Guess the contents of the .module_id file contained within command.
|
|
"""
|
|
if command[0] == 'cicc':
|
|
path = command[-3]
|
|
elif command[0] == 'cudafe++':
|
|
path = command[-1]
|
|
middle = pathlib.PurePath(path).name.replace('-', '_').replace('.', '_')
|
|
# this suffix is very wrong (the real one is far less likely to be
|
|
# unique), but it seems difficult to find a rule that reproduces the
|
|
# real suffixes, so here's one that, while inaccurate, is at least
|
|
# hopefully as straightforward as possible
|
|
suffix = hashlib.md5(str.encode(middle)).hexdigest()[:8]
|
|
return f'_{len(middle)}_{middle}_{suffix}'
|
|
|
|
|
|
def unique_module_id_files(commands: List[str]) -> List[str]:
|
|
"""
|
|
Give each command its own .module_id filename instead of sharing.
|
|
"""
|
|
module_id = None
|
|
uniqueified = []
|
|
for i, line in enumerate(commands):
|
|
arr = []
|
|
|
|
def uniqueify(s: Match[str]) -> str:
|
|
filename = re.sub(r'\-(\d+)', r'-\1-' + str(i), s.group(0))
|
|
arr.append(filename)
|
|
return filename
|
|
|
|
line = re.sub(re_tmp + r'.module_id', uniqueify, line)
|
|
line = re.sub(r'\s*\-\-gen\_module\_id\_file\s*', ' ', line)
|
|
if arr:
|
|
filename, = arr
|
|
if not module_id:
|
|
module_id = module_id_contents(shlex.split(line))
|
|
uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
|
|
uniqueified.append(line)
|
|
return uniqueified
|
|
|
|
|
|
def make_rm_force(commands: List[str]) -> List[str]:
|
|
"""
|
|
Add --force to all rm commands.
|
|
"""
|
|
return [f'{c} --force' if c.startswith('rm ') else c for c in commands]
|
|
|
|
|
|
def print_verbose_output(
|
|
*,
|
|
env: Dict[str, str],
|
|
commands: List[List[str]],
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Human-readably write nvcc --dryrun data to stderr.
|
|
"""
|
|
padding = len(str(len(commands) - 1))
|
|
with open(filename, 'w') as f:
|
|
for name, val in env.items():
|
|
print(f'#{" "*padding}$ {name}={val}', file=f)
|
|
for i, command in enumerate(commands):
|
|
prefix = f'{str(i).rjust(padding)}$ '
|
|
print(f'#{prefix}{command[0]}', file=f)
|
|
for part in command[1:]:
|
|
print(f'#{" "*len(prefix)}{part}', file=f)
|
|
|
|
|
|
Graph = List[Set[int]]
|
|
|
|
|
|
def straight_line_dependencies(commands: List[str]) -> Graph:
|
|
"""
|
|
Return a straight-line dependency graph.
|
|
"""
|
|
return [({i - 1} if i > 0 else set()) for i in range(len(commands))]
|
|
|
|
|
|
def files_mentioned(command: str) -> List[str]:
|
|
"""
|
|
Return fully-qualified names of all tmp files referenced by command.
|
|
"""
|
|
return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)]
|
|
|
|
|
|
def nvcc_data_dependencies(commands: List[str]) -> Graph:
|
|
"""
|
|
Return a list of the set of dependencies for each command.
|
|
"""
|
|
# fatbin needs to be treated specially because while the cicc steps
|
|
# do refer to .fatbin.c files, they do so through the
|
|
# --include_file_name option, since they're generating files that
|
|
# refer to .fatbin.c file(s) that will later be created by the
|
|
# fatbinary step; so for most files, we make a data dependency from
|
|
# the later step to the earlier step, but for .fatbin.c files, the
|
|
# data dependency is sort of flipped, because the steps that use the
|
|
# files generated by cicc need to wait for the fatbinary step to
|
|
# finish first
|
|
tmp_files: Dict[str, int] = {}
|
|
fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set)
|
|
graph = []
|
|
for i, line in enumerate(commands):
|
|
deps = set()
|
|
for tmp in files_mentioned(line):
|
|
if tmp in tmp_files:
|
|
dep = tmp_files[tmp]
|
|
deps.add(dep)
|
|
if dep in fatbins:
|
|
for filename in fatbins[dep]:
|
|
if filename in tmp_files:
|
|
deps.add(tmp_files[filename])
|
|
if tmp.endswith('.fatbin.c') and not line.startswith('fatbinary'):
|
|
fatbins[i].add(tmp)
|
|
else:
|
|
tmp_files[tmp] = i
|
|
if line.startswith('rm ') and not deps:
|
|
deps.add(i - 1)
|
|
graph.append(deps)
|
|
return graph
|
|
|
|
|
|
def is_weakly_connected(graph: Graph) -> bool:
|
|
"""
|
|
Return true iff graph is weakly connected.
|
|
"""
|
|
if not graph:
|
|
return True
|
|
neighbors: List[Set[int]] = [set() for _ in graph]
|
|
for node, predecessors in enumerate(graph):
|
|
for pred in predecessors:
|
|
neighbors[pred].add(node)
|
|
neighbors[node].add(pred)
|
|
# assume nonempty graph
|
|
stack = [0]
|
|
found = {0}
|
|
while stack:
|
|
node = stack.pop()
|
|
for neighbor in neighbors[node]:
|
|
if neighbor not in found:
|
|
found.add(neighbor)
|
|
stack.append(neighbor)
|
|
return len(found) == len(graph)
|
|
|
|
|
|
def warn_if_not_weakly_connected(graph: Graph) -> None:
|
|
"""
|
|
Warn the user if the execution graph is not weakly connected.
|
|
"""
|
|
if not is_weakly_connected(graph):
|
|
fast_nvcc_warn('execution graph is not (weakly) connected')
|
|
|
|
|
|
def print_dot_graph(
|
|
*,
|
|
commands: List[List[str]],
|
|
graph: Graph,
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Print a DOT file displaying short versions of the commands in graph.
|
|
"""
|
|
def name(k: int) -> str:
|
|
return f'"{k} {os.path.basename(commands[k][0])}"'
|
|
with open(filename, 'w') as f:
|
|
print('digraph {', file=f)
|
|
# print all nodes, in case it's disconnected
|
|
for i in range(len(graph)):
|
|
print(f' {name(i)};', file=f)
|
|
for i, deps in enumerate(graph):
|
|
for j in deps:
|
|
print(f' {name(j)} -> {name(i)};', file=f)
|
|
print('}', file=f)
|
|
|
|
|
|
|
|
class Result(TypedDict, total=False):
|
|
exit_code: int
|
|
stdout: bytes
|
|
stderr: bytes
|
|
time: float
|
|
files: Dict[str, int]
|
|
|
|
|
|
async def run_command(
|
|
command: str,
|
|
*,
|
|
env: Dict[str, str],
|
|
deps: Set[Awaitable[Result]],
|
|
gather_data: bool,
|
|
i: int,
|
|
save: Optional[str],
|
|
) -> Result:
|
|
"""
|
|
Run the command with the given env after waiting for deps.
|
|
"""
|
|
for task in deps:
|
|
dep_result = await task
|
|
# abort if a previous step failed
|
|
if 'exit_code' not in dep_result or dep_result['exit_code'] != 0:
|
|
return {}
|
|
if gather_data:
|
|
t1 = time.monotonic()
|
|
proc = await asyncio.create_subprocess_shell(
|
|
command,
|
|
env=env,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await proc.communicate()
|
|
code = cast(int, proc.returncode)
|
|
results: Result = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
|
|
if gather_data:
|
|
t2 = time.monotonic()
|
|
results['time'] = t2 - t1
|
|
sizes = {}
|
|
for tmp_file in files_mentioned(command):
|
|
if os.path.exists(tmp_file):
|
|
sizes[tmp_file] = os.path.getsize(tmp_file)
|
|
else:
|
|
sizes[tmp_file] = 0
|
|
results['files'] = sizes
|
|
if save:
|
|
dest = pathlib.Path(save) / str(i)
|
|
dest.mkdir()
|
|
for src in map(pathlib.Path, files_mentioned(command)):
|
|
if src.exists():
|
|
shutil.copy2(src, dest / (src.name))
|
|
return results
|
|
|
|
|
|
async def run_graph(
|
|
*,
|
|
env: Dict[str, str],
|
|
commands: List[str],
|
|
graph: Graph,
|
|
gather_data: bool = False,
|
|
save: Optional[str] = None,
|
|
) -> List[Result]:
|
|
"""
|
|
Return outputs/errors (and optionally time/file info) from commands.
|
|
"""
|
|
tasks: List[Awaitable[Result]] = []
|
|
for i, (command, indices) in enumerate(zip(commands, graph)):
|
|
deps = {tasks[j] for j in indices}
|
|
tasks.append(asyncio.create_task(run_command( # type: ignore[attr-defined]
|
|
command,
|
|
env=env,
|
|
deps=deps,
|
|
gather_data=gather_data,
|
|
i=i,
|
|
save=save,
|
|
)))
|
|
return [await task for task in tasks]
|
|
|
|
|
|
def print_command_outputs(command_results: List[Result]) -> None:
|
|
"""
|
|
Print captured stdout and stderr from commands.
|
|
"""
|
|
for result in command_results:
|
|
sys.stdout.write(result.get('stdout', b'').decode('ascii'))
|
|
sys.stderr.write(result.get('stderr', b'').decode('ascii'))
|
|
|
|
|
|
def write_log_csv(
|
|
command_parts: List[List[str]],
|
|
command_results: List[Result],
|
|
*,
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Write a CSV file of the times and /tmp file sizes from each command.
|
|
"""
|
|
tmp_files: List[str] = []
|
|
for result in command_results:
|
|
tmp_files.extend(result.get('files', {}).keys())
|
|
with open(filename, 'w', newline='') as csvfile:
|
|
fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files))
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for i, result in enumerate(command_results):
|
|
command = f'{i} {os.path.basename(command_parts[i][0])}'
|
|
row = {'command': command, 'seconds': result.get('time', 0)}
|
|
writer.writerow({**row, **result.get('files', {})})
|
|
|
|
|
|
def exit_code(results: List[Result]) -> int:
|
|
"""
|
|
Aggregate individual exit codes into a single code.
|
|
"""
|
|
for result in results:
|
|
code = result.get('exit_code', 0)
|
|
if code != 0:
|
|
return code
|
|
return 0
|
|
|
|
|
|
def wrap_nvcc(
|
|
args: List[str],
|
|
config: argparse.Namespace = default_config,
|
|
) -> int:
|
|
return subprocess.call([config.nvcc] + args)
|
|
|
|
|
|
def fast_nvcc(
|
|
args: List[str],
|
|
*,
|
|
config: argparse.Namespace = default_config,
|
|
) -> int:
|
|
"""
|
|
Emulate the result of calling the given nvcc binary with args.
|
|
|
|
Should run faster than plain nvcc.
|
|
"""
|
|
warn_if_windows()
|
|
warn_if_tmpdir_flag(args)
|
|
dryrun_data = nvcc_dryrun_data(config.nvcc, args)
|
|
env = dryrun_data['env']
|
|
warn_if_tmpdir_set(env)
|
|
commands = dryrun_data['commands']
|
|
if not config.faithful:
|
|
commands = make_rm_force(unique_module_id_files(commands))
|
|
|
|
if contains_non_executable(commands):
|
|
return wrap_nvcc(args, config)
|
|
|
|
command_parts = list(map(shlex.split, commands))
|
|
if config.verbose:
|
|
print_verbose_output(
|
|
env=env,
|
|
commands=command_parts,
|
|
filename=config.verbose,
|
|
)
|
|
graph = nvcc_data_dependencies(commands)
|
|
warn_if_not_weakly_connected(graph)
|
|
if config.graph:
|
|
print_dot_graph(
|
|
commands=command_parts,
|
|
graph=graph,
|
|
filename=config.graph,
|
|
)
|
|
if config.sequential:
|
|
graph = straight_line_dependencies(commands)
|
|
results = asyncio.run(run_graph( # type: ignore[attr-defined]
|
|
env=env,
|
|
commands=commands,
|
|
graph=graph,
|
|
gather_data=bool(config.table),
|
|
save=config.save,
|
|
))
|
|
print_command_outputs(results)
|
|
if config.table:
|
|
write_log_csv(command_parts, results, filename=config.table)
|
|
return exit_code([dryrun_data] + results) # type: ignore[arg-type, operator]
|
|
|
|
|
|
def our_arg(arg: str) -> bool:
|
|
return arg != '--'
|
|
|
|
|
|
if __name__ == '__main__':
|
|
argv = sys.argv[1:]
|
|
us = list(itertools.takewhile(our_arg, argv))
|
|
them = list(itertools.dropwhile(our_arg, argv))
|
|
sys.exit(fast_nvcc(them[1:], config=parser.parse_args(us)))
|