pytorch/caffe2/python/operator_test/adagrad_test.py
Lu Fang 664fe34e0a
[Caffe2][fbcode=>GH sync] Update from facebook 4323b18ce13c (#7116)
* [fix] Re-enable events in RNN ops

We have earlier added event disabling in RNN ops as back then we didn't use
events, with current use cases this is no longer true
(https://fburl.com/8vd0lp8y)

* use ops with cude impl

* Revert D7729695: [caffe2][fix] Re-enable events in RNN ops

This reverts commit 4b215c7496fb724656ff4c776933a15bdbbcde5e

@bypass-lint

An infra SEV is better than not reverting this diff.
If you copy this password, see you in SEV Review!
@cause_a_sev_many_files

* [observer] Clean up observer_config.h

#accept2ship

* [1/n] Refactor dataio_test.py

Replace code duplication with a common function

* Add barrier net that runs before training nets

Add a synchonize barrier net that is run before training nets.  With this net, shards that are faster will wait for other shards before start training.  This reduce chances of the faster shards timing out during GLOO AllReduce.

Removed explicit data_parallel_model.py.synchronize call in holmes workflow.  Similar change in speech/asr_training workflow will come in another diff.

* Support the dnnlowp backend in caffe2_benchmark

This is for SHARE operator latency evaluation

* Migrate integral_image_op to main caffe2

migrate integral_image_op(GPU version) given by https://fburl.com/yvqezigi
to caffe2/caffe2/operators and implement its CPU version. Write up a test
using the hypothesis_test mechanism

* [pos_disc, fbcode] Implement unjoined lr loss

As explained in https://our.intern.facebook.com/intern/wiki/Model_Based_Calibration/, when the dataset is an joined data set, where labels might change later, we need to use unjoined logloss.

The implementation is almost the same as in Sigrid (https://fburl.com/1trngsls), where
    loss = y (log(p) - log(1-p)) + (1-y)(log(1-p)) = xy - (1-y)x - (1-y)log(1+exp(-x))

For x < 0, to ensure stability and avoid overflow, we reformulate the above exp as
    loss = xy - (1-y)x - (1-y)x + (1-y)log(1+exp(x)) = xy + (1-y)log(1+exp(x))

Then the final expression becomes
    loss = xy + (y - 1) x (x >= 0) - (1 - y) log(1 + exp(x - 2 x (x >= 0)))

where y is the true label, x is the dot product and p = logistic(x).

This kind of implementation is align with the current implementation of the original cross entropy in
https://phabricator.intern.facebook.com/diffusion/FBS/browse/master/fbcode/caffe2/caffe2/operators/cross_entropy_op.cc;0bae3b5d0f825897c5e0dd0ff10f489d7271bf25$7-13

* Keep the array to fix the conflict

* [C2] Compute Adagrad effective LR

The AdagradWithLR op outputs an extra blob which is contains the average effective learning rate across all weights in this blob.

* Open-source extractMetaNetDef & runGlobalInitialization, add new Predictor constructor from db file, and add run_map_outputs

1. Open-source extractMetaNetDef and runGlobalInitialization, for use in
2. new Predictor constructor from db file.
3. Add new run function that returns outputs as TensorMap

* Disable eigen cpu

Disable eigen cpu in transpose and reduce

* Introduce request_only/object_only property of ModelLayer

by default this is False

* A simple TC Caffe2 benchmark

We can run tunner, get MappingOptions and then use them to
compare against cuBLAS

currently broken due to LLVM issues. How to run:

hg checkout eec1ab31b59c03b8deded1c755a9abaf8c45be01
add D7401202
add D7434625
add D7506031
add D7540728

buck run @mode/dev-nosan tc/tc/benchmarks_python:caffe2_benchmark

* Move Caffe2 feature_maps_ops to open source

Need feature maps operators in open source project facebookresearch/BlueWhale

* Manually fix the conflicts in channel shuffle op

* Fix the inconsistency between different gh and fbcode

* Skip Adagrad GPU Test (Because some gpu implementation is missing)

* Fix another test to make sure it won't run on gpu when implementation is not available yet
2018-05-01 20:49:00 -07:00

356 lines
14 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import functools
import hypothesis
from hypothesis import given, settings, HealthCheck
import hypothesis.strategies as st
import numpy as np
from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import unittest
class TestAdagrad(hu.HypothesisTestCase):
@staticmethod
def ref_adagrad(param_in, mom_in, grad, lr, epsilon, using_fp16=False,
output_effective_lr=False,
output_effective_lr_and_update=False):
mom_in_f32 = mom_in
param_in_f32 = param_in
if(using_fp16):
mom_in_f32 = mom_in.astype(np.float32)
param_in_f32 = param_in.astype(np.float32)
mom_out = mom_in_f32 + np.square(grad)
effective_lr = lr / (np.sqrt(mom_out) + epsilon)
grad_adj = effective_lr * grad
param_out = param_in_f32 + grad_adj
if output_effective_lr_and_update:
if(using_fp16):
return (param_out.astype(np.float16), mom_out.astype(np.float16),
effective_lr.astype(np.float16),
grad_adj.astype(np.float16))
else:
return (param_out.astype(np.float32), mom_out.astype(np.float32),
effective_lr.astype(np.float32),
grad_adj.astype(np.float32))
elif output_effective_lr:
if(using_fp16):
return (param_out.astype(np.float16), mom_out.astype(np.float16),
effective_lr.astype(np.float16))
else:
return (param_out.astype(np.float32), mom_out.astype(np.float32),
effective_lr.astype(np.float32))
if(using_fp16):
return (param_out.astype(np.float16), mom_out.astype(np.float16))
else:
return (param_out.astype(np.float32), mom_out.astype(np.float32))
@staticmethod
def ref_row_wise_adagrad(param_in, mom_in, grad, lr, epsilon):
mom_out = mom_in + np.mean(np.square(grad))
grad_adj = lr * grad / (np.sqrt(mom_out) + epsilon)
param_out = param_in + grad_adj
return (param_out, mom_out)
@given(inputs=hu.tensors(n=3),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
**hu.gcs)
def test_adagrad(self, inputs, lr, epsilon, gc, dc):
param, momentum, grad = inputs
lr = np.array([lr], dtype=np.float32)
op = core.CreateOperator(
"Adagrad",
["param", "momentum", "grad", "lr"],
["param", "momentum"],
epsilon=epsilon,
device_option=gc,
)
self.assertReferenceChecks(
gc, op,
[param, momentum, grad, lr],
functools.partial(self.ref_adagrad, epsilon=epsilon))
@given(inputs=hu.tensors(n=3),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
**hu.gcs_cpu_only)
def test_adagrad_output_effective_lr(self, inputs, lr, epsilon, gc, dc):
param, momentum, grad = inputs
lr = np.array([lr], dtype=np.float32)
op = core.CreateOperator(
"Adagrad",
["param", "momentum", "grad", "lr"],
["param", "momentum", "effective_lr"],
epsilon=epsilon,
device_option=gc,
)
self.assertReferenceChecks(
gc, op,
[param, momentum, grad, lr],
functools.partial(self.ref_adagrad, epsilon=epsilon,
output_effective_lr=True))
@given(inputs=hu.tensors(n=3),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
**hu.gcs_cpu_only)
def test_adagrad_output_effective_lr_and_update(
self, inputs, lr, epsilon, gc, dc):
param, momentum, grad = inputs
lr = np.array([lr], dtype=np.float32)
op = core.CreateOperator(
"Adagrad",
["param", "momentum", "grad", "lr"],
["param", "momentum", "effective_lr", "update"],
epsilon=epsilon,
device_option=gc,
)
self.assertReferenceChecks(
gc, op,
[param, momentum, grad, lr],
functools.partial(self.ref_adagrad, epsilon=epsilon,
output_effective_lr_and_update=True))
# Suppress filter_too_much health check.
# Likely caused by `assume` call falling through too often.
@settings(suppress_health_check=[HealthCheck.filter_too_much])
@given(inputs=hu.tensors(n=3),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
data_strategy=st.data(),
**hu.gcs)
def test_sparse_adagrad(self, inputs, lr, epsilon,
data_strategy, gc, dc):
param, momentum, grad = inputs
momentum = np.abs(momentum)
lr = np.array([lr], dtype=np.float32)
# Create an indexing array containing values that are lists of indices,
# which index into grad
indices = data_strategy.draw(
hu.tensor(dtype=np.int64,
elements=st.sampled_from(np.arange(grad.shape[0]))),
)
hypothesis.note('indices.shape: %s' % str(indices.shape))
# For now, the indices must be unique
hypothesis.assume(np.array_equal(np.unique(indices.flatten()),
np.sort(indices.flatten())))
# Sparsify grad
grad = grad[indices]
op = core.CreateOperator(
"SparseAdagrad",
["param", "momentum", "indices", "grad", "lr"],
["param", "momentum"],
epsilon=epsilon,
device_option=gc)
def ref_sparse(param, momentum, indices, grad, lr, ref_using_fp16=False):
param_out = np.copy(param)
momentum_out = np.copy(momentum)
for i, index in enumerate(indices):
param_out[index], momentum_out[index] = self.ref_adagrad(
param[index],
momentum[index],
grad[i],
lr,
epsilon,
using_fp16=ref_using_fp16
)
return (param_out, momentum_out)
ref_using_fp16_values = [False]
if dc == hu.gpu_do:
ref_using_fp16_values.append(True)
for ref_using_fp16 in ref_using_fp16_values:
if(ref_using_fp16):
print('test_sparse_adagrad with half precision embedding')
momentum_i = momentum.astype(np.float16)
param_i = param.astype(np.float16)
else:
print('test_sparse_adagrad with full precision embedding')
momentum_i = momentum.astype(np.float32)
param_i = param.astype(np.float32)
self.assertReferenceChecks(
gc, op, [param_i, momentum_i, indices, grad, lr, ref_using_fp16],
ref_sparse
)
@given(inputs=hu.tensors(n=2),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
data_strategy=st.data(),
**hu.gcs)
def test_sparse_adagrad_empty(self, inputs, lr, epsilon,
data_strategy, gc, dc):
param, momentum = inputs
momentum = np.abs(momentum)
lr = np.array([lr], dtype=np.float32)
grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
indices = np.empty(shape=(0,), dtype=np.int64)
hypothesis.note('indices.shape: %s' % str(indices.shape))
op = core.CreateOperator(
"SparseAdagrad",
["param", "momentum", "indices", "grad", "lr"],
["param", "momentum"],
epsilon=epsilon,
device_option=gc)
def ref_sparse(param, momentum, indices, grad, lr):
param_out = np.copy(param)
momentum_out = np.copy(momentum)
return (param_out, momentum_out)
ref_using_fp16_values = [False]
if dc == hu.gpu_do:
ref_using_fp16_values.append(True)
for ref_using_fp16 in ref_using_fp16_values:
if(ref_using_fp16):
print('test_sparse_adagrad_empty with half precision embedding')
momentum_i = momentum.astype(np.float16)
param_i = param.astype(np.float16)
else:
print('test_sparse_adagrad_empty with full precision embedding')
momentum_i = momentum.astype(np.float32)
param_i = param.astype(np.float32)
self.assertReferenceChecks(
gc, op, [param_i, momentum_i, indices, grad, lr], ref_sparse
)
# Suppress filter_too_much health check.
# Likely caused by `assume` call falling through too often.
@settings(suppress_health_check=[HealthCheck.filter_too_much])
@given(inputs=hu.tensors(n=2),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
data_strategy=st.data(),
**hu.gcs)
def test_row_wise_sparse_adagrad(self, inputs, lr, epsilon,
data_strategy, gc, dc):
param, grad = inputs
lr = np.array([lr], dtype=np.float32)
# Create a 1D row-wise average sum of squared gradients tensor.
momentum = data_strategy.draw(
hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
elements=hu.elements_of_type(dtype=np.float32))
)
momentum = np.abs(momentum)
# Create an indexing array containing values which index into grad
indices = data_strategy.draw(
hu.tensor(dtype=np.int64,
elements=st.sampled_from(np.arange(grad.shape[0]))),
)
# Note that unlike SparseAdagrad, RowWiseSparseAdagrad uses a moment
# tensor that is strictly 1-dimensional and equal in length to the
# first dimension of the parameters, so indices must also be
# 1-dimensional.
indices = indices.flatten()
hypothesis.note('indices.shape: %s' % str(indices.shape))
# The indices must be unique
hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))
# Sparsify grad
grad = grad[indices]
op = core.CreateOperator(
"RowWiseSparseAdagrad",
["param", "momentum", "indices", "grad", "lr"],
["param", "momentum"],
epsilon=epsilon,
device_option=gc)
def ref_row_wise_sparse(param, momentum, indices, grad, lr):
param_out = np.copy(param)
momentum_out = np.copy(momentum)
for i, index in enumerate(indices):
param_out[index], momentum_out[index] = self.ref_row_wise_adagrad(
param[index], momentum[index], grad[i], lr, epsilon)
return (param_out, momentum_out)
self.assertReferenceChecks(
gc, op,
[param, momentum, indices, grad, lr],
ref_row_wise_sparse)
@given(inputs=hu.tensors(n=1),
lr=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
epsilon=st.floats(min_value=0.01, max_value=0.99,
allow_nan=False, allow_infinity=False),
data_strategy=st.data(),
**hu.gcs)
def test_row_wise_sparse_adagrad_empty(self, inputs, lr, epsilon,
data_strategy, gc, dc):
param = inputs[0]
lr = np.array([lr], dtype=np.float32)
momentum = data_strategy.draw(
hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
elements=hu.elements_of_type(dtype=np.float32))
)
momentum = np.abs(momentum)
grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
indices = np.empty(shape=(0,), dtype=np.int64)
hypothesis.note('indices.shape: %s' % str(indices.shape))
op = core.CreateOperator(
"RowWiseSparseAdagrad",
["param", "momentum", "indices", "grad", "lr"],
["param", "momentum"],
epsilon=epsilon,
device_option=gc)
def ref_row_wise_sparse(param, momentum, indices, grad, lr):
param_out = np.copy(param)
momentum_out = np.copy(momentum)
return (param_out, momentum_out)
self.assertReferenceChecks(
gc, op,
[param, momentum, indices, grad, lr],
ref_row_wise_sparse)