diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 80a7d35cd4c..f17348e401c 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -310,7 +310,7 @@ def escape(n): def is_cuda_tensor(obj): - return isinstance(obj, torch.Tensor) and obj.is_cuda + return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor) def cuda_allocation_context(): snapshot = torch.cuda.memory._snapshot()