mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: makes life a bit easier Reviewed By: xianjiec Differential Revision: D4514640 fbshipit-source-id: b39f9cb05d31d2e5fa957bc072cf18eda13cff89
156 lines
5.3 KiB
Python
156 lines
5.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import schema
|
|
import numpy as np
|
|
|
|
import unittest
|
|
import pickle
|
|
|
|
|
|
class TestDB(unittest.TestCase):
|
|
def testPicklable(self):
|
|
s = schema.Struct(
|
|
('field1', schema.Scalar(dtype=np.int32)),
|
|
('field2', schema.List(schema.Scalar(dtype=str)))
|
|
)
|
|
s2 = pickle.loads(pickle.dumps(s))
|
|
for r in (s, s2):
|
|
self.assertTrue(isinstance(r.field1, schema.Scalar))
|
|
self.assertTrue(isinstance(r.field2, schema.List))
|
|
self.assertTrue(getattr(r, 'non_existent', None) is None)
|
|
|
|
def testNormalizeField(self):
|
|
s = schema.Struct(('field1', np.int32), ('field2', str))
|
|
self.assertEquals(
|
|
s,
|
|
schema.Struct(
|
|
('field1', schema.Scalar(dtype=np.int32)),
|
|
('field2', schema.Scalar(dtype=str))
|
|
)
|
|
)
|
|
|
|
def testTuple(self):
|
|
s = schema.Tuple(np.int32, str, np.float32)
|
|
s2 = schema.Struct(
|
|
('field_0', schema.Scalar(dtype=np.int32)),
|
|
('field_1', schema.Scalar(dtype=np.str)),
|
|
('field_2', schema.Scalar(dtype=np.float32))
|
|
)
|
|
self.assertEquals(s, s2)
|
|
self.assertEquals(s[0], schema.Scalar(dtype=np.int32))
|
|
self.assertEquals(s[1], schema.Scalar(dtype=np.str))
|
|
self.assertEquals(s[2], schema.Scalar(dtype=np.float32))
|
|
self.assertEquals(
|
|
s[2, 0],
|
|
schema.Struct(
|
|
('field_2', schema.Scalar(dtype=np.float32)),
|
|
('field_0', schema.Scalar(dtype=np.int32)),
|
|
)
|
|
)
|
|
# test iterator behavior
|
|
for i, (v1, v2) in enumerate(zip(s, s2)):
|
|
self.assertEquals(v1, v2)
|
|
self.assertEquals(s[i], v1)
|
|
self.assertEquals(s2[i], v1)
|
|
|
|
def testRawTuple(self):
|
|
s = schema.RawTuple(2)
|
|
self.assertEquals(
|
|
s, schema.Struct(
|
|
('field_0', schema.Scalar()), ('field_1', schema.Scalar())
|
|
)
|
|
)
|
|
self.assertEquals(s[0], schema.Scalar())
|
|
self.assertEquals(s[1], schema.Scalar())
|
|
|
|
def testStructIndexing(self):
|
|
s = schema.Struct(
|
|
('field1', schema.Scalar(dtype=np.int32)),
|
|
('field2', schema.List(schema.Scalar(dtype=str)))
|
|
)
|
|
self.assertEquals(s['field2'], s.field2)
|
|
self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str)))
|
|
self.assertEquals(
|
|
s['field2', 'field1'],
|
|
schema.Struct(
|
|
('field2', schema.List(schema.Scalar(dtype=str))),
|
|
('field1', schema.Scalar(dtype=np.int32)),
|
|
)
|
|
)
|
|
|
|
def testPreservesMetadata(self):
|
|
s = schema.Struct(
|
|
('a', schema.Scalar(np.float32)), (
|
|
'b', schema.Scalar(
|
|
np.int32,
|
|
metadata=schema.Metadata(categorical_limit=5)
|
|
)
|
|
), (
|
|
'c', schema.List(
|
|
schema.Scalar(
|
|
np.int32,
|
|
metadata=schema.Metadata(categorical_limit=6)
|
|
)
|
|
)
|
|
)
|
|
)
|
|
# attach metadata to lengths field
|
|
s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7))
|
|
|
|
self.assertEqual(None, s.a.metadata)
|
|
self.assertEqual(5, s.b.metadata.categorical_limit)
|
|
self.assertEqual(6, s.c.value.metadata.categorical_limit)
|
|
self.assertEqual(7, s.c.lengths.metadata.categorical_limit)
|
|
sc = s.clone()
|
|
self.assertEqual(None, sc.a.metadata)
|
|
self.assertEqual(5, sc.b.metadata.categorical_limit)
|
|
self.assertEqual(6, sc.c.value.metadata.categorical_limit)
|
|
self.assertEqual(7, sc.c.lengths.metadata.categorical_limit)
|
|
sv = schema.from_blob_list(
|
|
s, [
|
|
np.array([3.4]), np.array([2]), np.array([3]),
|
|
np.array([1, 2, 3])
|
|
]
|
|
)
|
|
self.assertEqual(None, sv.a.metadata)
|
|
self.assertEqual(5, sv.b.metadata.categorical_limit)
|
|
self.assertEqual(6, sv.c.value.metadata.categorical_limit)
|
|
self.assertEqual(7, sv.c.lengths.metadata.categorical_limit)
|
|
|
|
def testDupField(self):
|
|
with self.assertRaises(ValueError):
|
|
schema.Struct(
|
|
('a', schema.Scalar()),
|
|
('a', schema.Scalar()))
|
|
|
|
def testPreservesEmptyFields(self):
|
|
s = schema.Struct(
|
|
('a', schema.Scalar(np.float32)),
|
|
('b', schema.Struct()),
|
|
)
|
|
sc = s.clone()
|
|
self.assertIn("a", sc.fields)
|
|
self.assertIn("b", sc.fields)
|
|
sv = schema.from_blob_list(s, [np.array([3.4])])
|
|
self.assertIn("a", sv.fields)
|
|
self.assertIn("b", sv.fields)
|
|
self.assertEqual(0, len(sv.b.fields))
|
|
|
|
def testStructAddition(self):
|
|
s1 = schema.Struct(
|
|
('a', schema.Scalar())
|
|
)
|
|
s2 = schema.Struct(
|
|
('b', schema.Scalar())
|
|
)
|
|
s = s1 + s2
|
|
self.assertIn("a", s.fields)
|
|
self.assertIn("b", s.fields)
|
|
with self.assertRaises(ValueError):
|
|
s1 + s1
|
|
with self.assertRaises(TypeError):
|
|
s1 + schema.Scalar()
|