mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix max_width computation in _tensor_str._Formatter (#126859)
Previous version of `torch._tensor_str._Formatter` was not using `PRINT_OPTS.sci_mode` for the `max_width` computation but was using it for the formatting of values leading to a weird discrepancy.
Now, the code first checks if it should be in sci_mode, then compute `max_width`
Here is an example to test the behavior:
```python
A = torch.tensor([10, 1e-1, 1e-2])
B = torch.tensor([10, 1e-1, 1e-1])
print("================= Default =================")
print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")
print("================= sci_mode=False =================")
with torch._tensor_str.printoptions(sci_mode=False):
print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")
print("================= sci_mode=True =================")
with torch._tensor_str.printoptions(sci_mode=True):
print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")
```
In the current version this prints:
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([ 10.0000, 0.1000, 0.0100]) Formatter max_width: 10
tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 7
```
On can see that in `sci_mode=False`, the values of A are prefixed with unneeded 0 and does not have the same `max_width` as B (It keeps the `max_width` from `sci_mode = None`)
Also in `sci_mode = True`, for B, the `max_width` is 7 but each value takes 10 chars... (But it is fine as the code that uses `max_width` do not rely much on it, but still, this is missleading)
After this commit, this will print
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([10.0000, 0.1000, 0.0100]) Formatter max_width: 7
tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 10
```
This also allows to align A with B for `sci_mode=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126859
Approved by: https://github.com/malfet
This commit is contained in:
parent
b0b3e6e48b
commit
ee2649219c
|
|
@ -8337,7 +8337,7 @@ class TestTorch(TestCase):
|
|||
self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''')
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
self.assertEqual(x.__repr__(), str(x))
|
||||
self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''')
|
||||
self.assertExpectedInline(str(x), '''tensor([100.0000, 0.0100])''')
|
||||
torch.set_printoptions(sci_mode=None) # reset to the default value
|
||||
|
||||
# test no leading space if all elements positive
|
||||
|
|
|
|||
|
|
@ -178,14 +178,18 @@ class _Formatter:
|
|||
self.int_mode = False
|
||||
break
|
||||
|
||||
self.sci_mode = (
|
||||
nonzero_finite_max / nonzero_finite_min > 1000.0
|
||||
or nonzero_finite_max > 1.0e8
|
||||
or nonzero_finite_min < 1.0e-4
|
||||
if PRINT_OPTS.sci_mode is None
|
||||
else PRINT_OPTS.sci_mode
|
||||
)
|
||||
|
||||
if self.int_mode:
|
||||
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
||||
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
||||
if (
|
||||
nonzero_finite_max / nonzero_finite_min > 1000.0
|
||||
or nonzero_finite_max > 1.0e8
|
||||
):
|
||||
self.sci_mode = True
|
||||
if self.sci_mode:
|
||||
for value in nonzero_finite_vals:
|
||||
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
|
|
@ -195,12 +199,7 @@ class _Formatter:
|
|||
self.max_width = max(self.max_width, len(value_str) + 1)
|
||||
else:
|
||||
# Check if scientific representation should be used.
|
||||
if (
|
||||
nonzero_finite_max / nonzero_finite_min > 1000.0
|
||||
or nonzero_finite_max > 1.0e8
|
||||
or nonzero_finite_min < 1.0e-4
|
||||
):
|
||||
self.sci_mode = True
|
||||
if self.sci_mode:
|
||||
for value in nonzero_finite_vals:
|
||||
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
|
|
@ -209,9 +208,6 @@ class _Formatter:
|
|||
value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
|
||||
if PRINT_OPTS.sci_mode is not None:
|
||||
self.sci_mode = PRINT_OPTS.sci_mode
|
||||
|
||||
def width(self):
|
||||
return self.max_width
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user