mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Continuation after https://github.com/pytorch/pytorch/pull/90163. Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators): _Edit:_ I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script. ``` python import ast import os import docstring_parser for root, dirs, files in os.walk('.'): for name in files: if root.startswith("./.git/") or root.startswith("./third_party/"): continue if name.endswith(".py"): full_name = os.path.join(root, name) with open(full_name, "r") as source: tree = ast.parse(source.read()) for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): all_node_args = node.args.args if node.args.vararg is not None: all_node_args.append(node.args.vararg) if node.args.kwarg is not None: all_node_args.append(node.args.kwarg) if node.args.posonlyargs is not None: all_node_args.extend(node.args.posonlyargs) if node.args.kwonlyargs is not None: all_node_args.extend(node.args.kwonlyargs) args = [a.arg for a in all_node_args] docstring = docstring_parser.parse(ast.get_docstring(node)) doc_args = [a.arg_name for a in docstring.params] clean_doc_args = [] for a in doc_args: clean_a = "" for c in a.split()[0]: if c.isalnum() or c == '_': clean_a += c if clean_a: clean_doc_args.append(clean_a) doc_args = clean_doc_args for a in doc_args: if a not in args: print(full_name, node.lineno, args, doc_args) break ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505 Approved by: https://github.com/malfet, https://github.com/ZainRizvi
128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
import numpy as np
|
|
|
|
|
|
# Functions for converting
|
|
def figure_to_image(figures, close=True):
|
|
"""Render matplotlib figure to numpy format.
|
|
|
|
Note that this requires the ``matplotlib`` package.
|
|
|
|
Args:
|
|
figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures
|
|
close (bool): Flag to automatically close the figure
|
|
|
|
Returns:
|
|
numpy.array: image in [CHW] order
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.backends.backend_agg as plt_backend_agg
|
|
|
|
def render_to_rgb(figure):
|
|
canvas = plt_backend_agg.FigureCanvasAgg(figure)
|
|
canvas.draw()
|
|
data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
|
|
w, h = figure.canvas.get_width_height()
|
|
image_hwc = data.reshape([h, w, 4])[:, :, 0:3]
|
|
image_chw = np.moveaxis(image_hwc, source=2, destination=0)
|
|
if close:
|
|
plt.close(figure)
|
|
return image_chw
|
|
|
|
if isinstance(figures, list):
|
|
images = [render_to_rgb(figure) for figure in figures]
|
|
return np.stack(images)
|
|
else:
|
|
image = render_to_rgb(figures)
|
|
return image
|
|
|
|
|
|
def _prepare_video(V):
|
|
"""
|
|
Converts a 5D tensor [batchsize, time(frame), channel(color), height, width]
|
|
into 4D tensor with dimension [time(frame), new_width, new_height, channel].
|
|
A batch of images are spreaded to a grid, which forms a frame.
|
|
e.g. Video with batchsize 16 will have a 4x4 grid.
|
|
"""
|
|
b, t, c, h, w = V.shape
|
|
|
|
if V.dtype == np.uint8:
|
|
V = np.float32(V) / 255.0
|
|
|
|
def is_power2(num):
|
|
return num != 0 and ((num & (num - 1)) == 0)
|
|
|
|
# pad to nearest power of 2, all at once
|
|
if not is_power2(V.shape[0]):
|
|
len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
|
|
V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
|
|
|
|
n_rows = 2 ** ((b.bit_length() - 1) // 2)
|
|
n_cols = V.shape[0] // n_rows
|
|
|
|
V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
|
|
V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3))
|
|
V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c))
|
|
|
|
return V
|
|
|
|
|
|
def make_grid(I, ncols=8):
|
|
# I: N1HW or N3HW
|
|
assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here"
|
|
if I.shape[1] == 1:
|
|
I = np.concatenate([I, I, I], 1)
|
|
assert I.ndim == 4 and I.shape[1] == 3
|
|
nimg = I.shape[0]
|
|
H = I.shape[2]
|
|
W = I.shape[3]
|
|
ncols = min(nimg, ncols)
|
|
nrows = int(np.ceil(float(nimg) / ncols))
|
|
canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype)
|
|
i = 0
|
|
for y in range(nrows):
|
|
for x in range(ncols):
|
|
if i >= nimg:
|
|
break
|
|
canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i]
|
|
i = i + 1
|
|
return canvas
|
|
|
|
# if modality == 'IMG':
|
|
# if x.dtype == np.uint8:
|
|
# x = x.astype(np.float32) / 255.0
|
|
|
|
|
|
def convert_to_HWC(tensor, input_format): # tensor: numpy array
|
|
assert len(set(input_format)) == len(
|
|
input_format
|
|
), "You can not use the same dimension shordhand twice. \
|
|
input_format: {}".format(
|
|
input_format
|
|
)
|
|
assert len(tensor.shape) == len(
|
|
input_format
|
|
), "size of input tensor and input format are different. \
|
|
tensor shape: {}, input_format: {}".format(
|
|
tensor.shape, input_format
|
|
)
|
|
input_format = input_format.upper()
|
|
|
|
if len(input_format) == 4:
|
|
index = [input_format.find(c) for c in "NCHW"]
|
|
tensor_NCHW = tensor.transpose(index)
|
|
tensor_CHW = make_grid(tensor_NCHW)
|
|
return tensor_CHW.transpose(1, 2, 0)
|
|
|
|
if len(input_format) == 3:
|
|
index = [input_format.find(c) for c in "HWC"]
|
|
tensor_HWC = tensor.transpose(index)
|
|
if tensor_HWC.shape[2] == 1:
|
|
tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2)
|
|
return tensor_HWC
|
|
|
|
if len(input_format) == 2:
|
|
index = [input_format.find(c) for c in "HW"]
|
|
tensor = tensor.transpose(index)
|
|
tensor = np.stack([tensor, tensor, tensor], 2)
|
|
return tensor
|