import argparse import flask import glob import numpy as np import nvd3 import os import sys import tornado.httpserver import tornado.wsgi __folder__ = os.path.abspath(os.path.dirname(__file__)) app = flask.Flask( __name__, template_folder=os.path.join(__folder__, "templates"), static_folder=os.path.join(__folder__, "static") ) args = None def jsonify_nvd3(chart): chart.buildcontent() # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part # and the script part. Luckily, it seems to be the case that the HTML part is # only a
, which can be accessed by chart.container; the script part, # while the script part occupies the rest of the html content, which we can # then find by chart.htmlcontent.find['') return flask.jsonify( result=chart.container, script=chart.htmlcontent[script_start:script_end].strip() ) def visualize_summary(filename): try: data = np.loadtxt(filename) except Exception as e: return 'Cannot load file {}: {}'.format(filename, str(e)) chart_name = os.path.splitext(os.path.basename(filename))[0] chart = nvd3.lineChart( name=chart_name + '_summary_chart', height=args.chart_height, y_axis_format='.03g' ) if args.sample < 0: step = max(data.shape[0] / -args.sample, 1) else: step = args.sample xdata = np.arange(0, data.shape[0], step) # data should have 4 dimensions. chart.add_serie(x=xdata, y=data[xdata, 0], name='min') chart.add_serie(x=xdata, y=data[xdata, 1], name='max') chart.add_serie(x=xdata, y=data[xdata, 2], name='mean') chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std') chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std') return jsonify_nvd3(chart) def visualize_print_log(filename): try: data = np.loadtxt(filename) if data.ndim == 1: data = data[:, np.newaxis] except Exception as e: return 'Cannot load file {}: {}'.format(filename, str(e)) chart_name = os.path.splitext(os.path.basename(filename))[0] chart = nvd3.lineChart( name=chart_name + '_log_chart', height=args.chart_height, y_axis_format='.03g' ) if args.sample < 0: step = max(data.shape[0] / -args.sample, 1) else: step = args.sample xdata = np.arange(0, data.shape[0], step) # if there is only one curve, we also show the running min and max if data.shape[1] == 1: # We also print the running min and max for the steps. trunc_size = data.shape[0] / step running_mat = data[:trunc_size * step].reshape((trunc_size, step)) chart.add_serie( x=xdata[:trunc_size], y=running_mat.min(axis=1), name='running_min' ) chart.add_serie( x=xdata[:trunc_size], y=running_mat.max(axis=1), name='running_max' ) chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name) else: for i in range(0, min(data.shape[1], args.max_curves)): # data should have 4 dimensions. chart.add_serie( x=xdata, y=data[xdata, i], name='{}[{}]'.format(chart_name, i) ) return jsonify_nvd3(chart) def visualize_file(filename): fullname = os.path.join(args.root, filename) if filename.endswith('summary'): return visualize_summary(fullname) elif filename.endswith('log'): return visualize_print_log(fullname) else: return flask.jsonify( result='Unsupport file: {}'.format(filename), script='' ) @app.route('/') def index(): files = glob.glob(os.path.join(args.root, "*.*")) files.sort() names = [os.path.basename(f) for f in files] return flask.render_template( 'index.html', root=args.root, names=names, debug_messages=names ) @app.route('/visualization/') def visualization(name): ret = visualize_file(name) return ret def main(argv): parser = argparse.ArgumentParser("The mint visualizer.") parser.add_argument( '-p', '--port', type=int, default=5000, help="The flask port to use." ) parser.add_argument( '-r', '--root', type=str, default='.', help="The root folder to read files for visualization." ) parser.add_argument( '--max_curves', type=int, default=5, help="The max number of curves to show in a dump tensor." ) parser.add_argument( '--chart_height', type=int, default=300, help="The chart height for nvd3." ) parser.add_argument( '-s', '--sample', type=int, default=-200, help="Sample every given number of data points. A negative " "number means the total points we will sample on the " "whole curve. Default 100 points." ) global args args = parser.parse_args(argv) server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app)) server.listen(args.port) print("Tornado server starting on port {}.".format(args.port)) tornado.ioloop.IOLoop.instance().start() if __name__ == '__main__': main(sys.argv[1:])