mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
* Speed up printing of large tensors. Instead of deciding on the format based on all of the elements of the tensor, decide based on the elements that will actually be printed. * Fix flake8. * Add else case.
242 lines
8.3 KiB
Python
242 lines
8.3 KiB
Python
import math
|
|
import torch
|
|
from functools import reduce
|
|
from sys import float_info
|
|
|
|
|
|
class __PrinterOptions(object):
|
|
precision = 4
|
|
threshold = 1000
|
|
edgeitems = 3
|
|
linewidth = 80
|
|
|
|
|
|
PRINT_OPTS = __PrinterOptions()
|
|
SCALE_FORMAT = '{:.5e} *\n'
|
|
|
|
|
|
# We could use **kwargs, but this will give better docs
|
|
def set_printoptions(
|
|
precision=None,
|
|
threshold=None,
|
|
edgeitems=None,
|
|
linewidth=None,
|
|
profile=None,
|
|
):
|
|
r"""Set options for printing. Items shamelessly taken from NumPy
|
|
|
|
Args:
|
|
precision: Number of digits of precision for floating point output
|
|
(default = 8).
|
|
threshold: Total number of array elements which trigger summarization
|
|
rather than full `repr` (default = 1000).
|
|
edgeitems: Number of array items in summary at beginning and end of
|
|
each dimension (default = 3).
|
|
linewidth: The number of characters per line for the purpose of
|
|
inserting line breaks (default = 80). Thresholded matrices will
|
|
ignore this parameter.
|
|
profile: Sane defaults for pretty printing. Can override with any of
|
|
the above options. (any one of `default`, `short`, `full`)
|
|
"""
|
|
if profile is not None:
|
|
if profile == "default":
|
|
PRINT_OPTS.precision = 4
|
|
PRINT_OPTS.threshold = 1000
|
|
PRINT_OPTS.edgeitems = 3
|
|
PRINT_OPTS.linewidth = 80
|
|
elif profile == "short":
|
|
PRINT_OPTS.precision = 2
|
|
PRINT_OPTS.threshold = 1000
|
|
PRINT_OPTS.edgeitems = 2
|
|
PRINT_OPTS.linewidth = 80
|
|
elif profile == "full":
|
|
PRINT_OPTS.precision = 4
|
|
PRINT_OPTS.threshold = float('inf')
|
|
PRINT_OPTS.edgeitems = 3
|
|
PRINT_OPTS.linewidth = 80
|
|
|
|
if precision is not None:
|
|
PRINT_OPTS.precision = precision
|
|
if threshold is not None:
|
|
PRINT_OPTS.threshold = threshold
|
|
if edgeitems is not None:
|
|
PRINT_OPTS.edgeitems = edgeitems
|
|
if linewidth is not None:
|
|
PRINT_OPTS.linewidth = linewidth
|
|
|
|
|
|
def _get_min_log_scale():
|
|
min_positive = float_info.min * float_info.epsilon # get smallest denormal
|
|
if min_positive == 0: # use smallest normal if DAZ/FTZ is set
|
|
min_positive = float_info.min
|
|
return math.ceil(math.log(min_positive, 10))
|
|
|
|
|
|
def _number_format(tensor, min_sz=-1):
|
|
floating_dtype = tensor.dtype.is_floating_point # save this because we cast later
|
|
_min_log_scale = _get_min_log_scale()
|
|
min_sz = max(min_sz, 2)
|
|
tensor = torch.DoubleTensor(tensor.size()).copy_(tensor).abs_().view(tensor.nelement())
|
|
|
|
pos_inf_mask = tensor.eq(float('inf'))
|
|
neg_inf_mask = tensor.eq(float('-inf'))
|
|
nan_mask = tensor.ne(tensor)
|
|
invalid_value_mask = pos_inf_mask + neg_inf_mask + nan_mask
|
|
if invalid_value_mask.all():
|
|
example_value = 0
|
|
else:
|
|
example_value = tensor[invalid_value_mask.eq(0)][0]
|
|
tensor[invalid_value_mask] = example_value
|
|
if invalid_value_mask.any():
|
|
min_sz = max(min_sz, 3)
|
|
|
|
int_mode = True
|
|
# TODO: use fmod?
|
|
for value in tensor.tolist():
|
|
if value != math.ceil(value):
|
|
int_mode = False
|
|
break
|
|
|
|
exp_min = tensor.min()
|
|
if exp_min != 0:
|
|
exp_min = math.floor(math.log10(exp_min)) + 1
|
|
else:
|
|
exp_min = 1
|
|
exp_max = tensor.max()
|
|
if exp_max != 0:
|
|
exp_max = math.floor(math.log10(exp_max)) + 1
|
|
else:
|
|
exp_max = 1
|
|
include_decimal_int_mode = floating_dtype and int_mode
|
|
|
|
scale = 1
|
|
exp_max = int(exp_max)
|
|
prec = PRINT_OPTS.precision
|
|
if int_mode:
|
|
if exp_max > prec + 1:
|
|
format = '{{:11.{}e}}'.format(prec)
|
|
sz = max(min_sz, 7 + prec)
|
|
else:
|
|
sz = max(min_sz, exp_max + 1)
|
|
format = '{:' + str(sz) + '.0f}'
|
|
if include_decimal_int_mode:
|
|
format += '.'
|
|
sz += 1
|
|
else:
|
|
if exp_max - exp_min > prec:
|
|
sz = 7 + prec
|
|
if abs(exp_max) > 99 or abs(exp_min) > 99:
|
|
sz = sz + 1
|
|
sz = max(min_sz, sz)
|
|
format = '{{:{}.{}e}}'.format(sz, prec)
|
|
else:
|
|
if exp_max > prec + 1 or exp_max < 0:
|
|
sz = max(min_sz, 7)
|
|
scale = math.pow(10, max(exp_max - 1, _min_log_scale))
|
|
else:
|
|
if exp_max == 0:
|
|
sz = 7
|
|
else:
|
|
sz = exp_max + 6
|
|
sz = max(min_sz, sz)
|
|
format = '{{:{}.{}f}}'.format(sz, prec)
|
|
return format, scale, sz
|
|
|
|
|
|
def _scalar_str(self, fmt, scale):
|
|
scalar_str = fmt.format(self.item() / scale)
|
|
# The leading space for positives is ugly on scalars, so we strip it
|
|
return scalar_str.lstrip()
|
|
|
|
|
|
def _vector_str(self, indent, fmt, scale, sz, summarize):
|
|
element_length = sz + 3
|
|
elements_per_line = int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
|
|
char_per_line = element_length * elements_per_line
|
|
|
|
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
data = ([fmt.format(val / scale) for val in self[:PRINT_OPTS.edgeitems].tolist()] +
|
|
[' ...'] +
|
|
[fmt.format(val / scale) for val in self[-PRINT_OPTS.edgeitems:].tolist()])
|
|
else:
|
|
data = [fmt.format(val) for val in self.tolist()]
|
|
|
|
data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)]
|
|
lines = [', '.join(line) for line in data_lines]
|
|
return '[' + (',' + '\n' + ' ' * (indent + 1)).join(lines) + ']'
|
|
|
|
|
|
def _tensor_str(self, indent, fmt, scale, sz, summarize):
|
|
dim = self.dim()
|
|
|
|
if dim == 0:
|
|
return _scalar_str(self, fmt, scale)
|
|
if dim == 1:
|
|
return _vector_str(self, indent, fmt, scale, sz, summarize)
|
|
|
|
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
slices = ([_tensor_str(self[i], indent + 1, fmt, scale, sz, summarize)
|
|
for i in range(0, PRINT_OPTS.edgeitems)] +
|
|
['...'] +
|
|
[_tensor_str(self[i], indent + 1, fmt, scale, sz, summarize)
|
|
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
|
|
else:
|
|
slices = [_tensor_str(self[i], indent + 1, fmt, scale, sz, summarize) for i in range(0, self.size(0))]
|
|
|
|
tensor_str = (',' + '\n' * (dim - 1) + ' ' * (indent + 1)).join(slices)
|
|
return '[' + tensor_str + ']'
|
|
|
|
|
|
def get_summarized_data(self):
|
|
dim = self.dim()
|
|
if dim == 0:
|
|
return self
|
|
if dim == 1:
|
|
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:]))
|
|
else:
|
|
return self
|
|
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
|
|
start = [get_summarized_data(self[i]).view(-1) for i in range(0, PRINT_OPTS.edgeitems)]
|
|
end = ([get_summarized_data(self[i]).view(-1)
|
|
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
|
|
return torch.cat((start + end))
|
|
else:
|
|
return self
|
|
|
|
|
|
def _str(self):
|
|
if self.is_sparse:
|
|
size_str = str(tuple(self.shape)).replace(' ', '')
|
|
return '{} of size {} with indices:\n{}\nand values:\n{}'.format(
|
|
self.type(), size_str, self._indices(), self._values())
|
|
|
|
prefix = 'tensor('
|
|
indent = len(prefix)
|
|
summarize = self.numel() > PRINT_OPTS.threshold
|
|
|
|
suffix = ')'
|
|
if not torch._C._is_default_type_cuda():
|
|
if self.device.type == 'cuda':
|
|
suffix = ', device=\'' + str(self.device) + '\'' + suffix
|
|
else:
|
|
if self.device.type == 'cpu' or torch.cuda.current_device() != self.device.index:
|
|
suffix = ', device=\'' + str(self.device) + '\'' + suffix
|
|
|
|
if self.numel() == 0:
|
|
# In an empty tensor, there are no elements to infer if the dtype should be int64,
|
|
# so it must be shown explicitly.
|
|
if self.dtype != torch.get_default_dtype():
|
|
suffix = ', dtype=' + str(self.dtype) + suffix
|
|
tensor_str = '[]'
|
|
else:
|
|
if self.dtype != torch.get_default_dtype() and self.dtype != torch.int64:
|
|
suffix = ', dtype=' + str(self.dtype) + suffix
|
|
|
|
fmt, scale, sz = _number_format(get_summarized_data(self) if summarize else self)
|
|
if scale != 1:
|
|
prefix = prefix + SCALE_FORMAT.format(scale) + ' ' * indent
|
|
tensor_str = _tensor_str(self, indent, fmt, scale, sz, summarize)
|
|
|
|
return prefix + tensor_str + suffix
|