mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR. I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570 Approved by: https://github.com/ezyang
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
import math
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
from .ir import LoopBody
|
|
from .utils import dominated_nodes
|
|
|
|
|
|
def val_expressable_in_32_bits(val):
|
|
if getattr(val, "is_Boolean", False):
|
|
return True
|
|
|
|
if isinstance(val, sympy.Expr):
|
|
assert val.is_number
|
|
if val.is_Integer or val.is_Boolean:
|
|
val = int(val)
|
|
else:
|
|
val = float(val)
|
|
|
|
# bound within mantissa
|
|
if isinstance(val, float):
|
|
return val <= (2**24) and val >= -(2**24)
|
|
|
|
if isinstance(val, int):
|
|
iinfo = torch.iinfo(torch.int32)
|
|
return val <= iinfo.max and val >= iinfo.min
|
|
|
|
raise Exception(f"Unexpected value {val}") # noqa: TRY002
|
|
|
|
|
|
def range_expressable_in_32_bits(range):
|
|
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
|
|
range.upper
|
|
)
|
|
|
|
|
|
def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
|
|
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
|
|
# then it's precision is set for that chain of uses, and we don't need to consider those
|
|
# dominated values
|
|
def skip_filter(node):
|
|
return node.target == "to_dtype" and node.args[2] in (
|
|
torch.int32,
|
|
torch.float32,
|
|
torch.float64,
|
|
)
|
|
|
|
# TODO - there are dominated uses whose dtype does not depend on whether
|
|
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
|
|
# int32 without changing the output precision of the node. this case hasn't shown up
|
|
for dominated in dominated_nodes([node], skip_filter):
|
|
if dominated.target in ["store", "output"]:
|
|
continue
|
|
|
|
if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
|
|
idx = int(dominated.target[len("set_indirect") :])
|
|
indirect_var = indirect_vars[idx]
|
|
|
|
# We check that we can compute all the indices it's involved in with int32
|
|
for index, expr in indices.items():
|
|
if indirect_var in expr.free_symbols:
|
|
index_val = replacement_vals[index]
|
|
|
|
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
|
|
return
|
|
|
|
# all indices are integers, so make sure that we
|
|
# use the bounds of integers instead of floats.
|
|
# TODO - not sure if we should be doing int/float casts while tracing,
|
|
# might interfere with sympy.
|
|
|
|
index_val_int = ValueRanges[sympy.Expr](
|
|
int(index_val.lower), int(index_val.upper)
|
|
)
|
|
if not range_expressable_in_32_bits(index_val_int):
|
|
return
|
|
|
|
if not range_expressable_in_32_bits(bounds[dominated]):
|
|
return
|
|
|
|
args = list(node.args)
|
|
args[2] = torch.int32
|
|
node.args = tuple(args)
|
|
|
|
|
|
def indexing_dtype_strength_reduction(loop_body: LoopBody):
|
|
"""
|
|
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
|
|
intermediaries from int64 to int32
|
|
"""
|
|
bv = loop_body.bounds()
|
|
|
|
int64_dtype_nodes = [
|
|
node
|
|
for node in loop_body.get_nodes()
|
|
if (
|
|
node.target == "to_dtype"
|
|
and node.args[2] == torch.int64
|
|
and node not in bv.unbounded_vars
|
|
)
|
|
]
|
|
if not int64_dtype_nodes:
|
|
return
|
|
|
|
bounds = bv.get_bounds()
|
|
|
|
# TODO - if dominated node of one to_dtype is not expressible in int32,
|
|
# we should short circuit another to_dtype node if that node also dominates
|
|
for node in int64_dtype_nodes:
|
|
try_to_reduce_precision(
|
|
node,
|
|
bounds,
|
|
loop_body.indirect_vars,
|
|
loop_body.indexing_exprs,
|
|
bv.replacement_vals,
|
|
)
|