Make hashing a SymInt raise an error again (#130548)

See https://github.com/pytorch/pytorch/issues/130547

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130548
Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/lezcano
This commit is contained in:
Edward Z. Yang 2024-07-16 06:52:16 -07:00 committed by PyTorch MergeBot
parent 1d8baa4df2
commit 408c921d96
5 changed files with 59 additions and 34 deletions

View File

@ -7,6 +7,7 @@ import itertools
import math import math
import operator import operator
import re import re
import unittest
import numpy as np import numpy as np
@ -1262,11 +1263,15 @@ class TestSymNumberMagicMethods(TestCase):
def get_constant_bool(self, val): def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val)) return SymBool(torch._C._get_constant_bool_symnode(val))
@unittest.expectedFailure
def test_symint_hashing(self):
shape_env = ShapeEnv()
hash(create_symint(shape_env, 3))
def test_symnode_hashing(self): def test_symnode_hashing(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
# These all trigger specialization when hashed # These all trigger specialization when hashed
hash(create_symint(shape_env, 3))
hash(create_symbool(shape_env, True)) hash(create_symbool(shape_env, True))
# We should be passing in float here, but create_symbol currently # We should be passing in float here, but create_symbol currently
# only supports int # only supports int

View File

@ -520,24 +520,31 @@ class SymInt:
return self.node.expr return self.node.expr
def __hash__(self) -> builtins.int: def __hash__(self) -> builtins.int:
return hash(self._get_int()) if self.node.is_nested_int():
return hash(self.node.nested_int())
else:
# We could support constant SymInts as well, but not doing it for now
raise TypeError("unhashable type: non-nested SymInt")
# TODO: Force specialization
# This can't be done because the TypeError here is load bearing
# for einops
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
# return hash(builtins.int(self))
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]: def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
"""Represent this int as an exact integer ratio""" """Represent this int as an exact integer ratio"""
return self._get_int(), 1 return self, 1
def bit_length(self) -> "SymInt": def bit_length(self) -> builtins.int:
return SymInt(self.node.wrap_int(self._get_int().bit_length())) # TODO: A more relaxed guard is possible here, where you guard to
# allow all integer quantities which would result in the same bit
# length. We can also just make a dedicated Sympy function for
# computing this quantity and represent it symbolically.
return builtins.int(self).bit_length()
def conjugate(self) -> "SymInt": def conjugate(self) -> "SymInt":
return self return self
def _get_int(self) -> builtins.int:
if self.node.is_nested_int():
return self.node.nested_int()
else:
return builtins.int(self)
class SymFloat: class SymFloat:
""" """
@ -638,7 +645,7 @@ class SymFloat:
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]: def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
"""Represent this float as an exact integer ratio""" """Represent this float as an exact integer ratio"""
return self._get_float().as_integer_ratio() return builtins.float(self).as_integer_ratio()
def __repr__(self): def __repr__(self):
return self.node._graph_repr() return self.node._graph_repr()
@ -647,10 +654,7 @@ class SymFloat:
return self.node.expr return self.node.expr
def __hash__(self): def __hash__(self):
return hash(self._get_float()) return hash(builtins.float(self))
def _get_float(self) -> builtins.float:
return self.node.float_() if self.node.is_constant() else builtins.float(self)
class SymBool: class SymBool:

View File

@ -2122,7 +2122,9 @@ def wrap_fx_proxy_cls(
): ):
set_example_value(proxy.node, example_value) set_example_value(proxy.node, example_value)
return EventVariable(proxy, example_value, **options) return EventVariable(proxy, example_value, **options)
elif isinstance(example_value, int) and proxy.node.target in [ elif isinstance(example_value, int) and (
proxy.node.target
in [
torch.sym_int, torch.sym_int,
getattr, getattr,
operator.getitem, operator.getitem,
@ -2136,7 +2138,13 @@ def wrap_fx_proxy_cls(
# This always wants to be in the graph, even if the constraint # This always wants to be in the graph, even if the constraint
# results in a constant int # results in a constant int
torch._constrain_as_size, torch._constrain_as_size,
]: ]
or (
# TODO: this is a little sus, because we didn't check what the self is
proxy.node.op == "call_method"
and proxy.node.target in ["bit_length"]
)
):
set_example_value(proxy.node, example_value) set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options) return ConstantVariable.create(example_value, **options)
elif isinstance(example_value, torch.backends.cuda.SDPAParams): elif isinstance(example_value, torch.backends.cuda.SDPAParams):

View File

@ -9,6 +9,7 @@ import typing
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch
import torch.fx as fx import torch.fx as fx
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -932,7 +933,14 @@ class TritonHOPifier:
self.raise_unsupported("Grid can have at most rank 3") self.raise_unsupported("Grid can have at most rank 3")
assert len(grids) != 0 assert len(grids) != 0
if len(set(grids)) == 1:
def intify(x):
if isinstance(x, torch.SymInt):
return int(x)
else:
return x
if len(set(pytree.tree_map(intify, grids))) == 1:
# If there's only one unique grid, lets simplify # If there's only one unique grid, lets simplify
grids = [grids[0]] grids = [grids[0]]

View File

@ -90,7 +90,7 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
pass pass
# Going via an iterator directly is slower than via list comprehension. # Going via an iterator directly is slower than via list comprehension.
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
if not allow_duplicate and len(set(axis)) != len(axis): if not allow_duplicate and len(set(map(int, axis))) != len(axis):
if argname: if argname:
raise ValueError(f"repeated axis in `{argname}` argument") raise ValueError(f"repeated axis in `{argname}` argument")
else: else: