mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix torchrec multiprocess tests (#158159)
Summary: The new version of `get_device_tflops` imported something from testing, which imported common_utils.py, which disabled global flags. Test Plan: Fixing existing tests Rollback Plan: Differential Revision: D78192700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158159 Approved by: https://github.com/nipung90, https://github.com/huydhn
This commit is contained in:
parent
058fb1790f
commit
9cd521de4d
|
|
@ -2178,7 +2178,10 @@ def get_device_tflops(dtype: torch.dtype) -> float:
|
||||||
|
|
||||||
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
|
||||||
|
|
||||||
from torch.testing._internal.common_cuda import SM80OrLater
|
SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
|
||||||
|
8,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
|
assert dtype in (torch.float16, torch.bfloat16, torch.float32)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user