mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29304 Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. ghstack-source-id: 93564364 Test Plan: unit tests. Differential Revision: D18354586 fbshipit-source-id: 85d4c8bfec4aa38d2863cda704d024692511cff5
19 lines
604 B
Python
Executable File
19 lines
604 B
Python
Executable File
#!/usr/bin/env python3
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
from dist_optimizer_test import DistOptimizerTest
|
|
from common_distributed import MultiProcessTestCase
|
|
from common_utils import TEST_WITH_ASAN, run_tests
|
|
|
|
import unittest
|
|
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues")
|
|
class DistOptimizerTestWithSpawn(MultiProcessTestCase, DistOptimizerTest):
|
|
|
|
def setUp(self):
|
|
super(DistOptimizerTestWithSpawn, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|