from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from caffe2.python import net_printer from caffe2.python.checkpoint import Job from caffe2.python.net_builder import ops from caffe2.python.task import Task, final_output import unittest def example_loop(): with Task(): total = ops.Const(0) total_large = ops.Const(0) total_small = ops.Const(0) total_tiny = ops.Const(0) with ops.loop(10) as loop: outer = ops.Mul([loop.iter(), ops.Const(10)]) with ops.loop(loop.iter()) as inner: val = ops.Add([outer, inner.iter()]) with ops.If(ops.GE([val, ops.Const(80)])) as c: ops.Add([total_large, val], [total_large]) with c.Elif(ops.GE([val, ops.Const(50)])) as c: ops.Add([total_small, val], [total_small]) with c.Else(): ops.Add([total_tiny, val], [total_tiny]) ops.Add([total, val], total) def example_task(): with Task(): with ops.task_init(): one = ops.Const(1) two = ops.Add([one, one]) with ops.task_init(): three = ops.Const(3) accum = ops.Add([two, three]) # here, accum should be 5 with ops.task_exit(): # here, accum should be 6, since this executes after lines below seven_1 = ops.Add([accum, one]) six = ops.Add([accum, one]) ops.Add([accum, one], [accum]) seven_2 = ops.Add([accum, one]) o6 = final_output(six) o7_1 = final_output(seven_1) o7_2 = final_output(seven_2) return o6, o7_1, o7_2 def example_job(): with Job() as job: with job.init_group: example_loop() example_task() return job class TestNetPrinter(unittest.TestCase): def test_print(self): self.assertTrue(len(net_printer.to_string(example_job())) > 0) def test_valid_job(self): job = example_job() with job: with Task(): # distributed_ctx_init_* ignored by analyzer ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b']) net_printer.analyze(example_job()) def test_undefined_blob(self): job = example_job() with job: with Task(): ops.Add(['a', 'b']) with self.assertRaises(AssertionError): net_printer.analyze(job) def test_multiple_definition(self): job = example_job() with job: with Task(): ops.Add([ops.Const(0), ops.Const(1)], 'out1') with Task(): ops.Add([ops.Const(2), ops.Const(3)], 'out1') with self.assertRaises(AssertionError): net_printer.analyze(job)