diff --git a/test/conftest.py b/test/conftest.py index d742430f886..078e4b3b2b8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,6 +21,16 @@ from _pytest.terminal import _get_raw_skip_reason from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin +try: + from torch.testing._internal.common_utils import parse_cmd_line_args +except ImportError: + # Temporary workaround needed until parse_cmd_line_args makes it into a nightlye because + # main / PR's tests are sometimes run against the previous day's nightly which won't + # have this function. + def parse_cmd_line_args(): + pass + + if TYPE_CHECKING: from _pytest._code.code import ReprFileLocation @@ -83,6 +93,7 @@ def pytest_addoption(parser: Parser) -> None: def pytest_configure(config: Config) -> None: + parse_cmd_line_args() xmlpath = config.option.xmlpath_reruns # Prevent opening xmllog on worker nodes (xdist). if xmlpath and not hasattr(config, "workerinput"): diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index f42aa7f8f43..f128f9c7eec 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -27,6 +27,9 @@ from torch.testing._internal.jit_utils import ( ) +assert GRAPH_EXECUTOR is not None + + @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) diff --git a/test/test_jit.py b/test/test_jit.py index 093753851f5..83407e25d0b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3,6 +3,13 @@ import torch +if __name__ == '__main__': + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + # This is how we include tests located in test/jit/... # They are included here so that they are invoked when you call `test_jit.py`, # do not run these test files directly. @@ -97,7 +104,7 @@ import torch.nn.functional as F from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ - suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ + GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -158,6 +165,7 @@ def doAutodiffCheck(testname): if "test_t_" in testname or testname == "test_t": return False + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: return False @@ -201,6 +209,7 @@ def doAutodiffCheck(testname): return testname not in test_exceptions +assert GRAPH_EXECUTOR # TODO: enable TE in PE when all tests are fixed torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index b3cf4d9bee8..dcdf78ff4b8 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -5,12 +5,17 @@ from torch.cuda.amp import autocast from typing import Optional import unittest -from test_jit import JitTestCase from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo from torch.testing import FileCheck from jit.test_models import MnistNet +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + +from test_jit import JitTestCase TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() @skipIfTorchDynamo("Not a TorchDynamo suitable test") diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 1ac7803a9d4..5446770695c 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -9,6 +9,13 @@ import torch.nn.functional as F from torch.testing import FileCheck from unittest import skipIf +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ diff --git a/test/test_jit_fuser_legacy.py b/test/test_jit_fuser_legacy.py index 3bd8c9497ce..4100bcc3e18 100644 --- a/test/test_jit_fuser_legacy.py +++ b/test/test_jit_fuser_legacy.py @@ -2,6 +2,14 @@ import sys sys.argv.append("--jit-executor=legacy") + +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from test_jit_fuser import * # noqa: F403 if __name__ == '__main__': diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3e26d37da1..1bda41f7f8f 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -22,6 +22,13 @@ from torch.testing import FileCheck torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from itertools import combinations, permutations, product from textwrap import dedent diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 5576f164534..480b57a55bd 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,7 +2,14 @@ import sys sys.argv.append("--jit-executor=legacy") -from test_jit import * # noqa: F403 +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests + +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + +from test_jit import * # noqa: F403, F401 if __name__ == '__main__': run_tests() diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index b445f4ad853..22767567ac1 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -33,6 +33,7 @@ import torch.nn as nn from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch._logging._internal import trace_log +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( FILE_SCHEMA, find_free_port, @@ -772,7 +773,12 @@ class MultiProcessTestCase(TestCase): process = proc( target=self.__class__._run, name="process " + str(rank), - args=(rank, self._current_test_name(), self.file_name, child_conn), + args=( + rank, + self._current_test_name(), + self.file_name, + child_conn, + ), kwargs={ "fake_pg": getattr(self, "fake_pg", False), }, @@ -849,6 +855,7 @@ class MultiProcessTestCase(TestCase): torch._C._set_print_stack_traces_on_fatal_signal(True) # Show full C++ stacktraces when a Python error originating from C++ is raised. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + common_utils.set_rng_seed() # self.id() == e.g. '__main__.TestDistributed.test_get_rank' # We're retrieving a corresponding test and executing it. @@ -1670,6 +1677,10 @@ class MultiProcContinuousTest(TestCase): self.rank = cls.rank self.world_size = cls.world_size test_fn = getattr(self, test_name) + + # Ensure all the ranks use the same seed. + common_utils.set_rng_seed() + # Run the test function test_fn(**kwargs) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index c7274fddd6d..1d8f7470297 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -58,6 +58,7 @@ from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_utils import ( FILE_SCHEMA, get_cycles_per_ms, + set_rng_seed, TEST_CUDA, TEST_HPU, TEST_XPU, @@ -1228,6 +1229,7 @@ class FSDPTest(MultiProcessTestCase): dist.barrier(device_ids=device_ids) torch._dynamo.reset() + set_rng_seed() self.run_test(test_name, pipe) torch._dynamo.reset() diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 135cc6a7bd6..3b8f277ceef 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -15,6 +15,7 @@ import torch.cuda import torch.nn as nn import torch.nn.functional as F from torch.nn import _reduction as _Reduction +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater @@ -1078,6 +1079,7 @@ def single_batch_reference_fn(input, parameters, module): def get_new_module_tests(): + common_utils.set_rng_seed() new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 93a6352831f..be1e30d0f18 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -101,9 +101,35 @@ except ImportError: has_pytest = False +SEED = 1234 MI300_ARCH = ("gfx942",) MI200_ARCH = ("gfx90a") +class ProfilingMode(Enum): + LEGACY = 1 + SIMPLE = 2 + PROFILING = 3 + +# Set by parse_cmd_line_args() if called +CI_FUNCTORCH_ROOT = "" +CI_PT_ROOT = "" +CI_TEST_PREFIX = "" +DISABLED_TESTS_FILE = "" +GRAPH_EXECUTOR : Optional[ProfilingMode] = None +LOG_SUFFIX = "" +PYTEST_SINGLE_TEST = "" +REPEAT_COUNT = 0 +RERUN_DISABLED_TESTS = False +RUN_PARALLEL = 0 +SHOWLOCALS = False +SLOW_TESTS_FILE = "" +TEST_BAILOUTS = False +TEST_DISCOVER = False +TEST_IN_SUBPROCESS = False +TEST_SAVE_XML = "" +UNITTEST_ARGS : list[str] = [] +USE_PYTEST = False + def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -838,11 +864,6 @@ class decorateIf(_TestParametrizer): yield (test_wrapper, test_name, {}, decorator_fn) -class ProfilingMode(Enum): - LEGACY = 1 - SIMPLE = 2 - PROFILING = 3 - def cppProfilingFlagsToProfilingMode(): old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -861,6 +882,7 @@ def cppProfilingFlagsToProfilingMode(): def enable_profiling_mode_for_profiling_tests(): old_prof_exec_state = False old_prof_mode_state = False + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -895,6 +917,7 @@ meth_call = torch._C.ScriptMethod.__call__ def prof_callable(callable, *args, **kwargs): if 'profile_and_replay' in kwargs: del kwargs['profile_and_replay'] + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: with enable_profiling_mode_for_profiling_tests(): callable(*args, **kwargs) @@ -924,72 +947,91 @@ def _get_test_report_path(): test_source = override if override is not None else 'python-unittest' return os.path.join('test-reports', test_source) -is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") -parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) -parser.add_argument('--subprocess', action='store_true', - help='whether to run each test in a subprocess') -parser.add_argument('--seed', type=int, default=1234) -parser.add_argument('--accept', action='store_true') -parser.add_argument('--jit-executor', '--jit_executor', type=str) -parser.add_argument('--repeat', type=int, default=1) -parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') -parser.add_argument('--use-pytest', action='store_true') -parser.add_argument('--save-xml', nargs='?', type=str, - const=_get_test_report_path(), - default=_get_test_report_path() if IS_CI else None) -parser.add_argument('--discover-tests', action='store_true') -parser.add_argument('--log-suffix', type=str, default="") -parser.add_argument('--run-parallel', type=int, default=1) -parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) -parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) -parser.add_argument('--rerun-disabled-tests', action='store_true') -parser.add_argument('--pytest-single-test', type=str, nargs=1) -parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) +def parse_cmd_line_args(): + global CI_FUNCTORCH_ROOT + global CI_PT_ROOT + global CI_TEST_PREFIX + global DISABLED_TESTS_FILE + global GRAPH_EXECUTOR + global LOG_SUFFIX + global PYTEST_SINGLE_TEST + global REPEAT_COUNT + global RERUN_DISABLED_TESTS + global RUN_PARALLEL + global SHOWLOCALS + global SLOW_TESTS_FILE + global TEST_BAILOUTS + global TEST_DISCOVER + global TEST_IN_SUBPROCESS + global TEST_SAVE_XML + global UNITTEST_ARGS + global USE_PYTEST + + is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") + parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) + parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') + parser.add_argument('--accept', action='store_true') + parser.add_argument('--jit-executor', '--jit_executor', type=str) + parser.add_argument('--repeat', type=int, default=1) + parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') + parser.add_argument('--use-pytest', action='store_true') + parser.add_argument('--save-xml', nargs='?', type=str, + const=_get_test_report_path(), + default=_get_test_report_path() if IS_CI else None) + parser.add_argument('--discover-tests', action='store_true') + parser.add_argument('--log-suffix', type=str, default="") + parser.add_argument('--run-parallel', type=int, default=1) + parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) + parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) + parser.add_argument('--rerun-disabled-tests', action='store_true') + parser.add_argument('--pytest-single-test', type=str, nargs=1) + parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) # Only run when -h or --help flag is active to display both unittest and parser help messages. -def run_unittest_help(argv): - unittest.main(argv=argv) + def run_unittest_help(argv): + unittest.main(argv=argv) -if '-h' in sys.argv or '--help' in sys.argv: - help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) - help_thread.start() - help_thread.join() + if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() -args, remaining = parser.parse_known_args() -if args.jit_executor == 'legacy': - GRAPH_EXECUTOR = ProfilingMode.LEGACY -elif args.jit_executor == 'profiling': - GRAPH_EXECUTOR = ProfilingMode.PROFILING -elif args.jit_executor == 'simple': - GRAPH_EXECUTOR = ProfilingMode.SIMPLE -else: - # infer flags based on the default settings - GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() + args, remaining = parser.parse_known_args() + if args.jit_executor == 'legacy': + GRAPH_EXECUTOR = ProfilingMode.LEGACY + elif args.jit_executor == 'profiling': + GRAPH_EXECUTOR = ProfilingMode.PROFILING + elif args.jit_executor == 'simple': + GRAPH_EXECUTOR = ProfilingMode.SIMPLE + else: + # infer flags based on the default settings + GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() -RERUN_DISABLED_TESTS = args.rerun_disabled_tests + RERUN_DISABLED_TESTS = args.rerun_disabled_tests -SLOW_TESTS_FILE = args.import_slow_tests -DISABLED_TESTS_FILE = args.import_disabled_tests -LOG_SUFFIX = args.log_suffix -RUN_PARALLEL = args.run_parallel -TEST_BAILOUTS = args.test_bailouts -USE_PYTEST = args.use_pytest -PYTEST_SINGLE_TEST = args.pytest_single_test -TEST_DISCOVER = args.discover_tests -TEST_IN_SUBPROCESS = args.subprocess -TEST_SAVE_XML = args.save_xml -REPEAT_COUNT = args.repeat -SEED = args.seed -SHOWLOCALS = args.showlocals -if not getattr(expecttest, "ACCEPT", False): - expecttest.ACCEPT = args.accept -UNITTEST_ARGS = [sys.argv[0]] + remaining -torch.manual_seed(SEED) + SLOW_TESTS_FILE = args.import_slow_tests + DISABLED_TESTS_FILE = args.import_disabled_tests + LOG_SUFFIX = args.log_suffix + RUN_PARALLEL = args.run_parallel + TEST_BAILOUTS = args.test_bailouts + USE_PYTEST = args.use_pytest + PYTEST_SINGLE_TEST = args.pytest_single_test + TEST_DISCOVER = args.discover_tests + TEST_IN_SUBPROCESS = args.subprocess + TEST_SAVE_XML = args.save_xml + REPEAT_COUNT = args.repeat + SHOWLOCALS = args.showlocals + if not getattr(expecttest, "ACCEPT", False): + expecttest.ACCEPT = args.accept + UNITTEST_ARGS = [sys.argv[0]] + remaining + + set_rng_seed() # CI Prefix path used only on CI environment -CI_TEST_PREFIX = str(Path(os.getcwd())) -CI_PT_ROOT = str(Path(os.getcwd()).parent) -CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) + CI_TEST_PREFIX = str(Path(os.getcwd())) + CI_PT_ROOT = str(Path(os.getcwd()).parent) + CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p, timeout=None): try: @@ -1138,7 +1180,9 @@ def lint_test_case_extension(suite): return succeed -def get_report_path(argv=UNITTEST_ARGS, pytest=False): +def get_report_path(argv=None, pytest=False): + if argv is None: + argv = UNITTEST_ARGS test_filename = sanitize_test_filename(argv[0]) test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = os.path.join(test_report_path, test_filename) @@ -1189,7 +1233,11 @@ def get_pytest_test_cases(argv: list[str]) -> list[str]: return test_collector_plugin.tests -def run_tests(argv=UNITTEST_ARGS): +def run_tests(argv=None): + parse_cmd_line_args() + if argv is None: + argv = UNITTEST_ARGS + # import test files. if SLOW_TESTS_FILE: if os.path.exists(SLOW_TESTS_FILE): @@ -1759,6 +1807,7 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe if not isinstance(fn, type): @wraps(fn) def wrapper(*args, **kwargs): + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.LEGACY: raise unittest.SkipTest(msg) else: @@ -2379,7 +2428,9 @@ def get_function_arglist(func): return inspect.getfullargspec(func).args -def set_rng_seed(seed): +def set_rng_seed(seed=None): + if seed is None: + seed = SEED torch.manual_seed(seed) random.seed(seed) if TEST_NUMPY: @@ -3402,7 +3453,7 @@ class TestCase(expecttest.TestCase): def setUp(self): check_if_enable(self) - set_rng_seed(SEED) + set_rng_seed() # Save global check sparse tensor invariants state that can be # restored from tearDown: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index c4701432d81..cc47b91db54 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,16 +137,18 @@ class Foo: f = Foo(10) f.bar = 1 -foo_cpu_tensor = Foo(torch.randn(3, 3)) +# Defer instantiation until the seed is set so that randn() returns the same +# values in all processes. +def create_collectives_object_test_list(): + return [ + {"key1": 3, "key2": 4, "key3": {"nested": True}}, + f, + Foo(torch.randn(3, 3)), + "foo", + [1, 2, True, "string", [4, 5, "nested"]], + ] -COLLECTIVES_OBJECT_TEST_LIST = [ - {"key1": 3, "key2": 4, "key3": {"nested": True}}, - f, - foo_cpu_tensor, - "foo", - [1, 2, True, "string", [4, 5, "nested"]], -] # Allowlist of distributed backends where profiling collectives is supported. PROFILING_SUPPORTED_BACKENDS = [ @@ -396,12 +398,6 @@ class ControlFlowToyModel(nn.Module): return F.relu(self.lin1(x)) -DDP_NET = Net() -BN_NET = BatchNormNet() -BN_NET_NO_AFFINE = BatchNormNet(affine=False) -ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) - - def get_timeout(test_id): test_name = test_id.split(".")[-1] if test_name in CUSTOMIZED_TIMEOUT: @@ -4293,7 +4289,7 @@ class DistributedTest: # as baseline # cpu training setup - model = DDP_NET + model = Net() # single gpu training setup model_gpu = copy.deepcopy(model) @@ -4348,7 +4344,7 @@ class DistributedTest: _group, _group_id, rank = self._init_global_test() # cpu training setup - model_base = DDP_NET + model_base = Net() # DDP-CPU training setup model_DDP = copy.deepcopy(model_base) @@ -5497,7 +5493,7 @@ class DistributedTest: def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): torch.manual_seed(31415) # Creates model and optimizer in default precision - model = copy.deepcopy(DDP_NET).cuda() + model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.03) # Creates a GradScaler once at the beginning of training. @@ -5582,7 +5578,7 @@ class DistributedTest: # as baseline # cpu training setup - model = BN_NET if affine else BN_NET_NO_AFFINE + model = BatchNormNet() if affine else BatchNormNet(affine=False) # single gpu training setup model_gpu = copy.deepcopy(model) @@ -5632,6 +5628,7 @@ class DistributedTest: def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): learning_rate = 0.03 + DDP_NET = Net() net = torch.nn.parallel.DistributedDataParallel( copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank], @@ -5698,7 +5695,7 @@ class DistributedTest: learning_rate = 0.03 net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( - copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank] + Net().cuda(), device_ids=[self.rank] ) averager = create_averager() @@ -5848,7 +5845,7 @@ class DistributedTest: bs_offset = int(rank * 2) global_bs = int(num_processes * 2) - model = ONLY_SBN_NET + model = nn.SyncBatchNorm(2, momentum=0.99) model_gpu = copy.deepcopy(model).cuda(rank) model_DDP = nn.parallel.DistributedDataParallel( model_gpu, device_ids=[rank] @@ -6058,6 +6055,7 @@ class DistributedTest: def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( self, ): + ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) _group, _group_id, rank = self._init_global_test() model = nn.parallel.DistributedDataParallel( ONLY_SBN_NET.cuda(rank), device_ids=[rank] @@ -6125,7 +6123,7 @@ class DistributedTest: def test_DistributedDataParallel_SyncBatchNorm_half(self): _group, _group_id, rank = self._init_global_test() - model = copy.deepcopy(BN_NET) + model = BatchNormNet() model = model.half() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.parallel.DistributedDataParallel( @@ -6141,7 +6139,7 @@ class DistributedTest: def _test_ddp_logging_data(self, is_gpu): rank = dist.get_rank() - model_DDP = copy.deepcopy(DDP_NET) + model_DDP = Net() if is_gpu: model_DDP = nn.parallel.DistributedDataParallel( model_DDP.cuda(rank), device_ids=[rank] @@ -6417,7 +6415,7 @@ class DistributedTest: BACKEND == "nccl", "nccl does not support DDP on CPU models" ) def test_static_graph_api_cpu(self): - model_DDP = nn.parallel.DistributedDataParallel(DDP_NET) + model_DDP = nn.parallel.DistributedDataParallel(Net()) expected_err = "should be called before training loop starts" with self.assertRaisesRegex(RuntimeError, expected_err): local_bs = 2 @@ -6650,7 +6648,7 @@ class DistributedTest: def _test_allgather_object(self, subgroup=None): # Only set device for NCCL backend since it must use GPUs. - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() backend = os.environ["BACKEND"] if backend == "nccl": @@ -6694,7 +6692,7 @@ class DistributedTest: def _test_gather_object(self, pg=None): # Ensure stateful objects can be gathered - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() my_rank = dist.get_rank(pg) backend = os.environ["BACKEND"] @@ -7264,7 +7262,7 @@ class DistributedTest: return x torch.cuda.set_device(self.rank) - model_bn = BN_NET + model_bn = BatchNormNet() model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( copy.deepcopy(model_bn) ).cuda(self.rank) @@ -7560,7 +7558,7 @@ class DistributedTest: loss.backward() def _test_broadcast_object_list(self, group=None): - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() # Only set device for NCCL backend since it must use GPUs. # Case where rank != GPU device. @@ -8284,10 +8282,11 @@ class DistributedTest: @require_backend_is_available({"gloo"}) def test_scatter_object_list(self): src_rank = 0 + collectives_object_test_list = create_collectives_object_test_list() scatter_list = ( - COLLECTIVES_OBJECT_TEST_LIST + collectives_object_test_list if self.rank == src_rank - else [None for _ in COLLECTIVES_OBJECT_TEST_LIST] + else [None for _ in collectives_object_test_list] ) world_size = dist.get_world_size() scatter_list = scatter_list[:world_size] @@ -8300,8 +8299,8 @@ class DistributedTest: dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) self.assertEqual( output_obj_list[0], - COLLECTIVES_OBJECT_TEST_LIST[ - self.rank % len(COLLECTIVES_OBJECT_TEST_LIST) + collectives_object_test_list[ + self.rank % len(collectives_object_test_list) ], ) # Ensure errors are raised upon incorrect arguments. @@ -9987,7 +9986,7 @@ class DistributedTest: "Only Nccl & Gloo backend support DistributedDataParallel", ) def test_sync_bn_logged(self): - model = BN_NET + model = BatchNormNet() rank = self.rank # single gpu training setup model_gpu = model.cuda(rank)