[7/N] Don't skip ASAN on some tests (#139675)

Follows #139565
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139675
Approved by: https://github.com/ezyang
This commit is contained in:
cyy 2024-11-05 14:01:01 +00:00 committed by PyTorch MergeBot
parent f551d90552
commit 546318e559
3 changed files with 5 additions and 4 deletions

View File

@ -35,6 +35,8 @@ Scalar item(const Tensor& self) {
#endif
Scalar _local_scalar_dense_cpu(const Tensor& self) {
// Don't use bool*, since it may take out-of-range byte as bool.
// Instead, we cast explicitly to avoid ASAN error.
if (self.scalar_type() == kBool) {
return Scalar(static_cast<bool>(*reinterpret_cast<const uint8_t*>(self.const_data_ptr<bool>())));
}

View File

@ -545,7 +545,6 @@ class TestDecomp(TestCase):
# NB: This actually overlaps with test_comprehensive, but it only
# runs on things that are definitely decomposed so it's a lot faster
# to run
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ -553,7 +552,6 @@ class TestDecomp(TestCase):
def test_quick(self, device, dtype, op):
self.do_cross_ref(device, dtype, op, run_all=False)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
@onlyNativeDeviceTypes
@skipIfCrossRef
@ -663,7 +661,6 @@ class TestDecomp(TestCase):
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@suppress_warnings
@tf32_off()
# only tests RNNs since we have py dispsatcher decomps for them

View File

@ -137,7 +137,9 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
return PyComplex_FromCComplex(
*reinterpret_cast<Py_complex*>((c10::complex<double>*)data));
case at::kBool:
return PyBool_FromLong(*(bool*)data);
// Don't use bool*, since it may take out-of-range byte as bool.
// Instead, we cast explicitly to avoid ASAN error.
return PyBool_FromLong(static_cast<bool>(*(uint8_t*)data));
case at::kBFloat16:
return PyFloat_FromDouble(
at::convert<double, at::BFloat16>(*(at::BFloat16*)data));