mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: These return views in Python 3 which would not do anything in a lot of usages currently present in Caffe2. This diff simply removes (almost) all usages of these two in Caffe2 and sub projects in favor of comprehensions which are also easier to read/understand Reviewed By: akyrola Differential Revision: D5142049 fbshipit-source-id: e800631d2df7d0823fed698cae46c486038007dc
91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import net_printer
|
|
from caffe2.python.checkpoint import Job
|
|
from caffe2.python.net_builder import ops
|
|
from caffe2.python.task import Task, final_output
|
|
import unittest
|
|
|
|
|
|
def example_loop():
|
|
with Task():
|
|
total = ops.Const(0)
|
|
total_large = ops.Const(0)
|
|
total_small = ops.Const(0)
|
|
total_tiny = ops.Const(0)
|
|
with ops.loop(10) as loop:
|
|
outer = ops.Mul([loop.iter(), ops.Const(10)])
|
|
with ops.loop(loop.iter()) as inner:
|
|
val = ops.Add([outer, inner.iter()])
|
|
with ops.If(ops.GE([val, ops.Const(80)])) as c:
|
|
ops.Add([total_large, val], [total_large])
|
|
with c.Elif(ops.GE([val, ops.Const(50)])) as c:
|
|
ops.Add([total_small, val], [total_small])
|
|
with c.Else():
|
|
ops.Add([total_tiny, val], [total_tiny])
|
|
ops.Add([total, val], total)
|
|
|
|
|
|
def example_task():
|
|
with Task():
|
|
with ops.task_init():
|
|
one = ops.Const(1)
|
|
two = ops.Add([one, one])
|
|
with ops.task_init():
|
|
three = ops.Const(3)
|
|
accum = ops.Add([two, three])
|
|
# here, accum should be 5
|
|
with ops.task_exit():
|
|
# here, accum should be 6, since this executes after lines below
|
|
seven_1 = ops.Add([accum, one])
|
|
six = ops.Add([accum, one])
|
|
ops.Add([accum, one], [accum])
|
|
seven_2 = ops.Add([accum, one])
|
|
o6 = final_output(six)
|
|
o7_1 = final_output(seven_1)
|
|
o7_2 = final_output(seven_2)
|
|
return o6, o7_1, o7_2
|
|
|
|
|
|
def example_job():
|
|
with Job() as job:
|
|
with job.init_group:
|
|
example_loop()
|
|
example_task()
|
|
return job
|
|
|
|
|
|
class TestNetPrinter(unittest.TestCase):
|
|
def test_print(self):
|
|
self.assertTrue(len(net_printer.to_string(example_job())) > 0)
|
|
|
|
def test_valid_job(self):
|
|
job = example_job()
|
|
with job:
|
|
with Task():
|
|
# distributed_ctx_init_* ignored by analyzer
|
|
ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b'])
|
|
net_printer.analyze(example_job())
|
|
|
|
def test_undefined_blob(self):
|
|
job = example_job()
|
|
with job:
|
|
with Task():
|
|
ops.Add(['a', 'b'])
|
|
with self.assertRaises(AssertionError) as e:
|
|
net_printer.analyze(job)
|
|
self.assertEqual("Blob undefined: a", str(e.exception))
|
|
|
|
def test_multiple_definition(self):
|
|
job = example_job()
|
|
with job:
|
|
with Task():
|
|
ops.Add([ops.Const(0), ops.Const(1)], 'out1')
|
|
with Task():
|
|
ops.Add([ops.Const(2), ops.Const(3)], 'out1')
|
|
with self.assertRaises(AssertionError):
|
|
net_printer.analyze(job)
|