pytorch/caffe2/python/net_printer_test.py
Thomas Dudziak 47e921ba49 Remove map() and filter() in favor of comprehensions
Summary: These return views in Python 3 which would not do anything in a lot of usages currently present in Caffe2. This diff simply removes (almost) all usages of these two in Caffe2 and sub projects in favor of comprehensions which are also easier to read/understand

Reviewed By: akyrola

Differential Revision: D5142049

fbshipit-source-id: e800631d2df7d0823fed698cae46c486038007dc
2017-05-30 15:32:58 -07:00

91 lines
2.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import net_printer
from caffe2.python.checkpoint import Job
from caffe2.python.net_builder import ops
from caffe2.python.task import Task, final_output
import unittest
def example_loop():
with Task():
total = ops.Const(0)
total_large = ops.Const(0)
total_small = ops.Const(0)
total_tiny = ops.Const(0)
with ops.loop(10) as loop:
outer = ops.Mul([loop.iter(), ops.Const(10)])
with ops.loop(loop.iter()) as inner:
val = ops.Add([outer, inner.iter()])
with ops.If(ops.GE([val, ops.Const(80)])) as c:
ops.Add([total_large, val], [total_large])
with c.Elif(ops.GE([val, ops.Const(50)])) as c:
ops.Add([total_small, val], [total_small])
with c.Else():
ops.Add([total_tiny, val], [total_tiny])
ops.Add([total, val], total)
def example_task():
with Task():
with ops.task_init():
one = ops.Const(1)
two = ops.Add([one, one])
with ops.task_init():
three = ops.Const(3)
accum = ops.Add([two, three])
# here, accum should be 5
with ops.task_exit():
# here, accum should be 6, since this executes after lines below
seven_1 = ops.Add([accum, one])
six = ops.Add([accum, one])
ops.Add([accum, one], [accum])
seven_2 = ops.Add([accum, one])
o6 = final_output(six)
o7_1 = final_output(seven_1)
o7_2 = final_output(seven_2)
return o6, o7_1, o7_2
def example_job():
with Job() as job:
with job.init_group:
example_loop()
example_task()
return job
class TestNetPrinter(unittest.TestCase):
def test_print(self):
self.assertTrue(len(net_printer.to_string(example_job())) > 0)
def test_valid_job(self):
job = example_job()
with job:
with Task():
# distributed_ctx_init_* ignored by analyzer
ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b'])
net_printer.analyze(example_job())
def test_undefined_blob(self):
job = example_job()
with job:
with Task():
ops.Add(['a', 'b'])
with self.assertRaises(AssertionError) as e:
net_printer.analyze(job)
self.assertEqual("Blob undefined: a", str(e.exception))
def test_multiple_definition(self):
job = example_job()
with job:
with Task():
ops.Add([ops.Const(0), ops.Const(1)], 'out1')
with Task():
ops.Add([ops.Const(2), ops.Const(3)], 'out1')
with self.assertRaises(AssertionError):
net_printer.analyze(job)