migrate mtml to dper2

Summary:
1. migrate the basic mtml model to dper 2
2. test dper 2 mtml model
3. test all optimizers

Reviewed By: kittipatv

Differential Revision: D4680215

fbshipit-source-id: 7aac5c59bdac22fcad8ed869b98e9e62dca1d337
This commit is contained in:
Huazhong Ning 2017-03-16 17:43:48 -07:00 committed by Facebook Github Bot
parent cc2e915461
commit ad4ae4528f
2 changed files with 24 additions and 7 deletions

View File

@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from caffe2.python import core, schema from caffe2.python import schema
from caffe2.python.layers.layers import ( from caffe2.python.layers.layers import (
ModelLayer, ModelLayer,
) )
@ -15,10 +15,13 @@ import numpy as np
class BatchLRLoss(ModelLayer): class BatchLRLoss(ModelLayer):
def __init__(self, model, input_record, name='batch_lr_loss', **kwargs): def __init__(self, model, input_record, name='batch_lr_loss',
average_loss=True, **kwargs):
super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs) super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
schema.is_schema_subset( self.average_loss = average_loss
assert schema.is_schema_subset(
schema.Struct( schema.Struct(
('label', schema.Scalar()), ('label', schema.Scalar()),
('prediction', schema.Scalar()) ('prediction', schema.Scalar())
@ -46,4 +49,13 @@ class BatchLRLoss(ModelLayer):
[class_probabilities] + label, [class_probabilities] + label,
net.NextScopedBlob('cross_entropy'), net.NextScopedBlob('cross_entropy'),
) )
net.AveragedLoss(xent, self.output_schema.field_blobs()) if 'weight' in self.input_record.fields:
xent = net.Mul(
[xent, self.input_record.weight()],
net.NextScopedBlob('weighted_scross_entropy'),
)
if self.average_loss:
net.AveragedLoss(xent, self.output_schema.field_blobs())
else:
net.ReduceFrontSum(xent, self.output_schema.field_blobs())

View File

@ -635,20 +635,25 @@ def Map(
) )
def NamedTuple(name_prefix, *fields):
return Struct(* [('%s_%d' % (name_prefix, i), field)
for i, field in enumerate(fields)])
def Tuple(*fields): def Tuple(*fields):
""" """
Creates a Struct with default, sequential, field names of given types. Creates a Struct with default, sequential, field names of given types.
""" """
return Struct(* [('field_%d' % i, field) for i, field in enumerate(fields)]) return NamedTuple('field', *fields)
def RawTuple(num_fields): def RawTuple(num_fields, name_prefix='field'):
""" """
Creates a tuple of `num_field` untyped scalars. Creates a tuple of `num_field` untyped scalars.
""" """
assert isinstance(num_fields, int) assert isinstance(num_fields, int)
assert num_fields >= 0 assert num_fields >= 0
return Tuple(*([np.void] * num_fields)) return NamedTuple(name_prefix, *([np.void] * num_fields))
def from_dtype(dtype, _outer_shape=()): def from_dtype(dtype, _outer_shape=()):