mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add separate logging target for cudagraphs (#118329)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118329 Approved by: https://github.com/mlazos
This commit is contained in:
parent
e180218949
commit
e33e88e5bc
|
|
@ -103,6 +103,14 @@ class LoggingTests(LoggingTestCase):
|
|||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 8)
|
||||
|
||||
@requires_cuda
|
||||
@make_logging_test(cudagraphs=True)
|
||||
def test_cudagraphs(self, records):
|
||||
fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn)
|
||||
fn_opt(torch.ones(1000, 1000, device="cuda"))
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 8)
|
||||
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_recompiles(self, records):
|
||||
def fn(x, y):
|
||||
|
|
@ -684,6 +692,7 @@ fn(torch.randn(5))
|
|||
# single record tests
|
||||
exclusions = {
|
||||
"bytecode",
|
||||
"cudagraphs",
|
||||
"output_code",
|
||||
"schedule",
|
||||
"fusion",
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ import dataclasses
|
|||
import functools
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
import threading
|
||||
|
|
@ -102,7 +101,8 @@ else:
|
|||
pass
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
|
||||
|
||||
|
||||
from . import config
|
||||
|
||||
|
|
|
|||
|
|
@ -200,6 +200,7 @@ def set_logs(
|
|||
overlap: bool = False,
|
||||
export: Optional[int] = None,
|
||||
modules: Optional[Dict[str, Union[int, bool]]] = None,
|
||||
cudagraphs: bool = False,
|
||||
):
|
||||
"""
|
||||
Sets the log level for individual components and toggles individual log
|
||||
|
|
@ -284,6 +285,9 @@ def set_logs(
|
|||
aot_joint_graph (:class:`bool`):
|
||||
Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
|
||||
|
||||
inductor (:class:`Optional[int]`):
|
||||
Whether to log information from inductor cudagraphs. Default: ``logging.WARN``
|
||||
|
||||
ddp_graphs (:class:`bool`):
|
||||
Whether to emit graphs generated by DDPOptimizer. Default: ``False``
|
||||
|
||||
|
|
@ -445,6 +449,7 @@ def set_logs(
|
|||
fusion=fusion,
|
||||
overlap=overlap,
|
||||
export=export,
|
||||
cudagraphs=cudagraphs,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,13 @@ DISTRIBUTED = [
|
|||
register_log("dynamo", ["torch._dynamo", *DYNAMIC])
|
||||
register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"])
|
||||
register_log("autograd", "torch.autograd")
|
||||
register_log("inductor", "torch._inductor")
|
||||
register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"])
|
||||
|
||||
register_artifact(
|
||||
"cudagraphs",
|
||||
"Logs information from wrapping inductor generated code with cudagraphs.",
|
||||
)
|
||||
|
||||
register_log("dynamic", DYNAMIC)
|
||||
register_log("torch", "torch")
|
||||
register_log("distributed", DISTRIBUTED)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user