mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This PR adds TensorBoard logging support natively within PyTorch. It is based on the tensorboardX code developed by lanpa and relies on changes inside the tensorflow/tensorboard repo landing at https://github.com/tensorflow/tensorboard/pull/2065. With these changes users can simply `pip install tensorboard; pip install torch` and then log PyTorch data directly to the TensorBoard protobuf format using ``` import torch from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() s1 = torch.rand(1) writer.add_scalar('data/scalar1', s1[0], 0) writer.close() ``` Design: - `EventFileWriter` and `RecordWriter` from tensorboardX now live in tensorflow/tensorboard - `SummaryWriter` and PyTorch-specific conversion from tensors, nn modules, etc. now live in pytorch/pytorch. We also support Caffe2 blobs and nets. Action items: - [x] `from torch.utils.tensorboard import SummaryWriter` - [x] rename functions - [x] unittests - [x] move actual writing function to tensorflow/tensorboard in https://github.com/tensorflow/tensorboard/pull/2065 Review: - Please review for PyTorch standard formatting, code usage, etc. - Please verify unittest usage is correct and executing in CI Any significant changes made here will likely be synced back to github.com/lanpa/tensorboardX/ in the future. cc orionr, ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/16196 Differential Revision: D15062901 Pulled By: orionr fbshipit-source-id: 3812eb6aa07a2811979c5c7b70810261f9ea169e
63 lines
2.5 KiB
Python
63 lines
2.5 KiB
Python
import os
|
|
import math
|
|
import numpy as np
|
|
from ._convert_np import make_np
|
|
from ._utils import make_grid
|
|
from PIL import Image
|
|
from posixpath import join
|
|
|
|
|
|
def make_tsv(metadata, save_path, metadata_header=None):
|
|
if not metadata_header:
|
|
metadata = [str(x) for x in metadata]
|
|
else:
|
|
assert len(metadata_header) == len(metadata[0]), \
|
|
'len of header must be equal to the number of columns in metadata'
|
|
metadata = ['\t'.join(str(e) for e in l)
|
|
for l in [metadata_header] + metadata]
|
|
|
|
with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
|
|
for x in metadata:
|
|
f.write(x + '\n')
|
|
# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
|
|
|
|
|
|
def make_sprite(label_img, save_path):
|
|
# this ensures the sprite image has correct dimension as described in
|
|
# https://www.tensorflow.org/get_started/embedding_viz
|
|
nrow = int(math.ceil((label_img.size(0)) ** 0.5))
|
|
arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow)
|
|
|
|
# augment images so that #images equals nrow*nrow
|
|
arranged_augment_square_HWC = np.ndarray((arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3))
|
|
arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc
|
|
arranged_augment_square_HWC[:arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
|
|
im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
|
|
im.save(os.path.join(save_path, 'sprite.png'))
|
|
|
|
|
|
def append_pbtxt(metadata, label_img, save_path, subdir, global_step, tag):
|
|
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f:
|
|
# step = os.path.split(save_path)[-1]
|
|
f.write('embeddings {\n')
|
|
f.write('tensor_name: "{}:{}"\n'.format(
|
|
tag, str(global_step).zfill(5)))
|
|
f.write('tensor_path: "{}"\n'.format(join(subdir, 'tensors.tsv')))
|
|
if metadata is not None:
|
|
f.write('metadata_path: "{}"\n'.format(
|
|
join(subdir, 'metadata.tsv')))
|
|
if label_img is not None:
|
|
f.write('sprite {\n')
|
|
f.write('image_path: "{}"\n'.format(join(subdir, 'sprite.png')))
|
|
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
|
|
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
|
|
f.write('}\n')
|
|
f.write('}\n')
|
|
|
|
|
|
def make_mat(matlist, save_path):
|
|
with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f:
|
|
for x in matlist:
|
|
x = [str(i.item()) for i in x]
|
|
f.write('\t'.join(x) + '\n')
|