mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40513 This PR makes the following changes: 1. Complex Printing now uses print formatting for it's real and imaginary values and they are joined at the end. 2. Adding 1. naturally fixes the printing of complex tensors in sci_mode=True ``` >>> torch.tensor(float('inf')+float('inf')*1j) tensor(nan+infj) >>> torch.randn(2000, dtype=torch.cfloat) tensor([ 0.3015-0.2502j, -1.1102+1.2218j, -0.6324+0.0640j, ..., -1.0200-0.2302j, 0.6511-0.1889j, -0.1069+0.1702j]) >>> torch.tensor([1e-3, 3+4j, 1e-5j, 1e-2+3j, 5+1e-6j]) tensor([1.0000e-03+0.0000e+00j, 3.0000e+00+4.0000e+00j, 0.0000e+00+1.0000e-05j, 1.0000e-02+3.0000e+00j, 5.0000e+00+1.0000e-06j]) >>> torch.randn(3, dtype=torch.cfloat) tensor([ 1.0992-0.4459j, 1.1073+0.1202j, -0.2177-0.6342j]) >>> x = torch.tensor([1e2, 1e-2]) >>> torch.set_printoptions(sci_mode=False) >>> x tensor([ 100.0000, 0.0100]) >>> x = torch.tensor([1e2, 1e-2j]) >>> x tensor([100.+0.0000j, 0.+0.0100j]) ``` Test Plan: Imported from OSS Differential Revision: D22309294 Pulled By: anjali411 fbshipit-source-id: 20edf9e28063725aeff39f3a246a2d7f348ff1e8
372 lines
16 KiB
Python
372 lines
16 KiB
Python
import math
|
|
import torch
|
|
from torch._six import inf
|
|
|
|
|
|
class __PrinterOptions(object):
|
|
precision = 4
|
|
threshold = 1000
|
|
edgeitems = 3
|
|
linewidth = 80
|
|
sci_mode = None
|
|
|
|
|
|
PRINT_OPTS = __PrinterOptions()
|
|
|
|
|
|
# We could use **kwargs, but this will give better docs
|
|
def set_printoptions(
|
|
precision=None,
|
|
threshold=None,
|
|
edgeitems=None,
|
|
linewidth=None,
|
|
profile=None,
|
|
sci_mode=None
|
|
):
|
|
r"""Set options for printing. Items shamelessly taken from NumPy
|
|
|
|
Args:
|
|
precision: Number of digits of precision for floating point output
|
|
(default = 4).
|
|
threshold: Total number of array elements which trigger summarization
|
|
rather than full `repr` (default = 1000).
|
|
edgeitems: Number of array items in summary at beginning and end of
|
|
each dimension (default = 3).
|
|
linewidth: The number of characters per line for the purpose of
|
|
inserting line breaks (default = 80). Thresholded matrices will
|
|
ignore this parameter.
|
|
profile: Sane defaults for pretty printing. Can override with any of
|
|
the above options. (any one of `default`, `short`, `full`)
|
|
sci_mode: Enable (True) or disable (False) scientific notation. If
|
|
None (default) is specified, the value is defined by
|
|
`torch._tensor_str._Formatter`. This value is automatically chosen
|
|
by the framework.
|
|
"""
|
|
if profile is not None:
|
|
if profile == "default":
|
|
PRINT_OPTS.precision = 4
|
|
PRINT_OPTS.threshold = 1000
|
|
PRINT_OPTS.edgeitems = 3
|
|
PRINT_OPTS.linewidth = 80
|
|
elif profile == "short":
|
|
PRINT_OPTS.precision = 2
|
|
PRINT_OPTS.threshold = 1000
|
|
PRINT_OPTS.edgeitems = 2
|
|
PRINT_OPTS.linewidth = 80
|
|
elif profile == "full":
|
|
PRINT_OPTS.precision = 4
|
|
PRINT_OPTS.threshold = inf
|
|
PRINT_OPTS.edgeitems = 3
|
|
PRINT_OPTS.linewidth = 80
|
|
|
|
if precision is not None:
|
|
PRINT_OPTS.precision = precision
|
|
if threshold is not None:
|
|
PRINT_OPTS.threshold = threshold
|
|
if edgeitems is not None:
|
|
PRINT_OPTS.edgeitems = edgeitems
|
|
if linewidth is not None:
|
|
PRINT_OPTS.linewidth = linewidth
|
|
PRINT_OPTS.sci_mode = sci_mode
|
|
|
|
|
|
class _Formatter(object):
|
|
def __init__(self, tensor):
|
|
self.floating_dtype = tensor.dtype.is_floating_point
|
|
self.int_mode = True
|
|
self.sci_mode = False
|
|
self.max_width = 1
|
|
|
|
with torch.no_grad():
|
|
tensor_view = tensor.reshape(-1)
|
|
|
|
if not self.floating_dtype:
|
|
for value in tensor_view:
|
|
value_str = '{}'.format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
|
|
else:
|
|
nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
|
|
|
|
if nonzero_finite_vals.numel() == 0:
|
|
# no valid number, do nothing
|
|
return
|
|
|
|
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
|
nonzero_finite_abs = nonzero_finite_vals.abs().double()
|
|
nonzero_finite_min = nonzero_finite_abs.min().double()
|
|
nonzero_finite_max = nonzero_finite_abs.max().double()
|
|
|
|
for value in nonzero_finite_vals:
|
|
if value != torch.ceil(value):
|
|
self.int_mode = False
|
|
break
|
|
|
|
if self.int_mode:
|
|
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
|
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
|
if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8:
|
|
self.sci_mode = True
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
else:
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{:.0f}').format(value)
|
|
self.max_width = max(self.max_width, len(value_str) + 1)
|
|
else:
|
|
# Check if scientific representation should be used.
|
|
if nonzero_finite_max / nonzero_finite_min > 1000.\
|
|
or nonzero_finite_max > 1.e8\
|
|
or nonzero_finite_min < 1.e-4:
|
|
self.sci_mode = True
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
else:
|
|
for value in nonzero_finite_vals:
|
|
value_str = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
|
|
self.max_width = max(self.max_width, len(value_str))
|
|
|
|
if PRINT_OPTS.sci_mode is not None:
|
|
self.sci_mode = PRINT_OPTS.sci_mode
|
|
|
|
def width(self):
|
|
return self.max_width
|
|
|
|
def format(self, value):
|
|
if self.floating_dtype:
|
|
if self.sci_mode:
|
|
ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
|
|
elif self.int_mode:
|
|
ret = '{:.0f}'.format(value)
|
|
if not (math.isinf(value) or math.isnan(value)):
|
|
ret += '.'
|
|
else:
|
|
ret = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
|
|
else:
|
|
ret = '{}'.format(value)
|
|
return (self.max_width - len(ret)) * ' ' + ret
|
|
|
|
|
|
def _scalar_str(self, formatter1, formatter2=None):
|
|
if formatter2 is not None:
|
|
real_str = _scalar_str(self.real, formatter1)
|
|
imag_str = _scalar_str(self.imag, formatter2) + "j"
|
|
if self.imag < 0:
|
|
return real_str + imag_str.lstrip()
|
|
else:
|
|
return real_str + "+" + imag_str.lstrip()
|
|
else:
|
|
return formatter1.format(self.item())
|
|
|
|
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
|
|
# length includes spaces and comma between elements
|
|
element_length = formatter1.width() + 2
|
|
if formatter2 is not None:
|
|
# width for imag_formatter + an extra j for complex
|
|
element_length += formatter2.width() + 1
|
|
|
|
elements_per_line = max(1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))))
|
|
char_per_line = element_length * elements_per_line
|
|
|
|
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
|
|
if formatter2 is not None:
|
|
real_str = formatter1.format(val.real)
|
|
imag_str = formatter2.format(val.imag) + "j"
|
|
if val.imag < 0:
|
|
return real_str + imag_str.lstrip()
|
|
else:
|
|
return real_str + "+" + imag_str.lstrip()
|
|
else:
|
|
return formatter1.format(val)
|
|
|
|
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
data = ([_val_formatter(val) for val in self[:PRINT_OPTS.edgeitems].tolist()] +
|
|
[' ...'] +
|
|
[_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()])
|
|
else:
|
|
data = [_val_formatter(val) for val in self.tolist()]
|
|
|
|
data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)]
|
|
lines = [', '.join(line) for line in data_lines]
|
|
return '[' + (',' + '\n' + ' ' * (indent + 1)).join(lines) + ']'
|
|
|
|
# formatter2 is only used for printing complex tensors.
|
|
# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
|
|
# and tensor.imag respesectively
|
|
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
|
|
dim = self.dim()
|
|
|
|
if dim == 0:
|
|
return _scalar_str(self, formatter1, formatter2)
|
|
|
|
if dim == 1:
|
|
return _vector_str(self, indent, summarize, formatter1, formatter2)
|
|
|
|
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
slices = ([_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2)
|
|
for i in range(0, PRINT_OPTS.edgeitems)] +
|
|
['...'] +
|
|
[_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2)
|
|
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
|
|
else:
|
|
slices = [_tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1, formatter2)
|
|
for i in range(0, self.size(0))]
|
|
|
|
tensor_str = (',' + '\n' * (dim - 1) + ' ' * (indent + 1)).join(slices)
|
|
return '[' + tensor_str + ']'
|
|
|
|
def _tensor_str(self, indent):
|
|
if self.numel() == 0:
|
|
return '[]'
|
|
|
|
if self.has_names():
|
|
# There are two main codepaths (possibly more) that tensor printing goes through:
|
|
# - tensor data can fit comfortably on screen
|
|
# - tensor data needs to be summarized
|
|
# Some of the codepaths don't fully support named tensors, so we send in
|
|
# an unnamed tensor to the formatting code as a workaround.
|
|
self = self.rename(None)
|
|
|
|
summarize = self.numel() > PRINT_OPTS.threshold
|
|
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
|
|
self = self.float()
|
|
|
|
if self.dtype.is_complex:
|
|
real_formatter = _Formatter(get_summarized_data(self.real) if summarize else self.real)
|
|
imag_formatter = _Formatter(get_summarized_data(self.imag) if summarize else self.imag)
|
|
return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)
|
|
else:
|
|
formatter = _Formatter(get_summarized_data(self) if summarize else self)
|
|
return _tensor_str_with_formatter(self, indent, summarize, formatter)
|
|
|
|
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
|
|
tensor_strs = [tensor_str]
|
|
last_line_len = len(tensor_str) - tensor_str.rfind('\n') + 1
|
|
for suffix in suffixes:
|
|
suffix_len = len(suffix)
|
|
if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
|
|
tensor_strs.append(',\n' + ' ' * indent + suffix)
|
|
last_line_len = indent + suffix_len
|
|
force_newline = False
|
|
else:
|
|
tensor_strs.append(', ' + suffix)
|
|
last_line_len += suffix_len + 2
|
|
tensor_strs.append(')')
|
|
return ''.join(tensor_strs)
|
|
|
|
|
|
def get_summarized_data(self):
|
|
dim = self.dim()
|
|
if dim == 0:
|
|
return self
|
|
if dim == 1:
|
|
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:]))
|
|
else:
|
|
return self
|
|
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
|
|
end = ([self[i]
|
|
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
|
|
return torch.stack([get_summarized_data(x) for x in (start + end)])
|
|
else:
|
|
return torch.stack([get_summarized_data(x) for x in self])
|
|
|
|
def _str_intern(self):
|
|
prefix = 'tensor('
|
|
indent = len(prefix)
|
|
suffixes = []
|
|
|
|
# Note [Print tensor device]:
|
|
# A general logic here is we only print device when it doesn't match
|
|
# the device specified in default tensor type.
|
|
# Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
|
|
# torch._C._get_default_device() only returns either cpu or cuda.
|
|
# In other cases, we don't have a way to set them as default yet,
|
|
# and we should always print out device for them.
|
|
if self.device.type != torch._C._get_default_device()\
|
|
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index):
|
|
suffixes.append('device=\'' + str(self.device) + '\'')
|
|
|
|
# TODO: add an API to map real -> complex dtypes
|
|
_default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
|
|
has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool)
|
|
if self.is_sparse:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
suffixes.append('nnz=' + str(self._nnz()))
|
|
if not has_default_dtype:
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
indices_prefix = 'indices=tensor('
|
|
indices = self._indices().detach()
|
|
indices_str = _tensor_str(indices, indent + len(indices_prefix))
|
|
if indices.numel() == 0:
|
|
indices_str += ', size=' + str(tuple(indices.shape))
|
|
values_prefix = 'values=tensor('
|
|
values = self._values().detach()
|
|
values_str = _tensor_str(values, indent + len(values_prefix))
|
|
if values.numel() == 0:
|
|
values_str += ', size=' + str(tuple(values.shape))
|
|
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
|
|
elif self.is_quantized:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
if not has_default_dtype:
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
suffixes.append('quantization_scheme=' + str(self.qscheme()))
|
|
if self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric:
|
|
suffixes.append('scale=' + str(self.q_scale()))
|
|
suffixes.append('zero_point=' + str(self.q_zero_point()))
|
|
elif self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric:
|
|
suffixes.append('scale=' + str(self.q_per_channel_scales()))
|
|
suffixes.append('zero_point=' + str(self.q_per_channel_zero_points()))
|
|
suffixes.append('axis=' + str(self.q_per_channel_axis()))
|
|
tensor_str = _tensor_str(self.dequantize(), indent)
|
|
else:
|
|
if self.is_meta:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
if self.dtype != torch.get_default_dtype():
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
# TODO: This implies that ellipses is valid syntax for allocating
|
|
# a meta tensor, which it could be, but it isn't right now
|
|
tensor_str = '...'
|
|
else:
|
|
if self.numel() == 0 and not self.is_sparse:
|
|
# Explicitly print the shape if it is not (0,), to match NumPy behavior
|
|
if self.dim() != 1:
|
|
suffixes.append('size=' + str(tuple(self.shape)))
|
|
|
|
# In an empty tensor, there are no elements to infer if the dtype
|
|
# should be int64, so it must be shown explicitly.
|
|
if self.dtype != torch.get_default_dtype():
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
tensor_str = '[]'
|
|
else:
|
|
if not has_default_dtype:
|
|
suffixes.append('dtype=' + str(self.dtype))
|
|
|
|
if self.layout != torch.strided:
|
|
tensor_str = _tensor_str(self.to_dense(), indent)
|
|
else:
|
|
tensor_str = _tensor_str(self, indent)
|
|
|
|
if self.layout != torch.strided:
|
|
suffixes.append('layout=' + str(self.layout))
|
|
|
|
if self.grad_fn is not None:
|
|
name = type(self.grad_fn).__name__
|
|
if name == 'CppFunction':
|
|
name = self.grad_fn.name().rsplit('::', 1)[-1]
|
|
suffixes.append('grad_fn=<{}>'.format(name))
|
|
elif self.requires_grad:
|
|
suffixes.append('requires_grad=True')
|
|
|
|
if self.has_names():
|
|
suffixes.append('names={}'.format(self.names))
|
|
|
|
return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
|
|
|
|
def _str(self):
|
|
with torch.no_grad():
|
|
return _str_intern(self)
|