mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Hi, Team! The PR is motivated from https://github.com/pytorch/pytorch/pull/71153#discussion_r782446738. It aims to replace `get_all` type macros with the ATen dispatch macros. The files it iterates over are: (Thanks, Lezcano, for the idea!!) <details> <summary> `test/test_autograd.py`</summary> <p> ```python 43:from torch.testing._internal.common_dtype import get_all_dtypes 8506: floating_dt = [dt for dt in get_all_dtypes() if dt.is_floating_point] ``` </p> </details> <details> <summary> `test/test_binary_ufuncs.py`</summary> <p> ```python 26: all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes, 27: get_all_complex_dtypes, get_all_fp_dtypes, 935: dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) 1035: dtypes(*get_all_dtypes( 1488: dtypes(*(get_all_dtypes(include_bool=False, include_bfloat16=False))) 1879: dtypes(*product(get_all_dtypes(include_complex=False), get_all_dtypes(include_complex=False))) 1887: dtypes(*(get_all_int_dtypes() + [torch.bool])) 1913: dtypes(*(get_all_fp_dtypes())) 1941: dtypes(*(get_all_fp_dtypes())) 1977: dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) 2019: dtypes(*product(get_all_fp_dtypes(), get_all_fp_dtypes())) 2048: dtypes(*get_all_dtypes()) 2110: dtypes(*product(get_all_dtypes(include_complex=False), 2111: get_all_dtypes(include_complex=False))) 2128: types = [torch.bool, torch.bfloat16] + get_all_int_dtypes() 2173: if dtypes[1] in get_all_fp_dtypes(): 2178: dtypes(*product(get_all_fp_dtypes(), 2179: get_all_fp_dtypes())) 2260: dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) 2261: dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) 2273: dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) 2274: dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) 2307: dtypes(*get_all_math_dtypes('cpu')) 2319: dtypes(*get_all_fp_dtypes(include_bfloat16=False)) 2331: dtypes(*get_all_int_dtypes()) 2356: dtypes(*get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) 2393: if dtype in get_all_int_dtypes(): 2614: dtypes(*get_all_dtypes()) 2624: dtypes(*tuple(itertools.combinations_with_replacement(get_all_dtypes(), 2))) 2806: dtypes(*list(product(get_all_dtypes(include_complex=False), 2807: get_all_dtypes(include_complex=False)))) 2866: dtypes(*list(product(get_all_complex_dtypes(), 2867: get_all_complex_dtypes()))) 2902: dtypes(*product(get_all_dtypes(), get_all_dtypes())) 2906: dtypes(*product(get_all_dtypes(), get_all_dtypes())) 2910: dtypes(*product(get_all_dtypes(), get_all_dtypes())) 3019: dtypes = [torch.float, torch.double] + get_all_complex_dtypes() 3221: dtypes(*get_all_dtypes(include_complex=False)) 3407: dtypes(*list(product(get_all_dtypes(include_bool=False), 3408: get_all_dtypes(include_bool=False)))) 3504: dtypes(*product(get_all_dtypes(include_complex=False, include_bfloat16=False), 3505: get_all_dtypes(include_complex=False, include_bfloat16=False))) 3516: if x.dtype in get_all_int_dtypes() + [torch.bool]: 3643: dtypes(*product(get_all_dtypes(include_complex=False, 3645: get_all_dtypes(include_complex=False, ``` </p> </details> <details> <summary> `test/test_complex.py`</summary> <p> ```python 6:from torch.testing._internal.common_dtype import get_all_complex_dtypes 11: dtypes(*get_all_complex_dtypes()) ``` </p> </details> <details> <summary> `test/test_foreach.py`</summary> <p> ```python 18: get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes, 142: if dtype in get_all_int_dtypes(): 179: disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] 201: disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] 205: disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool] 211: disable_fastpath |= dtype not in get_all_complex_dtypes() 241: bool_int_div = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] 246: disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool] 248: disable_fastpath |= dtype not in get_all_complex_dtypes() 250: disable_fastpath |= True and dtype not in get_all_complex_dtypes() 307: disable_fastpath = dtype in get_all_int_dtypes() + [torch.bool] 365: if opinfo.name == "_foreach_abs" and dtype in get_all_complex_dtypes(): 376: ops(foreach_unary_op_db, dtypes=get_all_dtypes()) 393: dtypes=get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) 401: ops(foreach_minmax_op_db, dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True)) 426: if ord in (1, 2) and dtype in torch.testing.get_all_fp_dtypes(): 439: dtypes(*get_all_dtypes()) 449: ops(foreach_binary_op_db, dtypes=get_all_dtypes()) 481: ops(foreach_binary_op_db, dtypes=get_all_dtypes()) 536: if dtype in get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div: 545: ops(foreach_binary_op_db, dtypes=get_all_dtypes()) 637: ops(foreach_pointwise_op_db, allowed_dtypes=get_all_fp_dtypes(include_half=False, include_bfloat16=False)) ``` </p> </details> <details> <summary> `test/test_linalg.py`</summary> <p> ```python 29: all_types, floating_types, floating_and_complex_types, get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, 30: get_all_fp_dtypes, 111: dtypes(*(get_all_dtypes())) 794: float_and_complex_dtypes = get_all_fp_dtypes() + get_all_complex_dtypes() 807: dtypes(*(get_all_int_dtypes())) 828: dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) 841: if dtype in get_all_complex_dtypes(): 844: dtypes(*itertools.product(get_all_dtypes(), 845: get_all_dtypes())) 855: for dtypes0, dtypes1, dtypes2 in product(get_all_dtypes(), repeat=3): 5607: *get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater))) 5608: dtypes(*(set(get_all_dtypes()) - {torch.half, torch.bool})) 5644: dtypes(*(get_all_complex_dtypes() + get_all_fp_dtypes())) 6255: dtypesIfCUDA(*get_all_complex_dtypes(), 6256: *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)), 6292: dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) 6323: dtypesIfCUDA(*get_all_complex_dtypes(), 6324: *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) 6325: dtypes(*get_all_complex_dtypes(), *get_all_fp_dtypes()) 6358: dtypesIfCUDA(*([torch.float, torch.double] + get_all_complex_dtypes())) 6556: dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) 6668: dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) 6741: dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) ``` </p> </details> <details> <summary> `test/test_nn.py`</summary> <p> ```python 37:from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes 50: onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types, \ 8862: for device in get_all_device_types(): 9629: for dt1 in get_all_math_dtypes(device): 9630: for dt2 in get_all_math_dtypes(device): 9631: for dt3 in get_all_math_dtypes(device): 9648: for input_dtype in get_all_math_dtypes(device): 9664: for input_dtype in get_all_math_dtypes(device): 13015: dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 13034: dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 13159: dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 17400: dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 17768: dtypesIfCUDA(*get_all_fp_dtypes()) 17773: dtypesIfCUDA(*get_all_fp_dtypes()) 17778: dtypesIfCUDA(*get_all_fp_dtypes()) 17783: dtypesIfCUDA(*get_all_fp_dtypes()) 17788: dtypesIfCUDA(*get_all_fp_dtypes()) 17793: dtypesIfCUDA(*get_all_fp_dtypes()) 17798: dtypesIfCUDA(*get_all_fp_dtypes()) 17963: dtypesIfCUDA(*get_all_fp_dtypes()) 17977: dtypesIfCUDA(*get_all_fp_dtypes()) 18684: def test_cross_entropy_loss_prob_target_all_reductions(self, device): ``` </p> </details> <details> <summary> `test/test_numpy_interop.py`</summary> <p> ```python 12:from torch.testing._internal.common_dtype import get_all_dtypes 399: dtypes(*get_all_dtypes()) ``` </p> </details> <details> <summary> `test/test_ops.py`</summary> <p> ```python 12:from torch.testing._internal.common_dtype import floating_and_complex_types_and, get_all_dtypes 86: for dtype in get_all_dtypes(): ``` </p> </details> <details> <summary> `test/test_reductions.py`</summary> <p> ```python 16: get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes, 360: allowed_dtypes=get_all_dtypes(include_bfloat16=False)) 366: allowed_dtypes=get_all_dtypes(include_bfloat16=False)) 394: allowed_dtypes=get_all_dtypes(include_bfloat16=False)) 750: for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]: 1404: dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) 1457: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1458: get_all_complex_dtypes())) 1465: return dtype in get_all_int_dtypes() 1494: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) 1501: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) 1507: dtypes(*(get_all_complex_dtypes())) 1514: dtypes = list(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)) 1523: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) 1531: if dtype in get_all_fp_dtypes(): 1608: dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False, 1837: dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) 1855: dtypes(*(set(get_all_dtypes(include_bool=False, include_complex=False)) - {torch.uint8})) 3219: for dtype in get_all_dtypes(include_half=True, include_bfloat16=False, ``` </p> </details> <details> <summary> `test/test_serialization.py`</summary> <p> ```python 26:from torch.testing._internal.common_dtype import get_all_dtypes 586: for device, dtype in product(devices, get_all_dtypes()): 589: for other_dtype in get_all_dtypes(): ``` </p> </details> <details> <summary> `test/test_shape_ops.py`</summary> <p> ```python 18:from torch.testing._internal.common_dtype import get_all_dtypes 230: dtypes(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False, 232: dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False)) 344: dtypes(*get_all_dtypes()) 443: dtypes(*get_all_dtypes()) 461: dtypes(*get_all_dtypes()) 570: dtypes(*get_all_dtypes(include_complex=False)) ``` </p> </details> <details> <summary> `test/test_sort_and_select.py`</summary> <p> ```python 12: all_types, all_types_and, floating_types_and, get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, 136: dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) 231: dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) 296: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 647: dtypesIfCUDA(*get_all_fp_dtypes()) 678: dtypesIfCUDA(*(get_all_dtypes(include_complex=False, 682: dtypes(*(get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))) 739: dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128}) 740: dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) 799: dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128}) 800: dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) ``` </p> </details> <details> <summary> `test/test_sparse.py`</summary> <p> ```python 20:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes 29: floating_and_complex_types, floating_and_complex_types_and, get_all_dtypes, get_all_int_dtypes, 1963: return dtype in get_all_int_dtypes() 1994: dtypes(*get_all_dtypes(include_bool=False, include_half=False, 2103: return dtype in get_all_int_dtypes() 2138: dtypes(*get_all_dtypes(include_bool=False, include_half=False, 2626: all_sparse_dtypes = get_all_dtypes(include_complex=True) 2633: all_sparse_dtypes = get_all_dtypes(include_complex=True) 3230: dtypes(*get_all_complex_dtypes(), 3231: *get_all_fp_dtypes(include_half=False, include_bfloat16=False)) 3234: *get_all_fp_dtypes( ``` </p> </details> <details> <summary> `test/test_sparse_csr.py`</summary> <p> ```python 7:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes, floating_and_complex_types, make_tensor 17:from torch.testing._internal.common_dtype import floating_types, get_all_dtypes 120: dtypes(*get_all_dtypes()) 133: dtypes(*get_all_dtypes()) 150: dtypes(*get_all_dtypes()) 180: dtypes(*get_all_dtypes()) 201: dtypes(*get_all_dtypes()) 210: dtypes(*get_all_dtypes()) 225: dtypes(*get_all_dtypes()) 244: dtypes(*get_all_dtypes()) 263: dtypes(*get_all_dtypes()) 285: dtypes(*get_all_dtypes()) 411: dtypes(*get_all_dtypes()) 482: dtypes(*get_all_dtypes()) 502: dtypes(*get_all_dtypes()) 562: dtypes(*get_all_dtypes()) 588: dtypesIfCUDA(*get_all_complex_dtypes(), 589: *get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater)) 745: dtypesIfCUDA(*get_all_complex_dtypes(), 746: *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC, 765: dtypesIfCUDA(*get_all_complex_dtypes(), 766: *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC, 801: *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater, 841: *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater, 1182: dtypes(*get_all_dtypes()) 1276: dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_bfloat16=False)) 1286: dtypes(*get_all_dtypes()) ``` </p> </details> <details> <summary> `test/test_tensor_creation_ops.py`</summary> <p> ```python 21: onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta, get_all_device_types) 23: get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes 150: for dt in get_all_dtypes(): 160: for dt in get_all_dtypes(): 314: dtypes = [dtype for dtype in get_all_dtypes() if dtype != torch.bfloat16] 1012: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1013: get_all_complex_dtypes())) 1032: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1033: get_all_complex_dtypes())) 1050: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1051: get_all_complex_dtypes())) 1745: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1779: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1868: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1926: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1954: do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device) 1956: do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, None) 1957: do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device) 2538: for device in get_all_device_types(): 2645: for dtype in get_all_dtypes(): 2678: dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False) + 2679: get_all_complex_dtypes())) 2716: dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) 2827: for dt in get_all_dtypes(): 2913: dtypes(*get_all_dtypes(include_bool=False, include_half=False)) 2914: dtypesIfCUDA(*get_all_dtypes(include_bool=False, include_half=True)) 3028: dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) 3033: dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) 3074: dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_complex=False)) 3075: dtypesIfCUDA(*((get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16]) 3077: else get_all_dtypes(include_bool=False, include_half=True, include_complex=False))) 3873: dtypes(*get_all_dtypes()) 3884: dtypes(*get_all_dtypes(include_bool=False)) 3916: for other in get_all_dtypes(): 3922: dtypes(*get_all_dtypes()) 3932: dtypes(*get_all_dtypes(include_bool=False)) 3955: dtypes(*get_all_dtypes(include_bool=False)) 3961: dtypes(*get_all_dtypes(include_bool=False)) 3965: dtypes(*get_all_dtypes()) ``` </p> </details> <details> <summary> `test/test_testing.py`</summary> <p> ```python 25:from torch.testing._internal.common_dtype import get_all_dtypes 31: dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False, ``` </p> </details> <details> <summary> `test/test_torch.py`</summary> <p> ```python 51: expectedAlertNondeterministic, get_all_device_types, skipXLA) 57: get_all_fp_dtypes, get_all_int_dtypes, get_all_math_dtypes, get_all_dtypes, get_all_complex_dtypes 296: for d in get_all_device_types(): 323: for device in get_all_device_types(): 324: for dt1 in get_all_dtypes(): 325: for dt2 in get_all_dtypes(): 343: all_dtypes = get_all_dtypes() 350: all_dtypes = get_all_dtypes() 781: for dtype in get_all_dtypes(): 986: for device in get_all_device_types(): 1017: for device in get_all_device_types(): 1018: for dtype in get_all_math_dtypes(device): 2792: for device in get_all_device_types(): 3186: dtypes(*get_all_dtypes()) 3195: for error_dtype in get_all_dtypes(): 3203: dtypes(*get_all_dtypes()) 3212: for error_dtype in get_all_dtypes(): 4539: dtypes(*get_all_fp_dtypes()) 4545: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 4577: dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) 4578: dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True))) 4579: dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False))) 4599: dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False))) 4600: dtypesIfCPU(*(get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False))) 4601: dtypesIfCUDA(*(get_all_dtypes(include_bfloat16=False, include_complex=False))) 4613: for p_dtype in get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): 4628: dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False))) 4629: dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False))) 4640: dtypes(*get_all_fp_dtypes()) 4723: dtypes(*get_all_fp_dtypes()) 4735: dtypes(*get_all_fp_dtypes(include_bfloat16=False)) 4736: dtypesIfCUDA(*get_all_fp_dtypes()) 4747: dtypes(*get_all_fp_dtypes()) 4761: dtypes(*get_all_fp_dtypes()) 4771: dtypes(*get_all_fp_dtypes()) 4792: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 5302: dtypes(*get_all_dtypes(include_bfloat16=False)) 5322: dtypes(*get_all_dtypes(include_half=False, include_bfloat16=False)) 5323: dtypesIfCPU(*get_all_dtypes(include_bfloat16=False)) 5324: dtypesIfCUDA(*get_all_dtypes(include_bfloat16=False)) 5591: for dt in get_all_dtypes(): 5611: for dt in get_all_dtypes(): 5678: for dt in get_all_dtypes(): 5696: dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) 5697: dtypes(*set(get_all_math_dtypes('cpu'))) 5746: dtypes(*get_all_dtypes()) 5780: dtypes(*get_all_dtypes()) 5885: dtypes(*get_all_dtypes()) 5902: dtypes(*get_all_dtypes()) 5945: dtypes(*get_all_dtypes()) 5979: dtypes(*get_all_dtypes(include_bool=False)) 6049: dtypes(*get_all_dtypes(include_bool=False)) 6092: dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + 6093: get_all_complex_dtypes())) 6094: dtypesIfCPU(*get_all_dtypes()) 6095: dtypesIfCUDA(*get_all_dtypes()) 6122: dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + 6123: get_all_complex_dtypes())) 6124: dtypesIfCPU(*get_all_dtypes()) 6125: dtypesIfCUDA(*get_all_dtypes()) 6163: dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + 6164: get_all_complex_dtypes())) 6165: dtypesIfCPU(*get_all_dtypes()) 6166: dtypesIfCUDA(*get_all_dtypes()) 6190: dtypes(*(get_all_complex_dtypes() + 6191: get_all_int_dtypes())) 6238: dtypes(*get_all_dtypes()) 6323: dtypes(*get_all_dtypes()) 6389: dtypes(*product(get_all_dtypes(), (torch.uint8, torch.bool))) 6699: dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) 6700: dtypes(*set(get_all_math_dtypes('cpu'))) 7452: dtypes(*get_all_dtypes(include_bool=False)) 7461: dtypes(*get_all_dtypes(include_bool=False)) 7477: dtypes(*get_all_dtypes(include_bool=False)) 7496: dtypes(*get_all_dtypes(include_bool=False)) 7538: dtypes(*get_all_dtypes(include_bool=False)) 8162: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() + 8163: get_all_complex_dtypes())) 8175: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() + 8176: get_all_complex_dtypes())) ``` </p> </details> <details> <summary> `test/test_type_promotion.py`</summary> <p> ```python 14: get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes 187: for dtype in get_all_dtypes(): 262: dtypes1 = get_all_math_dtypes('cuda') 263: dtypes2 = get_all_math_dtypes(device) 339: dtypes(*itertools.product(get_all_dtypes(), get_all_dtypes())) 468: for dt1 in get_all_math_dtypes(device): 469: for dt2 in get_all_math_dtypes(device): 519: for dt1 in get_all_math_dtypes(device): 520: for dt2 in get_all_math_dtypes(device): 528: for dt in get_all_math_dtypes(device): 561: for dtype in get_all_dtypes(): 766: dtypes=get_all_math_dtypes(device)) 771: dtypes=get_all_math_dtypes(device)) 782: dtypes=get_all_math_dtypes(device)) 879: dtypes = get_all_dtypes(include_bfloat16=False) 898: dtypes = get_all_dtypes(include_bfloat16=False, include_bool=False) 965: dtypesIfCUDA(*itertools.product(get_all_dtypes(include_bfloat16=False, include_complex=False), 966: get_all_dtypes(include_bfloat16=False, include_complex=False))) 967: dtypes(*itertools.product(get_all_dtypes(include_half=False, include_bfloat16=False, 969: get_all_dtypes(include_half=False, include_bfloat16=False, 976: return dtype in get_all_int_dtypes() + [torch.bool] 979: return dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False) ``` </p> </details> <details> <summary> `test/test_unary_ufuncs.py`</summary> <p> ```python 24: floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, get_all_dtypes, get_all_math_dtypes, 25: get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes 517: dtypes(*(get_all_int_dtypes() + [torch.bool] + 518: get_all_fp_dtypes(include_bfloat16=False))) 596: dtypes(*get_all_fp_dtypes(include_half=True, include_bfloat16=False)) 611: invalid_input_dtypes = get_all_int_dtypes() + \ 612: get_all_complex_dtypes() + \ 619: for dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False): 1048: dtypes(*get_all_math_dtypes('cpu')) 1182: dtypesIfCUDA(*get_all_fp_dtypes()) 1190: dtypesIfCUDA(*get_all_fp_dtypes()) 1205: dtypesIfCUDA(*get_all_fp_dtypes()) 1215: dtypesIfCUDA(*get_all_fp_dtypes()) 1307: dtypes(*(get_all_dtypes(include_bool=False))) 1349: dtypes(*(get_all_fp_dtypes(include_half=False) + 1350: get_all_complex_dtypes())) 1351: dtypesIfCUDA(*(get_all_fp_dtypes(include_half=True) + 1352: get_all_complex_dtypes())) ``` </p> </details> <details> <summary> `test/test_view_ops.py`</summary> <p> ```python 19: get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes 124: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 131: dtypes(*get_all_dtypes(include_bfloat16=False)) 213: for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]: 220: dtypes(*get_all_dtypes()) 224: for view_dtype in get_all_dtypes(): 305: dtypes(*get_all_complex_dtypes(include_complex32=True)) 343: dtypes(*get_all_dtypes()) 354: dtypes(*get_all_dtypes()) 364: dtypes(*get_all_dtypes()) 374: dtypes(*get_all_dtypes()) 384: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 395: dtypes(*get_all_complex_dtypes()) 426: dtypes(*get_all_complex_dtypes()) 451: dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) 1263: dtypes(*(torch.testing.get_all_dtypes())) 1279: dtypes(*(torch.testing.get_all_dtypes())) 1405: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1406: get_all_complex_dtypes())) 1471: dtypes(*get_all_dtypes(include_bfloat16=False)) 1574: dtypes(*get_all_dtypes()) 1601: dtypes(*get_all_dtypes(include_bfloat16=False)) 1632: dtypes(*get_all_dtypes(include_bfloat16=False)) 1711: for dt in get_all_dtypes(): 1717: for dt in get_all_dtypes(): 1724: for dt in get_all_dtypes(): ``` </p> </details> I'm looking forward to your viewpoints. Thanks :) cc: mruberry kshitij12345 anjali411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/71561 Reviewed By: samdow Differential Revision: D34856571 Pulled By: mruberry fbshipit-source-id: 0dca038bcad5cf69906245c496d2e61ac3876335 (cherry picked from commit b058f67b4313143efa714ab105f36e74083131b9)
662 lines
33 KiB
Python
662 lines
33 KiB
Python
# Owner(s): ["module: mta"]
|
|
|
|
import itertools
|
|
from numbers import Number
|
|
import random
|
|
import re
|
|
import torch
|
|
import unittest
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._comparison import default_tolerances
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, dtypes, onlyCUDA, skipCUDAIfRocm, skipMeta, ops)
|
|
from torch.testing._internal.common_methods_invocations import (
|
|
foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db,
|
|
foreach_reduce_op_db)
|
|
from torch.testing._internal.common_dtype import (
|
|
all_types_and_complex_and, all_types_and, integral_types, complex_types,
|
|
floating_types_and, floating_types, integral_types_and,
|
|
)
|
|
|
|
# Includes some values such that N * N won't be a multiple of 4,
|
|
# which should ensure we test the vectorized and non-vectorized
|
|
# kernel code paths.
|
|
N_values = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300]
|
|
Scalars = (
|
|
random.randint(1, 10),
|
|
1.0 - random.random(),
|
|
True,
|
|
complex(1.0 - random.random(), 1.0 - random.random()),
|
|
)
|
|
|
|
def getScalarLists(N):
|
|
return (
|
|
("int", [random.randint(0, 9) + 1 for _ in range(N)]),
|
|
("float", [1.0 - random.random() for _ in range(N)]),
|
|
("complex", [complex(1.0 - random.random(), 1.0 - random.random()) for _ in range(N)]),
|
|
("bool", [True for _ in range(N)]),
|
|
("mixed", [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(N - 3)]),
|
|
("mixed", [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(N - 4)]),
|
|
)
|
|
|
|
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
|
|
|
|
class RegularFuncWrapper:
|
|
|
|
def __init__(self, func):
|
|
self.func = func
|
|
|
|
def __call__(self, inputs, values=None, **kwargs):
|
|
if values is not None:
|
|
assert len(inputs) == 3
|
|
if isinstance(values, Number):
|
|
values = [values for _ in range(len(inputs[0]))]
|
|
return [self.func(*i, value=values[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
|
|
if len(inputs) == 2 and isinstance(inputs[1], Number):
|
|
# binary op with tensorlist and scalar.
|
|
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
|
|
return [self.func(*i, **kwargs) for i in zip(*inputs)]
|
|
|
|
|
|
class ForeachFuncWrapper:
|
|
|
|
def __init__(self, func, n_expected_cudaLaunchKernels):
|
|
self.func = func
|
|
self.n_expected_cudaLaunchKernels = n_expected_cudaLaunchKernels
|
|
# Some foreach functions don't have in-place implementations.
|
|
self._is_inplace = False if func is None else func.__name__.endswith('_')
|
|
|
|
def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
|
|
actual = None
|
|
if (
|
|
is_cuda and
|
|
torch.autograd.kineto_available() and
|
|
torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities()
|
|
):
|
|
with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
|
|
actual = self.func(*inputs, **kwargs)
|
|
for e in p.key_averages():
|
|
if e.key == 'cudaLaunchKernel':
|
|
if is_fastpath:
|
|
assert e.count == self.n_expected_cudaLaunchKernels
|
|
else:
|
|
assert e.count > self.n_expected_cudaLaunchKernels
|
|
else:
|
|
actual = self.func(*inputs, **kwargs)
|
|
# note(mkozuki): inplace foreach functions are void functions.
|
|
return inputs[0] if self._is_inplace else actual
|
|
|
|
class TestForeach(TestCase):
|
|
|
|
@property
|
|
def is_cuda(self):
|
|
return self.device_type == 'cuda'
|
|
|
|
# note(mkozuki): It might be the case that the expected number of `cudaLaunchKernel`s
|
|
# is greater than 1 once foreach functions internally separate their input `TensorList`s by
|
|
# devices & dtypes into vectors of tensors.
|
|
def _get_funcs(self, op, n_expected_cudaLaunchKernels):
|
|
return (
|
|
ForeachFuncWrapper(op.method_variant, n_expected_cudaLaunchKernels),
|
|
RegularFuncWrapper(op.ref),
|
|
ForeachFuncWrapper(op.inplace_variant, n_expected_cudaLaunchKernels),
|
|
RegularFuncWrapper(op.ref_inplace),
|
|
)
|
|
|
|
def _binary_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, alpha=None):
|
|
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
|
|
try:
|
|
actual = op(inputs, self.is_cuda, is_fastpath)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs)
|
|
else:
|
|
expected = ref(ref_inputs)
|
|
self.assertEqual(actual, expected)
|
|
if alpha is not None:
|
|
kwargs = {'alpha': alpha}
|
|
ref_inputs = inputs
|
|
try:
|
|
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs, **kwargs)
|
|
else:
|
|
expected = ref(ref_inputs, **kwargs)
|
|
if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
|
|
self.assertEqual(expected, actual, atol=1.e-3, rtol=default_tolerances(dtype)[0])
|
|
else:
|
|
self.assertEqual(expected, actual)
|
|
|
|
def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, disable_fastpath):
|
|
n_expected_cudaLaunchKernels = N if disable_fastpath else 1
|
|
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
|
|
inputs = [
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
]
|
|
self._binary_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False)
|
|
self._binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True)
|
|
if opinfo.supports_alpha_param:
|
|
alpha = None
|
|
if dtype in integral_types():
|
|
alpha = 3
|
|
elif dtype.is_complex:
|
|
alpha = complex(3, 3)
|
|
else:
|
|
alpha = 3.14
|
|
self._binary_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False, alpha=alpha)
|
|
self._binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True, alpha=alpha)
|
|
|
|
# Tests of implicit broadcasting
|
|
# When sizes of tensors don't match, foreach functions are supposed to choose slow path
|
|
# even if this methods's argument `is_fastpath` is True.
|
|
# `cudaLaunchKernel` will be equal to `N`. For assert in `ForeachFuncWrapper` to pass,
|
|
# we pass `is_fastpath and disable_fastpath` to `_binary_test`'s argument of is_fastpath.
|
|
# as n_expected_cudaLaunchKernels is N if disable_fastpath.
|
|
inputs = [
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
[
|
|
make_tensor((N - i , 1), device=device, dtype=dtype, noncontiguous=not is_fastpath) for i in range(N)
|
|
],
|
|
]
|
|
self._binary_test(dtype, op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False)
|
|
self._binary_test(
|
|
dtype, inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath, is_inplace=True)
|
|
|
|
# note(mkozuki): Why ROCm?
|
|
# ROCm is supposed to compile slow path as in
|
|
# https://github.com/pytorch/pytorch/blob/7e032f18cf1405804c4f787b05ea2de5e08a091e/aten/src/ATen/native/ForeachUtils.h#L148-L164, # noqa: E501
|
|
# Therefore `[torch.add(*args, alpha=alpha) for args in zip(tensors1, tensors2)]` and
|
|
# `torch._foreach_add(tensors1, tensors2, alpha=alpha)`
|
|
# are expected to return the same outputs, however, the outputs look unstable for torch.bfloat16 and torch.half.
|
|
# log: https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/2741/console
|
|
@skipCUDAIfRocm
|
|
@skipMeta
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_tensorlists_fastpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
disable_fastpath = op.ref == torch.div and dtype in integral_types_and(torch.bool)
|
|
if op.ref == torch.add and dtype == torch.bool:
|
|
disable_fastpath = True
|
|
self._test_binary_op_tensorlists(device, dtype, op, N, True, disable_fastpath)
|
|
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_tensorlists_slowpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
self._test_binary_op_tensorlists(device, dtype, op, N, False, False)
|
|
|
|
def _test_binary_op_scalar(self, device, dtype, opinfo, N, scalar, is_fastpath, disable_fastpath):
|
|
n_expected_cudaLaunchKernels = N if disable_fastpath else 1
|
|
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
|
|
inputs = [opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath), scalar]
|
|
self._binary_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False)
|
|
self._binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True)
|
|
|
|
@skipCUDAIfRocm
|
|
@skipMeta
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_scalar_fastpath(self, device, dtype, op):
|
|
for N, scalar in itertools.product(N_values, Scalars):
|
|
disable_fastpath = op.ref == torch.div and dtype in integral_types_and(torch.bool)
|
|
if isinstance(scalar, int):
|
|
disable_fastpath |= dtype == torch.bool
|
|
if isinstance(scalar, float):
|
|
disable_fastpath |= dtype in integral_types_and(torch.bool)
|
|
if isinstance(scalar, bool):
|
|
disable_fastpath |= dtype == torch.bool
|
|
if op.ref in (torch.add, torch.mul):
|
|
disable_fastpath = False
|
|
if isinstance(scalar, complex):
|
|
disable_fastpath |= dtype not in complex_types()
|
|
self._test_binary_op_scalar(device, dtype, op, N, scalar, True, disable_fastpath)
|
|
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_scalar_slowpath(self, device, dtype, op):
|
|
for N, scalar in itertools.product(N_values, Scalars):
|
|
self._test_binary_op_scalar(device, dtype, op, N, scalar, False, False)
|
|
|
|
def _test_binary_op_scalarlist(self, device, dtype, opinfo, N, scalarlist, is_fastpath, disable_fastpath):
|
|
n_expected_cudaLaunchKernels = N if disable_fastpath else 1
|
|
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
|
|
inputs = [opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath), scalarlist]
|
|
self._binary_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False)
|
|
self._binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True)
|
|
|
|
# note(mkozuki): Why two functions depending on with/without bool?
|
|
# `foreach_sub` & `foreach_sub_` do `sub_check(tensors[i], scalars[i])` from i=1...N.
|
|
# So, if scalarlist has one or more bool values, `foreach_sub` and `foreach_sub_`
|
|
# raise bool subtraction error before doing any math.
|
|
# While regular `sub` and `sub_` do some math until they encounter bool.
|
|
# So, foreach sub's throw bool sub error first. However, regular sub's throw different
|
|
# errors depending on the order of scalarlist. To keep actual unit test impl simple,
|
|
# separating mixed scalarlist tests. By setting the first element of scalarlist to bool,
|
|
# they are expected to throw bool sub error even in inplace test.
|
|
@skipCUDAIfRocm
|
|
@skipMeta
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_scalarlist_fastpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
for type_str, scalarlist in getScalarLists(N):
|
|
bool_int_div = op.ref == torch.div and dtype in integral_types_and(torch.bool)
|
|
disable_fastpath = bool_int_div
|
|
if type_str == "int":
|
|
disable_fastpath |= dtype == torch.bool
|
|
if type_str == "float":
|
|
disable_fastpath |= dtype in integral_types_and(torch.bool)
|
|
if type_str == "complex":
|
|
disable_fastpath |= dtype not in complex_types()
|
|
if type_str == "mixed":
|
|
disable_fastpath |= True and dtype not in complex_types()
|
|
self._test_binary_op_scalarlist(device, dtype, op, N, scalarlist, True, disable_fastpath)
|
|
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_scalarlist_slowpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
for _, scalarlist in getScalarLists(N):
|
|
self._test_binary_op_scalarlist(device, dtype, op, N, scalarlist, False, False)
|
|
|
|
def _pointwise_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, values=None):
|
|
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
|
|
try:
|
|
actual = op(inputs, self.is_cuda, is_fastpath)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs)
|
|
else:
|
|
expected = ref(ref_inputs)
|
|
self.assertEqual(expected, actual)
|
|
if values is not None:
|
|
try:
|
|
actual = op(inputs + [values], self.is_cuda, is_fastpath)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(ref_inputs, values=values)
|
|
else:
|
|
expected = ref(ref_inputs, values=values)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fastpath, *, values=None):
|
|
n_expected_cudaLaunchKernels = N if disable_fastpath else 1
|
|
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
|
|
inputs = [
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
]
|
|
self._pointwise_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False, values=values)
|
|
self._pointwise_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True, values=values)
|
|
|
|
# Tests of implicit broadcasting
|
|
inputs = [
|
|
opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath, same_size=True),
|
|
[
|
|
make_tensor((N - i, 1), device=device, dtype=dtype, noncontiguous=not is_fastpath) for i in range(N)
|
|
],
|
|
[
|
|
make_tensor((1, N - i), device=device, dtype=dtype, noncontiguous=not is_fastpath) for i in range(N)
|
|
],
|
|
]
|
|
self._pointwise_test(dtype, op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False, values=values)
|
|
self._pointwise_test(
|
|
dtype, inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath, is_inplace=True, values=values)
|
|
|
|
@skipMeta
|
|
@ops(foreach_pointwise_op_db)
|
|
def test_pointwise_op_fastpath(self, device, dtype, op):
|
|
disable_fastpath = dtype in integral_types_and(torch.bool)
|
|
# for N, scalar in itertools.product(N_values, Scalars):
|
|
for N in N_values:
|
|
self._test_pointwise_op(device, dtype, op, N, True, disable_fastpath)
|
|
for scalar in Scalars:
|
|
self._test_pointwise_op(device, dtype, op, N, True, disable_fastpath, values=scalar)
|
|
for _, scalarlist in getScalarLists(N):
|
|
self._test_pointwise_op(
|
|
device, dtype, op, N, True, disable_fastpath, values=scalarlist)
|
|
|
|
@ops(foreach_pointwise_op_db)
|
|
def test_pointwise_op_slowpath(self, device, dtype, op):
|
|
# for N, scalar in itertools.product(N_values, Scalars):
|
|
for N in N_values:
|
|
self._test_pointwise_op(device, dtype, op, N, False, False)
|
|
for scalar in Scalars:
|
|
self._test_pointwise_op(device, dtype, op, N, False, False, values=scalar)
|
|
for _, scalarlist in getScalarLists(N):
|
|
self._test_pointwise_op(
|
|
device, dtype, op, N, False, False, values=scalarlist)
|
|
|
|
# note(mkozuki): fastpath test uses dtypes which fastpath implementation supports.
|
|
# To confirm the dtypes of `OpInfo` cover the dtypes that the function support,
|
|
# this test does not use `try-except` for fastpath.
|
|
def _regular_unary_test(self, dtype, op, ref, inputs, is_fastpath):
|
|
if is_fastpath:
|
|
self.assertEqual(ref(inputs), op(inputs, self.is_cuda, is_fastpath))
|
|
return
|
|
try:
|
|
actual = op(inputs, self.is_cuda, is_fastpath)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
ref(inputs)
|
|
else:
|
|
expected = ref(inputs)
|
|
self.assertEqual(actual, expected)
|
|
|
|
# note(mkozuki): why `try-except` for both fastpath?
|
|
# - inputs for fastpath can be integer tensors.
|
|
# - this is becase opinfo dtypes are configured for outpulace implementation
|
|
# - for integer inputs, trigonometric functions and exponential function returns float outputs,
|
|
# which causes "result type Float can't be case to the desired type" error.
|
|
# Thus, `try-except` is used even if `is_fastpath` is `True`.
|
|
def _inplace_unary_test(self, dtype, inplace, inplace_ref, inputs, is_fastpath):
|
|
copied_inputs = [[t.clone().detach() for t in tensors] for tensors in inputs]
|
|
try:
|
|
inplace(inputs, self.is_cuda, is_fastpath)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
inplace_ref(copied_inputs)
|
|
else:
|
|
inplace_ref(copied_inputs),
|
|
self.assertEqual(copied_inputs, inputs)
|
|
|
|
def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
|
|
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
|
|
inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
|
|
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
|
|
if opinfo.name == "_foreach_abs" and dtype in complex_types():
|
|
is_fastpath = False
|
|
self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
|
|
self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
|
|
|
|
@skipMeta
|
|
@ops(foreach_unary_op_db)
|
|
def test_unary_fastpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
self._test_unary(device, dtype, op, N, is_fastpath=True)
|
|
|
|
@ops(foreach_unary_op_db, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_unary_slowpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
self._test_unary(device, dtype, op, N, is_fastpath=False)
|
|
|
|
def _minmax_test(self, opinfo, inputs, is_fastpath, n_expected_cudaLaunchKernels):
|
|
op, ref, _, _ = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
|
|
self.assertEqual(ref(inputs), op(inputs, self.is_cuda, is_fastpath))
|
|
|
|
# note(mkozuki): in-place of foreach_minimum and foreach_maximum aren't implemented.
|
|
@ops(foreach_minmax_op_db)
|
|
def test_minmax_fastpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
inputs = tuple(op.sample_inputs(device, dtype, N) for _ in range(2))
|
|
self._minmax_test(op, inputs, True, N if dtype == torch.bool else 1)
|
|
|
|
@ops(foreach_minmax_op_db,
|
|
dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_minmax_slowpath(self, device, dtype, op):
|
|
for N in N_values:
|
|
inputs = tuple(op.sample_inputs(device, dtype, N, noncontiguous=True) for _ in range(2))
|
|
self._minmax_test(op, inputs, False, 1)
|
|
|
|
# note(mkozuki): ForeachFuncInfo's of both `_foreach_maximum` and `_foreach_minimum` include integer types.
|
|
# so, manually limit dtypes to fp types for inf&nan tests.
|
|
@ops(foreach_minmax_op_db, dtypes=floating_types_and(torch.half, torch.bfloat16))
|
|
def test_minmax_float_inf_nan(self, device, dtype, op):
|
|
inputs = (
|
|
[
|
|
torch.tensor([float('inf')], device=device, dtype=dtype),
|
|
torch.tensor([-float('inf')], device=device, dtype=dtype),
|
|
torch.tensor([float('nan')], device=device, dtype=dtype),
|
|
torch.tensor([float('nan')], device=device, dtype=dtype)
|
|
],
|
|
[
|
|
torch.tensor([-float('inf')], device=device, dtype=dtype),
|
|
torch.tensor([float('inf')], device=device, dtype=dtype),
|
|
torch.tensor([float('inf')], device=device, dtype=dtype),
|
|
torch.tensor([float('nan')], device=device, dtype=dtype)
|
|
],
|
|
)
|
|
self._minmax_test(op, inputs, True, 1)
|
|
|
|
def _reduce_test(self, opinfo, inputs, ord, is_fastpath, n_expected_cudaLaunchKernels):
|
|
op, ref, _, _ = self._get_funcs(opinfo, n_expected_cudaLaunchKernels)
|
|
self.assertEqual(ref(inputs, ord=ord), op(inputs, self.is_cuda, is_fastpath, ord=ord))
|
|
|
|
@ops(foreach_reduce_op_db)
|
|
def test_reduce_fastpath(self, device, dtype, op):
|
|
for N, ord in itertools.product(N_values, (0, 1, 2, -1, -2)):
|
|
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
|
|
n_expected_cudaLaunchKernels = 3
|
|
else:
|
|
n_expected_cudaLaunchKernels = N
|
|
inputs = op.sample_inputs(device, dtype, N, noncontiguous=False),
|
|
self._reduce_test(op, inputs, ord, True, n_expected_cudaLaunchKernels)
|
|
|
|
@ops(foreach_reduce_op_db)
|
|
def test_reduce_slowpath(self, device, dtype, op):
|
|
for N, ord in itertools.product(N_values, (0, 1, 2, -1, -2)):
|
|
inputs = op.sample_inputs(device, dtype, N, noncontiguous=True),
|
|
self._reduce_test(op, inputs, ord, False, 1)
|
|
|
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
|
|
# TODO: enable empty list case
|
|
for tensors in [[torch.randn([0])]]:
|
|
res = torch._foreach_add(tensors, 1)
|
|
self.assertEqual(res, tensors)
|
|
|
|
torch._foreach_add_(tensors, 1)
|
|
self.assertEqual(res, tensors)
|
|
|
|
@ops(foreach_binary_op_db, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
|
|
foreach_op, ref = op.method_variant, op.ref
|
|
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
|
|
|
|
if ref == torch.sub and dtype == torch.bool:
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
[ref(t, 1) for t in tensors]
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
foreach_op(tensors, 1)
|
|
return
|
|
|
|
expected = [ref(t, 1) for t in tensors]
|
|
res = foreach_op(tensors, 1)
|
|
self.assertEqual(res, expected)
|
|
|
|
# note(mkozuki): this test case fails with Meta at least in my local environment.
|
|
# The message was
|
|
# `AssertionError: NotImplementedError("Could not run 'aten::_foreach_add.Scalar' with arguments from the 'Meta' backend.`
|
|
@skipMeta
|
|
@ops(foreach_binary_op_db, allowed_dtypes=[torch.float])
|
|
def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
|
|
foreach_op = op.method_variant
|
|
tensors = [torch.tensor([1.1], dtype=torch.float, device=device),
|
|
torch.tensor([1], dtype=torch.long, device=device)]
|
|
runtime_error = None
|
|
try:
|
|
foreach_op(tensors, 1)
|
|
except RuntimeError as e:
|
|
runtime_error = e
|
|
self.assertIsNone(runtime_error)
|
|
|
|
@ops(foreach_binary_op_db, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_binary_op_list_error_cases(self, device, dtype, op):
|
|
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
|
|
tensors1 = []
|
|
tensors2 = []
|
|
|
|
# Empty lists
|
|
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
|
|
foreach_op(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
|
|
foreach_op_(tensors1, tensors2)
|
|
|
|
# One empty list
|
|
tensors1.append(torch.tensor([1], device=device, dtype=dtype))
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
|
|
foreach_op(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
|
|
foreach_op_(tensors1, tensors2)
|
|
|
|
# Lists have different amount of tensors
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
tensors2.append(torch.tensor([1], device=device))
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
|
|
foreach_op(tensors1, tensors2)
|
|
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
|
|
foreach_op_(tensors1, tensors2)
|
|
|
|
# Corresponding tensors with different sizes that aren't compatible with broadcast
|
|
# If sizes are different then foreach chooses slow path, thus error messages are expected
|
|
# to be the same as torch regular function.
|
|
tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
|
tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
|
|
try:
|
|
foreach_op(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
try:
|
|
foreach_op_(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
|
|
# different devices
|
|
if self.device_type == "cuda" and torch.cuda.device_count() > 1:
|
|
tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
|
|
tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
|
|
if dtype == torch.bool and foreach_op == torch._foreach_sub:
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
foreach_op([tensor1], [tensor2])
|
|
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
|
|
foreach_op_([tensor1], [tensor2])
|
|
return
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
|
|
foreach_op([tensor1], [tensor2])
|
|
if dtype in integral_types_and(torch.bool) and foreach_op == torch._foreach_div:
|
|
with self.assertRaisesRegex(RuntimeError, "result type"):
|
|
foreach_op_([tensor1], [tensor2])
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
|
|
foreach_op_([tensor1], [tensor2])
|
|
|
|
@skipMeta
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
|
@ops(foreach_binary_op_db, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
|
def test_binary_op_list_slow_path(self, device, dtype, op):
|
|
# note(mkozuki): why `n_expected_cudaLaunchKernels=0`?
|
|
# In this test, foreach functions don't go through fast path,
|
|
# but as there is only one tensor in each list of tensors,
|
|
# `cudaLaunchKernel` is 1 so ForeachFuncWrapper internal assert fails.
|
|
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op, n_expected_cudaLaunchKernels=0)
|
|
# 0-strides
|
|
tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
|
|
tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
|
|
inputs = ([tensor1], [tensor2])
|
|
self._binary_test(dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False)
|
|
self._binary_test(dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True)
|
|
|
|
# different strides
|
|
tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
|
|
tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
|
|
inputs = ([tensor1], [tensor2.t()])
|
|
self._binary_test(dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False)
|
|
self._binary_test(dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True)
|
|
|
|
# non contiguous
|
|
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
|
|
tensor2 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
|
|
self.assertFalse(tensor1.is_contiguous())
|
|
self.assertFalse(tensor2.is_contiguous())
|
|
inputs = ([tensor1], [tensor2])
|
|
self._binary_test(dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False)
|
|
self._binary_test(dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True)
|
|
|
|
# sliced tensor
|
|
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
|
|
tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[:, :, :, ::7]
|
|
inputs = ([tensor1], [tensor2])
|
|
self._binary_test(dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False)
|
|
self._binary_test(dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True)
|
|
|
|
# note: Below three tests (postfixed with `_tensors_on_different_devices`)
|
|
# checks whether foreach works with lists of tensors on different devices
|
|
# but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
|
|
@onlyCUDA
|
|
@ops(foreach_unary_op_db)
|
|
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
|
|
method, ref, inplace_method, ref_inplace = self._get_funcs(op, 1)
|
|
# tensors: ['cuda', 'cpu]
|
|
tensors = op.sample_inputs(device, dtype, 2)
|
|
tensors[1] = tensors[1].to('cpu')
|
|
try:
|
|
actual = method((tensors,), False, False)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), str(e)):
|
|
ref((tensors,))
|
|
else:
|
|
expected = ref((tensors,))
|
|
self.assertEqual(expected, actual)
|
|
|
|
try:
|
|
inplace_method((tensors,), False, False)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), str(e)):
|
|
ref_inplace((tensors,))
|
|
else:
|
|
self.assertEqual(expected, tensors)
|
|
|
|
@onlyCUDA
|
|
@ops(foreach_binary_op_db)
|
|
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
|
|
# `tensors1`: ['cuda', 'cpu']
|
|
# `tensors2`: ['cuda', 'cpu']
|
|
_cuda_tensors = op.sample_inputs(device, dtype, 2, same_size=True)
|
|
_cpu_tensors = op.sample_inputs('cpu', dtype, 2, same_size=True)
|
|
tensors1, tensors2 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors))
|
|
|
|
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
|
|
native_op, native_op_ = op.ref, op.ref_inplace
|
|
try:
|
|
actual = foreach_op(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
else:
|
|
expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
self.assertEqual(expected, actual)
|
|
try:
|
|
foreach_op_(tensors1, tensors2)
|
|
except RuntimeError as e:
|
|
with self.assertRaisesRegex(type(e), re.escape(str(e))):
|
|
[native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
|
|
else:
|
|
self.assertEqual(actual, tensors1)
|
|
|
|
@onlyCUDA
|
|
@ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
|
|
def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
|
|
# tensors1: ['cuda', 'cpu]
|
|
# tensors2: ['cuda', 'cpu]
|
|
# tensors3: ['cuda', 'cpu]
|
|
_cuda_tensors = op.sample_inputs(device, dtype, 3, same_size=True)
|
|
_cpu_tensors = op.sample_inputs('cpu', dtype, 3, same_size=True)
|
|
tensors1, tensors2, tensors3 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors))
|
|
|
|
foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
|
|
actual = foreach_op(tensors1, tensors2, tensors3)
|
|
expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
|
|
self.assertEqual(expected, actual)
|
|
|
|
# note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
|
|
foreach_op_(tensors1, tensors2, tensors3)
|
|
self.assertEqual(expected, tensors1)
|
|
|
|
|
|
instantiate_device_type_tests(TestForeach, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|