mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
migrate mtml to dper2
Summary: 1. migrate the basic mtml model to dper 2 2. test dper 2 mtml model 3. test all optimizers Reviewed By: kittipatv Differential Revision: D4680215 fbshipit-source-id: 7aac5c59bdac22fcad8ed869b98e9e62dca1d337
This commit is contained in:
parent
cc2e915461
commit
ad4ae4528f
|
|
@ -3,7 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, schema
|
||||
from caffe2.python import schema
|
||||
from caffe2.python.layers.layers import (
|
||||
ModelLayer,
|
||||
)
|
||||
|
|
@ -15,10 +15,13 @@ import numpy as np
|
|||
|
||||
class BatchLRLoss(ModelLayer):
|
||||
|
||||
def __init__(self, model, input_record, name='batch_lr_loss', **kwargs):
|
||||
def __init__(self, model, input_record, name='batch_lr_loss',
|
||||
average_loss=True, **kwargs):
|
||||
super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
|
||||
|
||||
schema.is_schema_subset(
|
||||
self.average_loss = average_loss
|
||||
|
||||
assert schema.is_schema_subset(
|
||||
schema.Struct(
|
||||
('label', schema.Scalar()),
|
||||
('prediction', schema.Scalar())
|
||||
|
|
@ -46,4 +49,13 @@ class BatchLRLoss(ModelLayer):
|
|||
[class_probabilities] + label,
|
||||
net.NextScopedBlob('cross_entropy'),
|
||||
)
|
||||
net.AveragedLoss(xent, self.output_schema.field_blobs())
|
||||
if 'weight' in self.input_record.fields:
|
||||
xent = net.Mul(
|
||||
[xent, self.input_record.weight()],
|
||||
net.NextScopedBlob('weighted_scross_entropy'),
|
||||
)
|
||||
|
||||
if self.average_loss:
|
||||
net.AveragedLoss(xent, self.output_schema.field_blobs())
|
||||
else:
|
||||
net.ReduceFrontSum(xent, self.output_schema.field_blobs())
|
||||
|
|
|
|||
|
|
@ -635,20 +635,25 @@ def Map(
|
|||
)
|
||||
|
||||
|
||||
def NamedTuple(name_prefix, *fields):
|
||||
return Struct(* [('%s_%d' % (name_prefix, i), field)
|
||||
for i, field in enumerate(fields)])
|
||||
|
||||
|
||||
def Tuple(*fields):
|
||||
"""
|
||||
Creates a Struct with default, sequential, field names of given types.
|
||||
"""
|
||||
return Struct(* [('field_%d' % i, field) for i, field in enumerate(fields)])
|
||||
return NamedTuple('field', *fields)
|
||||
|
||||
|
||||
def RawTuple(num_fields):
|
||||
def RawTuple(num_fields, name_prefix='field'):
|
||||
"""
|
||||
Creates a tuple of `num_field` untyped scalars.
|
||||
"""
|
||||
assert isinstance(num_fields, int)
|
||||
assert num_fields >= 0
|
||||
return Tuple(*([np.void] * num_fields))
|
||||
return NamedTuple(name_prefix, *([np.void] * num_fields))
|
||||
|
||||
|
||||
def from_dtype(dtype, _outer_shape=()):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user