mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Lazy import to avoid circular import issue for DebugMode (#163381)
as title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163381 Approved by: https://github.com/dolpm
This commit is contained in:
parent
bfe9e60ffb
commit
a1df0b42ce
|
|
@ -2,9 +2,7 @@
|
|||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.distributed.tensor as dt
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map
|
||||
|
|
@ -29,7 +27,7 @@ def _stringify_placement(placement) -> str:
|
|||
|
||||
def _tensor_debug_string(tensor) -> str:
|
||||
"""Convert tensor to debug string representation."""
|
||||
if isinstance(tensor, dt.DTensor):
|
||||
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
||||
# omitted device mesh
|
||||
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
|
||||
elif isinstance(tensor, FakeTensor):
|
||||
|
|
@ -41,6 +39,8 @@ def _tensor_debug_string(tensor) -> str:
|
|||
|
||||
|
||||
def _arg_to_str(arg) -> str:
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
|
||||
def to_str(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return _tensor_debug_string(x)
|
||||
|
|
@ -86,6 +86,7 @@ class DebugMode(TorchDispatchMode):
|
|||
record_realtensor=True,
|
||||
):
|
||||
super().__init__()
|
||||
import torch.distributed.tensor # noqa: F401
|
||||
|
||||
self.record_torchfunction = record_torchfunction
|
||||
self.record_faketensor = record_faketensor
|
||||
|
|
@ -111,7 +112,7 @@ class DebugMode(TorchDispatchMode):
|
|||
kwargs = {}
|
||||
|
||||
# Record the operation with its call depth
|
||||
if dt.DTensor in types:
|
||||
if torch.distributed.tensor.DTensor in types:
|
||||
self.operators.append((func, args, kwargs, self.call_depth))
|
||||
return NotImplemented
|
||||
elif FakeTensor in types or isinstance(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user