pytorch/caffe2/python/net_builder_test.py
Thomas Dudziak 47e921ba49 Remove map() and filter() in favor of comprehensions
Summary: These return views in Python 3 which would not do anything in a lot of usages currently present in Caffe2. This diff simply removes (almost) all usages of these two in Caffe2 and sub projects in favor of comprehensions which are also easier to read/understand

Reviewed By: akyrola

Differential Revision: D5142049

fbshipit-source-id: e800631d2df7d0823fed698cae46c486038007dc
2017-05-30 15:32:58 -07:00

158 lines
4.7 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import workspace
from caffe2.python.core import Plan, to_execution_step
from caffe2.python.task import Task, final_output
from caffe2.python.net_builder import ops, NetBuilder
from caffe2.python.session import LocalSession
import unittest
def _test_loop():
x = ops.Const(5)
y = ops.Const(0)
with ops.loop():
ops.stop_if(ops.EQ([x, ops.Const(0)]))
ops.Add([x, ops.Const(-1)], [x])
ops.Add([y, ops.Const(1)], [y])
return y
def _test_inner_stop(x):
ops.stop_if(ops.LT([x, ops.Const(5)]))
def _test_outer():
x = ops.Const(10)
# test stop_if(False)
with ops.stop_guard() as g1:
_test_inner_stop(x)
# test stop_if(True)
y = ops.Const(3)
with ops.stop_guard() as g2:
_test_inner_stop(y)
# test no stop
with ops.stop_guard() as g4:
ops.Const(0)
# test empty clause
with ops.stop_guard() as g3:
pass
return (
g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped())
def _test_if(x):
y = ops.Const(1)
with ops.If(ops.GT([x, ops.Const(50)])):
ops.Const(2, blob_out=y)
with ops.If(ops.LT([x, ops.Const(50)])):
ops.Const(3, blob_out=y)
ops.stop()
ops.Const(4, blob_out=y)
return y
class TestNetBuilder(unittest.TestCase):
def test_ops(self):
with NetBuilder() as nb:
y = _test_loop()
z, w, a, b = _test_outer()
p = _test_if(ops.Const(75))
q = _test_if(ops.Const(25))
plan = Plan('name')
plan.AddStep(to_execution_step(nb))
ws = workspace.C.Workspace()
ws.run(plan)
expected = [
(y, 5),
(z, False),
(w, True),
(a, False),
(b, False),
(p, 2),
(q, 3),
]
for b, expected in expected:
actual = ws.blobs[str(b)].fetch()
self.assertEquals(actual, expected)
def _expected_loop(self):
total = 0
total_large = 0
total_small = 0
total_tiny = 0
for loop_iter in range(10):
outer = loop_iter * 10
for inner_iter in range(loop_iter):
val = outer + inner_iter
if val >= 80:
total_large += val
elif val >= 50:
total_small += val
else:
total_tiny += val
total += val
return total, total_large, total_small, total_tiny
def _actual_loop(self):
total = ops.Const(0)
total_large = ops.Const(0)
total_small = ops.Const(0)
total_tiny = ops.Const(0)
with ops.loop(10) as loop:
outer = ops.Mul([loop.iter(), ops.Const(10)])
with ops.loop(loop.iter()) as inner:
val = ops.Add([outer, inner.iter()])
with ops.If(ops.GE([val, ops.Const(80)])) as c:
ops.Add([total_large, val], [total_large])
with c.Elif(ops.GE([val, ops.Const(50)])) as c:
ops.Add([total_small, val], [total_small])
with c.Else():
ops.Add([total_tiny, val], [total_tiny])
ops.Add([total, val], total)
return [
final_output(x)
for x in [total, total_large, total_small, total_tiny]
]
def test_loops(self):
with Task() as task:
out_actual = self._actual_loop()
with LocalSession() as session:
session.run(task)
expected = self._expected_loop()
actual = [o.fetch() for o in out_actual]
for e, a in zip(expected, actual):
self.assertEquals(e, a)
def test_setup(self):
with Task() as task:
with ops.task_init():
one = ops.Const(1)
two = ops.Add([one, one])
with ops.task_init():
three = ops.Const(3)
accum = ops.Add([two, three])
# here, accum should be 5
with ops.task_exit():
# here, accum should be 6, since this executes after lines below
seven_1 = ops.Add([accum, one])
six = ops.Add([accum, one])
ops.Add([accum, one], [accum])
seven_2 = ops.Add([accum, one])
o6 = final_output(six)
o7_1 = final_output(seven_1)
o7_2 = final_output(seven_2)
with LocalSession() as session:
session.run(task)
self.assertEquals(o6.fetch(), 6)
self.assertEquals(o7_1.fetch(), 7)
self.assertEquals(o7_2.fetch(), 7)