mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: observer framework can now be used in python + a small writeup of how to use it. this is D6035393 with a fix for ct-scan Reviewed By: salexspb Differential Revision: D6066380 fbshipit-source-id: 896c4c580d4387240b81ac2dbbc43db51d4bfeb9
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import numpy as np
|
|
import unittest
|
|
|
|
from caffe2.python import model_helper, brew
|
|
import caffe2.python.workspace as ws
|
|
|
|
|
|
class TestObservers(unittest.TestCase):
|
|
def setUp(self):
|
|
ws.ResetWorkspace()
|
|
self.model = model_helper.ModelHelper()
|
|
brew.fc(self.model, "data", "y",
|
|
dim_in=4, dim_out=2,
|
|
weight_init=('ConstantFill', dict(value=1.0)),
|
|
bias_init=('ConstantFill', dict(value=0.0)),
|
|
axis=0)
|
|
ws.FeedBlob("data", np.zeros([4], dtype='float32'))
|
|
|
|
ws.RunNetOnce(self.model.param_init_net)
|
|
ws.CreateNet(self.model.net)
|
|
|
|
def testObserver(self):
|
|
ob = self.model.net.AddObserver("TimeObserver")
|
|
ws.RunNet(self.model.net)
|
|
print(ob.average_time())
|
|
num = self.model.net.NumObservers()
|
|
self.model.net.RemoveObserver(ob)
|
|
assert(self.model.net.NumObservers() + 1 == num)
|