mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Move the open source version of build_ftrl to the open source directory. Because build_ftrl can use several engines, the SIMD engine is fb specific. We keep the build_ftrl in the fb/optimizers/sgd.py file. So, if the caller only uses the open source engine, it can import the open source build_ftrl. If the caller may use the SIMD engine, it needs to import the fb specific build_ftrl. Also move the tests to python directory. Reviewed By: salexspb Differential Revision: D4560384 fbshipit-source-id: 84fc915d3bbe42fd19503ef132d3277088f6fab3
28 lines
866 B
Python
28 lines
866 B
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from caffe2.python.sgd import build_sgd, build_ftrl, build_adagrad, build_adam
|
|
from caffe2.python.sgd_test_util import TestBase
|
|
from caffe2.python.test_util import TestCase
|
|
|
|
|
|
class TestSgd(TestBase, TestCase):
|
|
def build_optimizer(self, model):
|
|
build_sgd(model, base_learning_rate=0.1)
|
|
|
|
|
|
class TestFtrl(TestBase, TestCase):
|
|
def build_optimizer(self, model):
|
|
build_ftrl(model, engine=None, dedup_indices=False,
|
|
alpha=1.0, beta=0.1, lambda1=0.0, lambda2=0.0)
|
|
|
|
|
|
class TestAdagrad(TestBase, TestCase):
|
|
def build_optimizer(self, model):
|
|
build_adagrad(model, base_learning_rate=1.0)
|
|
|
|
|
|
class TestAdam(TestBase, TestCase):
|
|
def build_optimizer(self, model):
|
|
build_adam(model, base_learning_rate=0.1)
|