mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
migrate mtml to dper2
Summary: 1. migrate the basic mtml model to dper 2 2. test dper 2 mtml model 3. test all optimizers Reviewed By: kittipatv Differential Revision: D4680215 fbshipit-source-id: 7aac5c59bdac22fcad8ed869b98e9e62dca1d337
This commit is contained in:
parent
cc2e915461
commit
ad4ae4528f
|
|
@ -3,7 +3,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from caffe2.python import core, schema
|
from caffe2.python import schema
|
||||||
from caffe2.python.layers.layers import (
|
from caffe2.python.layers.layers import (
|
||||||
ModelLayer,
|
ModelLayer,
|
||||||
)
|
)
|
||||||
|
|
@ -15,10 +15,13 @@ import numpy as np
|
||||||
|
|
||||||
class BatchLRLoss(ModelLayer):
|
class BatchLRLoss(ModelLayer):
|
||||||
|
|
||||||
def __init__(self, model, input_record, name='batch_lr_loss', **kwargs):
|
def __init__(self, model, input_record, name='batch_lr_loss',
|
||||||
|
average_loss=True, **kwargs):
|
||||||
super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
|
super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
|
||||||
|
|
||||||
schema.is_schema_subset(
|
self.average_loss = average_loss
|
||||||
|
|
||||||
|
assert schema.is_schema_subset(
|
||||||
schema.Struct(
|
schema.Struct(
|
||||||
('label', schema.Scalar()),
|
('label', schema.Scalar()),
|
||||||
('prediction', schema.Scalar())
|
('prediction', schema.Scalar())
|
||||||
|
|
@ -46,4 +49,13 @@ class BatchLRLoss(ModelLayer):
|
||||||
[class_probabilities] + label,
|
[class_probabilities] + label,
|
||||||
net.NextScopedBlob('cross_entropy'),
|
net.NextScopedBlob('cross_entropy'),
|
||||||
)
|
)
|
||||||
net.AveragedLoss(xent, self.output_schema.field_blobs())
|
if 'weight' in self.input_record.fields:
|
||||||
|
xent = net.Mul(
|
||||||
|
[xent, self.input_record.weight()],
|
||||||
|
net.NextScopedBlob('weighted_scross_entropy'),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.average_loss:
|
||||||
|
net.AveragedLoss(xent, self.output_schema.field_blobs())
|
||||||
|
else:
|
||||||
|
net.ReduceFrontSum(xent, self.output_schema.field_blobs())
|
||||||
|
|
|
||||||
|
|
@ -635,20 +635,25 @@ def Map(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def NamedTuple(name_prefix, *fields):
|
||||||
|
return Struct(* [('%s_%d' % (name_prefix, i), field)
|
||||||
|
for i, field in enumerate(fields)])
|
||||||
|
|
||||||
|
|
||||||
def Tuple(*fields):
|
def Tuple(*fields):
|
||||||
"""
|
"""
|
||||||
Creates a Struct with default, sequential, field names of given types.
|
Creates a Struct with default, sequential, field names of given types.
|
||||||
"""
|
"""
|
||||||
return Struct(* [('field_%d' % i, field) for i, field in enumerate(fields)])
|
return NamedTuple('field', *fields)
|
||||||
|
|
||||||
|
|
||||||
def RawTuple(num_fields):
|
def RawTuple(num_fields, name_prefix='field'):
|
||||||
"""
|
"""
|
||||||
Creates a tuple of `num_field` untyped scalars.
|
Creates a tuple of `num_field` untyped scalars.
|
||||||
"""
|
"""
|
||||||
assert isinstance(num_fields, int)
|
assert isinstance(num_fields, int)
|
||||||
assert num_fields >= 0
|
assert num_fields >= 0
|
||||||
return Tuple(*([np.void] * num_fields))
|
return NamedTuple(name_prefix, *([np.void] * num_fields))
|
||||||
|
|
||||||
|
|
||||||
def from_dtype(dtype, _outer_shape=()):
|
def from_dtype(dtype, _outer_shape=()):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user