Jason Ansel
04b26ee1e8
Fix false positive from f-strings in set_linter ( #143628 )
...
This linter was going crazy in python 3.12, example:
```py
$ python3 tools/linter/adapters/set_linter.py torch/_inductor/runtime/triton_heuristics.py
torch/_inductor/runtime/triton_heuristics.py:192:25: Builtin `set` is deprecated
190 | args_str += ", ".join(call_args)
191 | for k, v in call_kwargs.items():
192 | args_str += f", {k}={v}"
^
193 |
194 | abs_path = os.path.abspath(sys.argv[0])
torch/_inductor/runtime/triton_heuristics.py:192:27: Builtin `set` is deprecated
190 | args_str += ", ".join(call_args)
191 | for k, v in call_kwargs.items():
192 | args_str += f", {k}={v}"
^
193 |
194 | abs_path = os.path.abspath(sys.argv[0])
torch/_inductor/runtime/triton_heuristics.py:192:29: Builtin `set` is deprecated
190 | args_str += ", ".join(call_args)
191 | for k, v in call_kwargs.items():
192 | args_str += f", {k}={v}"
^
193 |
194 | abs_path = os.path.abspath(sys.argv[0])
torch/_inductor/runtime/triton_heuristics.py:192:31: Builtin `set` is deprecated
190 | args_str += ", ".join(call_args)
191 | for k, v in call_kwargs.items():
192 | args_str += f", {k}={v}"
^
193 |
194 | abs_path = os.path.abspath(sys.argv[0])
torch/_inductor/runtime/triton_heuristics.py:195:17: Builtin `set` is deprecated
193 |
194 | abs_path = os.path.abspath(sys.argv[0])
195 | with open(f"{abs_path}.launch_params", "a") as f:
^
196 | f.write(f"{kernel_name} | {args_str}\n")
197 |
torch/_inductor/runtime/triton_heuristics.py:195:26: Builtin `set` is deprecated
193 |
194 | abs_path = os.path.abspath(sys.argv[0])
195 | with open(f"{abs_path}.launch_params", "a") as f:
^
196 | f.write(f"{kernel_name} | {args_str}\n")
197 |
torch/_inductor/runtime/triton_heuristics.py:196:19: Builtin `set` is deprecated
194 | abs_path = os.path.abspath(sys.argv[0])
195 | with open(f"{abs_path}.launch_params", "a") as f:
196 | f.write(f"{kernel_name} | {args_str}\n")
^
197 |
198 |
torch/_inductor/runtime/triton_heuristics.py:196:31: Builtin `set` is deprecated
194 | abs_path = os.path.abspath(sys.argv[0])
195 | with open(f"{abs_path}.launch_params", "a") as f:
196 | f.write(f"{kernel_name} | {args_str}\n")
^
197 |
198 |
torch/_inductor/runtime/triton_heuristics.py:196:35: Builtin `set` is deprecated
194 | abs_path = os.path.abspath(sys.argv[0])
195 | with open(f"{abs_path}.launch_params", "a") as f:
196 | f.write(f"{kernel_name} | {args_str}\n")
^
197 |
198 |
torch/_inductor/runtime/triton_heuristics.py:196:44: Builtin `set` is deprecated
194 | abs_path = os.path.abspath(sys.argv[0])
195 | with open(f"{abs_path}.launch_params", "a") as f:
196 | f.write(f"{kernel_name} | {args_str}\n")
^
197 |
198 |
torch/_inductor/runtime/triton_heuristics.py:729:26: Builtin `set` is deprecated
727 | exec(
728 | f"""
729 | def launcher({', '.join(def_args)}, grid, stream):
^
730 | if callable(grid):
731 | grid_0, grid_1, grid_2 = grid(grid_meta)
torch/_inductor/runtime/triton_heuristics.py:729:46: Builtin `set` is deprecated
727 | exec(
728 | f"""
729 | def launcher({', '.join(def_args)}, grid, stream):
^
730 | if callable(grid):
731 | grid_0, grid_1, grid_2 = grid(grid_meta)
torch/_inductor/runtime/triton_heuristics.py:735:24: Builtin `set` is deprecated
733 | grid_0, grid_1, grid_2 = grid
734 |
735 | args = {', '.join(call_args)},
^
736 | launch_args = get_launch_args(
737 | grid, grid_0, grid_1, grid_2, stream, function,
torch/_inductor/runtime/triton_heuristics.py:735:45: Builtin `set` is deprecated
733 | grid_0, grid_1, grid_2 = grid
734 |
735 | args = {', '.join(call_args)},
^
736 | launch_args = get_launch_args(
737 | grid, grid_0, grid_1, grid_2, stream, function,
torch/_inductor/runtime/triton_heuristics.py:1144:20: Builtin `set` is deprecated
1142 | cur_file = inspect.stack()[1].filename
1143 | summary_str = (
1144 | f"SUMMARY ({cur_file})\n"
^
1145 | f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
1146 | )
torch/_inductor/runtime/triton_heuristics.py:1144:29: Builtin `set` is deprecated
1142 | cur_file = inspect.stack()[1].filename
1143 | summary_str = (
1144 | f"SUMMARY ({cur_file})\n"
^
1145 | f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s"
1146 | )
torch/_inductor/runtime/triton_heuristics.py:1162:61: Builtin `set` is deprecated
1160 | )
1161 | file.write("====================\n")
1162 | file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
^
1163 | for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
1164 | # also display the runtime percentage for each kernel
torch/_inductor/runtime/triton_heuristics.py:1162:70: Builtin `set` is deprecated
1160 | )
1161 | file.write("====================\n")
1162 | file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
^
1163 | for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
1164 | # also display the runtime percentage for each kernel
torch/_inductor/runtime/triton_heuristics.py:1166:36: Builtin `set` is deprecated
1164 | # also display the runtime percentage for each kernel
1165 | percentage = f"{ms / overall_time * 100:.2f}%"
1166 | suffix = f" \t {percentage} \t {kernel_name}"
^
1167 | bw_info_str = create_bandwidth_info_str(
1168 | ms,
torch/_inductor/runtime/triton_heuristics.py:1166:47: Builtin `set` is deprecated
1164 | # also display the runtime percentage for each kernel
1165 | percentage = f"{ms / overall_time * 100:.2f}%"
1166 | suffix = f" \t {percentage} \t {kernel_name}"
^
1167 | bw_info_str = create_bandwidth_info_str(
1168 | ms,
torch/_inductor/runtime/triton_heuristics.py:1166:52: Builtin `set` is deprecated
1164 | # also display the runtime percentage for each kernel
1165 | percentage = f"{ms / overall_time * 100:.2f}%"
1166 | suffix = f" \t {percentage} \t {kernel_name}"
^
1167 | bw_info_str = create_bandwidth_info_str(
1168 | ms,
torch/_inductor/runtime/triton_heuristics.py:1166:64: Builtin `set` is deprecated
1164 | # also display the runtime percentage for each kernel
1165 | percentage = f"{ms / overall_time * 100:.2f}%"
1166 | suffix = f" \t {percentage} \t {kernel_name}"
^
1167 | bw_info_str = create_bandwidth_info_str(
1168 | ms,
torch/_inductor/runtime/triton_heuristics.py:1175:30: Builtin `set` is deprecated
1173 | )
1174 | file.write(bw_info_str + "\n")
1175 | file.write(f"{summary_str}\n\n")
^
1176 | except Exception as e:
1177 | log.warning(
torch/_inductor/runtime/triton_heuristics.py:1175:42: Builtin `set` is deprecated
1173 | )
1174 | file.write(bw_info_str + "\n")
1175 | file.write(f"{summary_str}\n\n")
^
1176 | except Exception as e:
1177 | log.warning(
torch/_inductor/runtime/triton_heuristics.py:1205:29: Builtin `set` is deprecated
1203 | else:
1204 | possible_names = _find_names(self)
1205 | kernel_name = f"{max(possible_names, key=len)}"
^
1206 | if not re.match(self.regex_filter, kernel_name):
1207 | return
torch/_inductor/runtime/triton_heuristics.py:1205:58: Builtin `set` is deprecated
1203 | else:
1204 | possible_names = _find_names(self)
1205 | kernel_name = f"{max(possible_names, key=len)}"
^
1206 | if not re.match(self.regex_filter, kernel_name):
1207 | return
torch/_inductor/runtime/triton_heuristics.py:1241:60: Builtin `set` is deprecated
1239 | "%s",
1240 | create_bandwidth_info_str(
1241 | ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
^
1242 | ),
1243 | )
torch/_inductor/runtime/triton_heuristics.py:1241:72: Builtin `set` is deprecated
1239 | "%s",
1240 | create_bandwidth_info_str(
1241 | ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
^
1242 | ),
1243 | )
torch/_inductor/runtime/triton_heuristics.py:1256:15: Builtin `set` is deprecated
1254 | for cfg in configs:
1255 | hasher.update(
1256 | f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
^
1257 | )
1258 | return hasher.hexdigest()
torch/_inductor/runtime/triton_heuristics.py:1256:42: Builtin `set` is deprecated
1254 | for cfg in configs:
1255 | hasher.update(
1256 | f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
^
1257 | )
1258 | return hasher.hexdigest()
torch/_inductor/runtime/triton_heuristics.py:1256:44: Builtin `set` is deprecated
1254 | for cfg in configs:
1255 | hasher.update(
1256 | f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
^
1257 | )
1258 | return hasher.hexdigest()
torch/_inductor/runtime/triton_heuristics.py:1256:58: Builtin `set` is deprecated
1254 | for cfg in configs:
1255 | hasher.update(
1256 | f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
^
1257 | )
1258 | return hasher.hexdigest()
torch/_inductor/runtime/triton_heuristics.py:1256:60: Builtin `set` is deprecated
1254 | for cfg in configs:
1255 | hasher.update(
1256 | f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
^
1257 | )
1258 | return hasher.hexdigest()
torch/_inductor/runtime/triton_heuristics.py:1256:75: Builtin `set` is deprecated
1254 | for cfg in configs:
1255 | hasher.update(
1256 | f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
^
1257 | )
1258 | return hasher.hexdigest()
torch/_inductor/runtime/triton_heuristics.py:1377:23: Builtin `set` is deprecated
1375 | if numel is None:
1376 | continue
1377 | block = cfg[f"{label}BLOCK"]
^
1378 | if numel == 1:
1379 | assert block == 1, (
torch/_inductor/runtime/triton_heuristics.py:1377:29: Builtin `set` is deprecated
1375 | if numel is None:
1376 | continue
1377 | block = cfg[f"{label}BLOCK"]
^
1378 | if numel == 1:
1379 | assert block == 1, (
torch/_inductor/runtime/triton_heuristics.py:1381:24: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:38: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:46: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:52: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:58: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:64: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:71: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:77: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:84: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1381:88: Builtin `set` is deprecated
1379 | assert block == 1, (
1380 | f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1381 | f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
^
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
torch/_inductor/runtime/triton_heuristics.py:1384:52: Builtin `set` is deprecated
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
1384 | max_block_str = f'config.triton.max_block["{label}"]'
^
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
torch/_inductor/runtime/triton_heuristics.py:1384:58: Builtin `set` is deprecated
1382 | )
1383 | max_block = TRITON_MAX_BLOCK[label]
1384 | max_block_str = f'config.triton.max_block["{label}"]'
^
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
torch/_inductor/runtime/triton_heuristics.py:1386:45: Builtin `set` is deprecated
1384 | max_block_str = f'config.triton.max_block["{label}"]'
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
^
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
1388 | )
torch/_inductor/runtime/triton_heuristics.py:1386:51: Builtin `set` is deprecated
1384 | max_block_str = f'config.triton.max_block["{label}"]'
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
^
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
1388 | )
torch/_inductor/runtime/triton_heuristics.py:1386:66: Builtin `set` is deprecated
1384 | max_block_str = f'config.triton.max_block["{label}"]'
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
^
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
1388 | )
torch/_inductor/runtime/triton_heuristics.py:1386:80: Builtin `set` is deprecated
1384 | max_block_str = f'config.triton.max_block["{label}"]'
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
^
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
1388 | )
torch/_inductor/runtime/triton_heuristics.py:1387:20: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:26: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:33: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:39: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:45: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:59: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:61: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:71: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:78: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1387:82: Builtin `set` is deprecated
1385 | assert max_block % block == 0, (
1386 | f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1387 | f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
^
1388 | )
1389 |
torch/_inductor/runtime/triton_heuristics.py:1402:19: Builtin `set` is deprecated
1400 | assert (
1401 | val <= max_block
1402 | ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
^
1403 |
1404 |
torch/_inductor/runtime/triton_heuristics.py:1402:23: Builtin `set` is deprecated
1400 | assert (
1401 | val <= max_block
1402 | ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
^
1403 |
1404 |
torch/_inductor/runtime/triton_heuristics.py:1402:46: Builtin `set` is deprecated
1400 | assert (
1401 | val <= max_block
1402 | ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
^
1403 |
1404 |
torch/_inductor/runtime/triton_heuristics.py:1402:56: Builtin `set` is deprecated
1400 | assert (
1401 | val <= max_block
1402 | ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
^
1403 |
1404 |
torch/_inductor/runtime/triton_heuristics.py:1402:67: Builtin `set` is deprecated
1400 | assert (
1401 | val <= max_block
1402 | ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
^
1403 |
1404 |
torch/_inductor/runtime/triton_heuristics.py:1402:71: Builtin `set` is deprecated
1400 | assert (
1401 | val <= max_block
1402 | ), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
^
1403 |
1404 |
torch/_inductor/runtime/triton_heuristics.py:1551:21: Builtin `set` is deprecated
1549 | rnumels = {}
1550 | for idx in range(num_reduction_dims - 1, -1, -1):
1551 | prefix = f"r{idx}_"
^
1552 | max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
1553 | dim = min(max_size, remaining)
torch/_inductor/runtime/triton_heuristics.py:1551:25: Builtin `set` is deprecated
1549 | rnumels = {}
1550 | for idx in range(num_reduction_dims - 1, -1, -1):
1551 | prefix = f"r{idx}_"
^
1552 | max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
1553 | dim = min(max_size, remaining)
torch/_inductor/runtime/triton_heuristics.py:1556:34: Builtin `set` is deprecated
1554 | assert (
1555 | remaining % dim == 0
1556 | ), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
^
1557 | rnumels[prefix] = dim
1558 | remaining //= dim
torch/_inductor/runtime/triton_heuristics.py:1556:38: Builtin `set` is deprecated
1554 | assert (
1555 | remaining % dim == 0
1556 | ), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
^
1557 | rnumels[prefix] = dim
1558 | remaining //= dim
torch/_inductor/runtime/triton_heuristics.py:1556:67: Builtin `set` is deprecated
1554 | assert (
1555 | remaining % dim == 0
1556 | ), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
^
1557 | rnumels[prefix] = dim
1558 | remaining //= dim
torch/_inductor/runtime/triton_heuristics.py:1556:77: Builtin `set` is deprecated
1554 | assert (
1555 | remaining % dim == 0
1556 | ), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
^
1557 | rnumels[prefix] = dim
1558 | remaining //= dim
torch/_inductor/runtime/triton_heuristics.py:1564:38: Builtin `set` is deprecated
1562 | assert (
1563 | r == final_numel
1564 | ), f"Expected ND reduction size ({rnumels}) to have {r} elements."
^
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
torch/_inductor/runtime/triton_heuristics.py:1564:46: Builtin `set` is deprecated
1562 | assert (
1563 | r == final_numel
1564 | ), f"Expected ND reduction size ({rnumels}) to have {r} elements."
^
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
torch/_inductor/runtime/triton_heuristics.py:1564:57: Builtin `set` is deprecated
1562 | assert (
1563 | r == final_numel
1564 | ), f"Expected ND reduction size ({rnumels}) to have {r} elements."
^
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
torch/_inductor/runtime/triton_heuristics.py:1564:59: Builtin `set` is deprecated
1562 | assert (
1563 | r == final_numel
1564 | ), f"Expected ND reduction size ({rnumels}) to have {r} elements."
^
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
torch/_inductor/runtime/triton_heuristics.py:1567:37: Builtin `set` is deprecated
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
1567 | ), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
^
1568 |
1569 | return rnumels
torch/_inductor/runtime/triton_heuristics.py:1567:45: Builtin `set` is deprecated
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
1567 | ), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
^
1568 |
1569 | return rnumels
torch/_inductor/runtime/triton_heuristics.py:1567:49: Builtin `set` is deprecated
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
1567 | ), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
^
1568 |
1569 | return rnumels
torch/_inductor/runtime/triton_heuristics.py:1567:60: Builtin `set` is deprecated
1565 | assert all(
1566 | rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
1567 | ), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
^
1568 |
1569 | return rnumels
torch/_inductor/runtime/triton_heuristics.py:1746:49: Builtin `set` is deprecated
1744 |
1745 | if not configs:
1746 | raise NotImplementedError(f"size_hints: {size_hints}")
^
1747 | return cached_autotune(
1748 | size_hints,
torch/_inductor/runtime/triton_heuristics.py:1746:60: Builtin `set` is deprecated
1744 |
1745 | if not configs:
1746 | raise NotImplementedError(f"size_hints: {size_hints}")
^
1747 | return cached_autotune(
1748 | size_hints,
torch/_inductor/runtime/triton_heuristics.py:1928:32: Builtin `set` is deprecated
1926 | for prefix in size_hints:
1927 | if prefix_is_reduction(prefix):
1928 | c.kwargs.pop(f"{prefix.upper()}BLOCK")
^
1929 |
1930 | if disable_pointwise_autotuning(inductor_meta):
torch/_inductor/runtime/triton_heuristics.py:1928:47: Builtin `set` is deprecated
1926 | for prefix in size_hints:
1927 | if prefix_is_reduction(prefix):
1928 | c.kwargs.pop(f"{prefix.upper()}BLOCK")
^
1929 |
1930 | if disable_pointwise_autotuning(inductor_meta):
torch/_inductor/runtime/triton_heuristics.py:1975:49: Builtin `set` is deprecated
1973 | assert triton_meta is not None
1974 | if len(size_hints) != 2:
1975 | raise NotImplementedError(f"size_hints: {size_hints}")
^
1976 |
1977 | configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
torch/_inductor/runtime/triton_heuristics.py:1975:60: Builtin `set` is deprecated
1973 | assert triton_meta is not None
1974 | if len(size_hints) != 2:
1975 | raise NotImplementedError(f"size_hints: {size_hints}")
^
1976 |
1977 | configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
torch/_inductor/runtime/triton_heuristics.py:2082:56: Builtin `set` is deprecated
2080 | xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
2081 | else:
2082 | raise AssertionError(f"invalid size for numels {len(numels)}")
^
2083 |
2084 | def get_grid_dim(numel, block):
torch/_inductor/runtime/triton_heuristics.py:2082:68: Builtin `set` is deprecated
2080 | xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
2081 | else:
2082 | raise AssertionError(f"invalid size for numels {len(numels)}")
^
2083 |
2084 | def get_grid_dim(numel, block):
torch/_inductor/runtime/triton_heuristics.py:2104:57: Builtin `set` is deprecated
2102 | torch._check(
2103 | y_grid <= max_y_grid,
2104 | lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
^
2105 | )
2106 |
torch/_inductor/runtime/triton_heuristics.py:2104:64: Builtin `set` is deprecated
2102 | torch._check(
2103 | y_grid <= max_y_grid,
2104 | lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
^
2105 | )
2106 |
torch/_inductor/runtime/triton_heuristics.py:2113:43: Builtin `set` is deprecated
2111 | )
2112 |
2113 | setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010
^
2114 |
2115 | return grid_fn
torch/_inductor/runtime/triton_heuristics.py:2113:50: Builtin `set` is deprecated
2111 | )
2112 |
2113 | setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010
^
2114 |
2115 | return grid_fn
torch/_inductor/runtime/triton_heuristics.py:2122:48: Builtin `set` is deprecated
2120 | return (meta["RSPLIT"], ceildiv(xnumel, meta.get("XBLOCK", 1)), 1)
2121 |
2122 | grid_fn_str = f"cooperative_reduction_grid({xnumel})"
^
2123 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2124 | return grid_fn
torch/_inductor/runtime/triton_heuristics.py:2122:55: Builtin `set` is deprecated
2120 | return (meta["RSPLIT"], ceildiv(xnumel, meta.get("XBLOCK", 1)), 1)
2121 |
2122 | grid_fn_str = f"cooperative_reduction_grid({xnumel})"
^
2123 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2124 | return grid_fn
torch/_inductor/runtime/triton_heuristics.py:2135:54: Builtin `set` is deprecated
2133 | coop_grid = cooperative_reduction_grid(xnumel)
2134 | normal_grid = grid(xnumel)
2135 | grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
^
2136 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2137 | return grid_fn
torch/_inductor/runtime/triton_heuristics.py:2135:61: Builtin `set` is deprecated
2133 | coop_grid = cooperative_reduction_grid(xnumel)
2134 | normal_grid = grid(xnumel)
2135 | grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
^
2136 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2137 | return grid_fn
torch/_inductor/runtime/triton_heuristics.py:2145:37: Builtin `set` is deprecated
2143 | return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
2144 |
2145 | grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
^
2146 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2147 |
torch/_inductor/runtime/triton_heuristics.py:2145:44: Builtin `set` is deprecated
2143 | return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
2144 |
2145 | grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
^
2146 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2147 |
torch/_inductor/runtime/triton_heuristics.py:2145:47: Builtin `set` is deprecated
2143 | return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
2144 |
2145 | grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
^
2146 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2147 |
torch/_inductor/runtime/triton_heuristics.py:2145:54: Builtin `set` is deprecated
2143 | return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
2144 |
2145 | grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
^
2146 | setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
2147 |
torch/_inductor/runtime/triton_heuristics.py:2173:42: Builtin `set` is deprecated
2171 | assert (
2172 | min_blocks_d is None or min_blocks == min_blocks_d
2173 | ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
^
2174 | else:
2175 | # sequential dispatch
torch/_inductor/runtime/triton_heuristics.py:2173:53: Builtin `set` is deprecated
2171 | assert (
2172 | min_blocks_d is None or min_blocks == min_blocks_d
2173 | ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
^
2174 | else:
2175 | # sequential dispatch
torch/_inductor/runtime/triton_heuristics.py:2173:66: Builtin `set` is deprecated
2171 | assert (
2172 | min_blocks_d is None or min_blocks == min_blocks_d
2173 | ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
^
2174 | else:
2175 | # sequential dispatch
torch/_inductor/runtime/triton_heuristics.py:2173:77: Builtin `set` is deprecated
2171 | assert (
2172 | min_blocks_d is None or min_blocks == min_blocks_d
2173 | ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
^
2174 | else:
2175 | # sequential dispatch
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143628
Approved by: https://github.com/yanboliang , https://github.com/rec
2024-12-20 11:45:26 +00:00
zeshengzong
217a4ddb04
Add range check embedding_bag on input index >= 0 of cuda device ( #140791 )
...
Fixes #89362
**Test Result**
**Before**
```
>>> import torch
>>> input = torch.randint(-5, 1, [1, 2], dtype=torch.int64).cuda()
>>> weight = torch.rand([2, 3], dtype=torch.float32).cuda()
>>> print(torch.nn.functional.embedding_bag(input, weight))
tensor([[0., 0., 0.]], device='cuda:0')
```
**After**
```python
>>> import torch
>>> input = torch.randint(-5, 1, [1, 2], dtype=torch.int64).cuda()
>>> weight = torch.rand([2, 3], dtype=torch.float32).cuda()
>>> print(torch.nn.functional.embedding_bag(input, weight))
/home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [0,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed.
/home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [1,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed.
/home/zong/code/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu:141: EmbeddingBag_updateOutputKernel_sum_mean: block: [0,0,0], thread: [2,0,0] Assertion `0 <= input_idx && input_idx < numRows` failed.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/zong/code/pytorch/torch/_tensor.py", line 568, in __repr__
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/_tensor_str.py", line 708, in _str
return _str_intern(self, tensor_contents=tensor_contents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/_tensor_str.py", line 625, in _str_intern
tensor_str = _tensor_str(self, indent)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/_tensor_str.py", line 357, in _tensor_str
formatter = _Formatter(get_summarized_data(self) if summarize else self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zong/code/pytorch/torch/_tensor_str.py", line 146, in __init__
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
```
```bash
$ pytest test/nn/test_embedding.py
```

```bash
$ lintrunner
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140791
Approved by: https://github.com/eqy
2024-12-20 05:47:26 +00:00