mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
TBB: Use static partitioner to match OpenMP scheduling (#65327)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65327 Should fix https://github.com/pytorch/pytorch/issues/64571 Test Plan: Imported from OSS Reviewed By: dagitses Differential Revision: D31474116 Pulled By: malfet fbshipit-source-id: 8c4264d4778c6caf58261e3f70d72decd134128d
This commit is contained in:
parent
d5033410b1
commit
bd9eee4e65
|
|
@ -41,7 +41,7 @@ inline void invoke_parallel(
|
||||||
eptr = std::current_exception();
|
eptr = std::current_exception();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
}, tbb::static_partitioner{});
|
||||||
if (eptr) {
|
if (eptr) {
|
||||||
std::rethrow_exception(eptr);
|
std::rethrow_exception(eptr);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import random
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||||
torch_to_numpy_dtype_dict, skipIfTBB, slowTest,
|
torch_to_numpy_dtype_dict, slowTest,
|
||||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS)
|
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS)
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
|
instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
|
||||||
|
|
@ -1205,7 +1205,6 @@ class TestTensorCreation(TestCase):
|
||||||
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
|
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
|
||||||
|
|
||||||
# TODO: update to work on CUDA, too
|
# TODO: update to work on CUDA, too
|
||||||
@skipIfTBB("This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_trilu_indices(self, device):
|
def test_trilu_indices(self, device):
|
||||||
for test_args in tri_tests_args:
|
for test_args in tri_tests_args:
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from torch.testing._internal.common_utils import \
|
||||||
random_fullrank_matrix_distinct_singular_value,
|
random_fullrank_matrix_distinct_singular_value,
|
||||||
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
|
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
|
||||||
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
|
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
|
||||||
GRADCHECK_NONDET_TOL, skipIfTBB)
|
GRADCHECK_NONDET_TOL)
|
||||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||||
|
|
||||||
from setuptools import distutils
|
from setuptools import distutils
|
||||||
|
|
@ -7961,8 +7961,7 @@ op_db: List[OpInfo] = [
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
|
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
|
||||||
'TestCommon', 'test_reference_testing'
|
'TestCommon', 'test_reference_testing'
|
||||||
),
|
)
|
||||||
skipIfTBB(),
|
|
||||||
],
|
],
|
||||||
sample_inputs_func=sample_inputs_layer_norm,),
|
sample_inputs_func=sample_inputs_layer_norm,),
|
||||||
OpInfo('nn.functional.pad',
|
OpInfo('nn.functional.pad',
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user