mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
TBB: Use static partitioner to match OpenMP scheduling (#65327)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65327 Should fix https://github.com/pytorch/pytorch/issues/64571 Test Plan: Imported from OSS Reviewed By: dagitses Differential Revision: D31474116 Pulled By: malfet fbshipit-source-id: 8c4264d4778c6caf58261e3f70d72decd134128d
This commit is contained in:
parent
d5033410b1
commit
bd9eee4e65
|
|
@ -41,7 +41,7 @@ inline void invoke_parallel(
|
|||
eptr = std::current_exception();
|
||||
}
|
||||
}
|
||||
});
|
||||
}, tbb::static_partitioner{});
|
||||
if (eptr) {
|
||||
std::rethrow_exception(eptr);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import random
|
|||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||
torch_to_numpy_dtype_dict, skipIfTBB, slowTest,
|
||||
torch_to_numpy_dtype_dict, slowTest,
|
||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
|
||||
|
|
@ -1205,7 +1205,6 @@ class TestTensorCreation(TestCase):
|
|||
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
|
||||
|
||||
# TODO: update to work on CUDA, too
|
||||
@skipIfTBB("This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
|
||||
@onlyCPU
|
||||
def test_trilu_indices(self, device):
|
||||
for test_args in tri_tests_args:
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from torch.testing._internal.common_utils import \
|
|||
random_fullrank_matrix_distinct_singular_value,
|
||||
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
|
||||
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
|
||||
GRADCHECK_NONDET_TOL, skipIfTBB)
|
||||
GRADCHECK_NONDET_TOL)
|
||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
|
||||
from setuptools import distutils
|
||||
|
|
@ -7961,8 +7961,7 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
|
||||
'TestCommon', 'test_reference_testing'
|
||||
),
|
||||
skipIfTBB(),
|
||||
)
|
||||
],
|
||||
sample_inputs_func=sample_inputs_layer_norm,),
|
||||
OpInfo('nn.functional.pad',
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user