Lazy import to avoid circular import issue for DebugMode (#163381)

as title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163381
Approved by: https://github.com/dolpm
This commit is contained in:
Sherlock Huang 2025-09-20 01:54:54 +00:00 committed by PyTorch MergeBot
parent bfe9e60ffb
commit a1df0b42ce

View File

@ -2,9 +2,7 @@
import contextlib
import torch
import torch.distributed.tensor as dt
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
from torch.utils._pytree import tree_map
@ -29,7 +27,7 @@ def _stringify_placement(placement) -> str:
def _tensor_debug_string(tensor) -> str:
"""Convert tensor to debug string representation."""
if isinstance(tensor, dt.DTensor):
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
elif isinstance(tensor, FakeTensor):
@ -41,6 +39,8 @@ def _tensor_debug_string(tensor) -> str:
def _arg_to_str(arg) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
def to_str(x):
if isinstance(x, torch.Tensor):
return _tensor_debug_string(x)
@ -86,6 +86,7 @@ class DebugMode(TorchDispatchMode):
record_realtensor=True,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
self.record_torchfunction = record_torchfunction
self.record_faketensor = record_faketensor
@ -111,7 +112,7 @@ class DebugMode(TorchDispatchMode):
kwargs = {}
# Record the operation with its call depth
if dt.DTensor in types:
if torch.distributed.tensor.DTensor in types:
self.operators.append((func, args, kwargs, self.call_depth))
return NotImplemented
elif FakeTensor in types or isinstance(