pytorch/caffe2/python/experiment_util.py
Aapo Kyrola 42279a610c use Pieter-MPI and fb.distributed
Summary:
Remove MPI and use fb.distributed rendezvous and Pieter's new Ops.

One now can pass a 'rendezvous' struct to data_parallel_model to initiate distributed SyncSGD. Provided rendezvoud implementation uses the kv-store handler of fb.distributed to disseminate information about other hosts. We can easily add other rendezvous, such as file-based, but that is topic of another diff.

Removing MPI allowed also simplifiying of Xray startup scripts, which are included in this diff.

When accepted, I will work on a simple example code so others can use this stuff as well. Also Flow implementation will be topic of next week.

Differential Revision: D4180012

fbshipit-source-id: 9e74f1fb43eaf7d4bb3e5ac6718d76bef2dfd731
2016-11-29 15:18:36 -08:00

60 lines
1.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import datetime
import time
import logging
from collections import OrderedDict
'''
Utilities for logging experiment run stats, such as accuracy
and loss over time for different runs. Runtime arguments are stored
in the log.
'''
class ModelTrainerLog():
def __init__(self, expname, runtime_args):
now = datetime.datetime.fromtimestamp(time.time())
self.experiment_id = now.strftime('%Y%m%d_%H%M%S')
self.filename = "%s_%s.log" % (expname, self.experiment_id)
self.logstr("# %s" % str(runtime_args))
self.headers = None
self.start_time = time.time()
self.last_time = self.start_time
self.last_input_count = 0
def logstr(self, str):
with open(self.filename, "a") as f:
f.write(str + "\n")
f.close()
logging.getLogger("experiment_logger").info(str)
def log(self, input_count, batch_count, additional_values):
logdict = OrderedDict()
delta_t = time.time() - self.last_time
delta_count = input_count - self.last_input_count
self.last_time = time.time()
self.last_input_count = input_count
logdict['time'] = time.time() - self.start_time
logdict['input_counter'] = input_count
logdict['batch_count'] = batch_count
if delta_t > 0:
logdict['inputs_per_sec'] = delta_count / delta_t
else:
logdict['inputs_per_sec'] = 0.0
for k in sorted(additional_values.keys()):
logdict[k] = additional_values[k]
# Write the headers if they are not written yet
if self.headers is None:
self.headers = logdict.keys()[:]
self.logstr(",".join(self.headers))
self.logstr(",".join([str(v) for v in logdict.values()]))