mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
There are 3 progress bars each gated behind their own config, all off by default for now 1. Dynamo: Macro level config for dynamo, AOT, inductor 2. FX: Progress bar for each pass, with their names 3. Inductor Pull Request resolved: https://github.com/pytorch/pytorch/pull/88384 Approved by: https://github.com/wconstab, https://github.com/mlazos
111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
import itertools
|
|
import logging
|
|
import os
|
|
|
|
from torch.hub import Faketqdm, tqdm
|
|
|
|
# logging level for dynamo generated graphs/bytecode/guards
|
|
logging.CODE = 15
|
|
logging.addLevelName(logging.CODE, "CODE")
|
|
|
|
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
|
|
disable_progress = True
|
|
|
|
# Return all loggers that torchdynamo/torchinductor is responsible for
|
|
def get_loggers():
|
|
return [
|
|
logging.getLogger("torch._dynamo"),
|
|
logging.getLogger("torch._inductor"),
|
|
]
|
|
|
|
|
|
# Set the level of all loggers that torchdynamo is responsible for
|
|
def set_loggers_level(level):
|
|
for logger in get_loggers():
|
|
logger.setLevel(level)
|
|
|
|
|
|
LOGGING_CONFIG = {
|
|
"version": 1,
|
|
"formatters": {
|
|
"torchdynamo_format": {
|
|
"format": "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
|
|
},
|
|
},
|
|
"handlers": {
|
|
"torchdynamo_console": {
|
|
"class": "logging.StreamHandler",
|
|
"level": "DEBUG",
|
|
"formatter": "torchdynamo_format",
|
|
"stream": "ext://sys.stderr",
|
|
},
|
|
},
|
|
"loggers": {
|
|
"torch._dynamo": {
|
|
"level": "DEBUG",
|
|
"handlers": ["torchdynamo_console"],
|
|
"propagate": False,
|
|
},
|
|
"torch._inductor": {
|
|
"level": "DEBUG",
|
|
"handlers": ["torchdynamo_console"],
|
|
"propagate": False,
|
|
},
|
|
},
|
|
"disable_existing_loggers": False,
|
|
}
|
|
|
|
|
|
# initialize torchdynamo loggers
|
|
def init_logging(log_level, log_file_name=None):
|
|
if "PYTEST_CURRENT_TEST" not in os.environ:
|
|
logging.config.dictConfig(LOGGING_CONFIG)
|
|
if log_file_name is not None:
|
|
log_file = logging.FileHandler(log_file_name)
|
|
log_file.setLevel(log_level)
|
|
for logger in get_loggers():
|
|
logger.addHandler(log_file)
|
|
|
|
set_loggers_level(log_level)
|
|
|
|
|
|
# Creates a logging function that logs a message with a step # prepended.
|
|
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
|
|
# so that step numbers are initialized properly. e.g.:
|
|
|
|
# @functools.lru_cache(None)
|
|
# def _step_logger():
|
|
# return get_step_logger(logging.getLogger(...))
|
|
|
|
# def fn():
|
|
# _step_logger()(logging.INFO, "msg")
|
|
|
|
_step_counter = itertools.count(1)
|
|
|
|
# Update num_steps if more phases are added: Dynamo, AOT, Backend
|
|
# This is very inductor centric
|
|
# _inductor.utils.has_triton() gives a circular import error here
|
|
|
|
if not disable_progress:
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
num_steps = 3
|
|
except ImportError:
|
|
num_steps = 2
|
|
pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
|
|
|
|
|
|
def get_step_logger(logger):
|
|
if not disable_progress:
|
|
pbar.update(1)
|
|
if not isinstance(pbar, Faketqdm):
|
|
pbar.set_postfix_str(f"{logger.name}")
|
|
|
|
step = next(_step_counter)
|
|
|
|
def log(level, msg):
|
|
logger.log(level, f"Step {step}: {msg}")
|
|
|
|
return log
|