pytorch/torch/_inductor/optimize_indexing.py
Aaron Gokaslan c5fafe9f48 [BE]: TRY002 - Ban raising vanilla exceptions (#124570)
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.

I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
2024-04-21 22:26:40 +00:00

119 lines
3.8 KiB
Python

import math
import sympy
import torch
from torch.utils._sympy.value_ranges import ValueRanges
from .ir import LoopBody
from .utils import dominated_nodes
def val_expressable_in_32_bits(val):
if getattr(val, "is_Boolean", False):
return True
if isinstance(val, sympy.Expr):
assert val.is_number
if val.is_Integer or val.is_Boolean:
val = int(val)
else:
val = float(val)
# bound within mantissa
if isinstance(val, float):
return val <= (2**24) and val >= -(2**24)
if isinstance(val, int):
iinfo = torch.iinfo(torch.int32)
return val <= iinfo.max and val >= iinfo.min
raise Exception(f"Unexpected value {val}") # noqa: TRY002
def range_expressable_in_32_bits(range):
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
range.upper
)
def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
# then it's precision is set for that chain of uses, and we don't need to consider those
# dominated values
def skip_filter(node):
return node.target == "to_dtype" and node.args[2] in (
torch.int32,
torch.float32,
torch.float64,
)
# TODO - there are dominated uses whose dtype does not depend on whether
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
# int32 without changing the output precision of the node. this case hasn't shown up
for dominated in dominated_nodes([node], skip_filter):
if dominated.target in ["store", "output"]:
continue
if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
idx = int(dominated.target[len("set_indirect") :])
indirect_var = indirect_vars[idx]
# We check that we can compute all the indices it's involved in with int32
for index, expr in indices.items():
if indirect_var in expr.free_symbols:
index_val = replacement_vals[index]
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
return
# all indices are integers, so make sure that we
# use the bounds of integers instead of floats.
# TODO - not sure if we should be doing int/float casts while tracing,
# might interfere with sympy.
index_val_int = ValueRanges[sympy.Expr](
int(index_val.lower), int(index_val.upper)
)
if not range_expressable_in_32_bits(index_val_int):
return
if not range_expressable_in_32_bits(bounds[dominated]):
return
args = list(node.args)
args[2] = torch.int32
node.args = tuple(args)
def indexing_dtype_strength_reduction(loop_body: LoopBody):
"""
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
intermediaries from int64 to int32
"""
bv = loop_body.bounds()
int64_dtype_nodes = [
node
for node in loop_body.get_nodes()
if (
node.target == "to_dtype"
and node.args[2] == torch.int64
and node not in bv.unbounded_vars
)
]
if not int64_dtype_nodes:
return
bounds = bv.get_bounds()
# TODO - if dominated node of one to_dtype is not expressible in int32,
# we should short circuit another to_dtype node if that node also dominates
for node in int64_dtype_nodes:
try_to_reduce_precision(
node,
bounds,
loop_body.indirect_vars,
loop_body.indexing_exprs,
bv.replacement_vals,
)