from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from caffe2.python import workspace, brew from caffe2.python.model_helper import ModelHelper import unittest import numpy as np class BrewTest(unittest.TestCase): def setUp(self): def myhelper(model, val=-1): return val if not brew.has_helper(myhelper): brew.Register(myhelper) self.myhelper = myhelper def myhelper2(model, val=-1): return val if not brew.has_helper(myhelper2): brew.Register(myhelper2) self.myhelper2 = myhelper2 def test_dropout(self): p = 0.2 X = np.ones((100, 100)).astype(np.float32) - p workspace.FeedBlob("x", X) model = ModelHelper(name="test_model") brew.dropout(model, "x", "out") workspace.RunNetOnce(model.param_init_net) workspace.RunNetOnce(model.net) out = workspace.FetchBlob("out") self.assertLess(abs(out.mean() - (1 - p)), 0.05) def test_fc(self): m, n, k = (15, 15, 15) X = np.random.rand(m, k).astype(np.float32) - 0.5 workspace.FeedBlob("x", X) model = ModelHelper(name="test_model") out = brew.fc(model, "x", "out_1", k, n) out = brew.packed_fc(model, out, "out_2", n, n) out = brew.fc_decomp(model, out, "out_3", n, n) out = brew.fc_prune(model, out, "out_4", n, n) workspace.RunNetOnce(model.param_init_net) workspace.RunNetOnce(model.net) def test_arg_scope(self): myhelper = self.myhelper myhelper2 = self.myhelper2 n = 15 with brew.arg_scope([myhelper], val=n): res = brew.myhelper(None) self.assertEqual(n, res) with brew.arg_scope([myhelper, myhelper2], val=n): res1 = brew.myhelper(None) res2 = brew.myhelper2(None) self.assertEqual([n, n], [res1, res2]) def test_arg_scope_single(self): X = np.random.rand(64, 3, 32, 32).astype(np.float32) - 0.5 workspace.FeedBlob("x", X) model = ModelHelper(name="test_model") with brew.arg_scope( brew.conv, stride=2, pad=2, weight_init=('XavierFill', {}), bias_init=('ConstantFill', {}) ): brew.conv( model=model, blob_in="x", blob_out="out", dim_in=3, dim_out=64, kernel=3, ) workspace.RunNetOnce(model.param_init_net) workspace.RunNetOnce(model.net) out = workspace.FetchBlob("out") self.assertEqual(out.shape, (64, 64, 17, 17)) def test_arg_scope_nested(self): myhelper = self.myhelper n = 16 with brew.arg_scope([myhelper], val=-3), \ brew.arg_scope([myhelper], val=-2): with brew.arg_scope([myhelper], val=n): res = brew.myhelper(None) self.assertEqual(n, res) res = brew.myhelper(None) self.assertEqual(res, -2) res = brew.myhelper(None, val=15) self.assertEqual(res, 15) def test_double_register(self): myhelper = self.myhelper with self.assertRaises(AttributeError): brew.Register(myhelper) def test_has_helper(self): self.assertTrue(brew.has_helper(brew.conv)) self.assertTrue(brew.has_helper("conv")) def myhelper3(): pass self.assertFalse(brew.has_helper(myhelper3))