fix tensor print behavior for MAIA (#155609)

This pull request fixes the tensor print behavior for `MAIA` to account for the absence of double-precision support in its backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155609
Approved by: https://github.com/soulitzer
This commit is contained in:
Jinhang Choi 2025-06-14 01:04:07 +00:00 committed by PyTorch MergeBot
parent dabb55baff
commit 04cf2c9d24

View File

@ -120,6 +120,7 @@ def tensor_totype(t):
if (
t.is_mps
or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64)
or t.is_maia
)
else torch.double
)
@ -167,8 +168,7 @@ class _Formatter:
# support for them is removed
nonzero_finite_vals = nonzero_finite_vals.float()
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
# Convert to double (or float) for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())