TBB: Use static partitioner to match OpenMP scheduling (#65327)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65327

Should fix https://github.com/pytorch/pytorch/issues/64571

Test Plan: Imported from OSS

Reviewed By: dagitses

Differential Revision: D31474116

Pulled By: malfet

fbshipit-source-id: 8c4264d4778c6caf58261e3f70d72decd134128d
This commit is contained in:
Peter Bell 2021-10-07 19:09:24 -07:00 committed by Facebook GitHub Bot
parent d5033410b1
commit bd9eee4e65
3 changed files with 4 additions and 6 deletions

View File

@ -41,7 +41,7 @@ inline void invoke_parallel(
eptr = std::current_exception(); eptr = std::current_exception();
} }
} }
}); }, tbb::static_partitioner{});
if (eptr) { if (eptr) {
std::rethrow_exception(eptr); std::rethrow_exception(eptr);
} }

View File

@ -11,7 +11,7 @@ import random
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
torch_to_numpy_dtype_dict, skipIfTBB, slowTest, torch_to_numpy_dtype_dict, slowTest,
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS) TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS)
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
@ -1205,7 +1205,6 @@ class TestTensorCreation(TestCase):
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d)) self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
# TODO: update to work on CUDA, too # TODO: update to work on CUDA, too
@skipIfTBB("This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
@onlyCPU @onlyCPU
def test_trilu_indices(self, device): def test_trilu_indices(self, device):
for test_args in tri_tests_args: for test_args in tri_tests_args:

View File

@ -36,7 +36,7 @@ from torch.testing._internal.common_utils import \
random_fullrank_matrix_distinct_singular_value, random_fullrank_matrix_distinct_singular_value,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_ASAN, torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, skipIfTBB) GRADCHECK_NONDET_TOL)
import torch.testing._internal.opinfo_helper as opinfo_helper import torch.testing._internal.opinfo_helper as opinfo_helper
from setuptools import distutils from setuptools import distutils
@ -7961,8 +7961,7 @@ op_db: List[OpInfo] = [
DecorateInfo( DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
'TestCommon', 'test_reference_testing' 'TestCommon', 'test_reference_testing'
), )
skipIfTBB(),
], ],
sample_inputs_func=sample_inputs_layer_norm,), sample_inputs_func=sample_inputs_layer_norm,),
OpInfo('nn.functional.pad', OpInfo('nn.functional.pad',