mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make hashing a SymInt raise an error again (#130548)
See https://github.com/pytorch/pytorch/issues/130547 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130548 Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/lezcano
This commit is contained in:
parent
1d8baa4df2
commit
408c921d96
|
|
@ -7,6 +7,7 @@ import itertools
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
import re
|
import re
|
||||||
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -1262,11 +1263,15 @@ class TestSymNumberMagicMethods(TestCase):
|
||||||
def get_constant_bool(self, val):
|
def get_constant_bool(self, val):
|
||||||
return SymBool(torch._C._get_constant_bool_symnode(val))
|
return SymBool(torch._C._get_constant_bool_symnode(val))
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_symint_hashing(self):
|
||||||
|
shape_env = ShapeEnv()
|
||||||
|
hash(create_symint(shape_env, 3))
|
||||||
|
|
||||||
def test_symnode_hashing(self):
|
def test_symnode_hashing(self):
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
|
|
||||||
# These all trigger specialization when hashed
|
# These all trigger specialization when hashed
|
||||||
hash(create_symint(shape_env, 3))
|
|
||||||
hash(create_symbool(shape_env, True))
|
hash(create_symbool(shape_env, True))
|
||||||
# We should be passing in float here, but create_symbol currently
|
# We should be passing in float here, but create_symbol currently
|
||||||
# only supports int
|
# only supports int
|
||||||
|
|
|
||||||
|
|
@ -520,24 +520,31 @@ class SymInt:
|
||||||
return self.node.expr
|
return self.node.expr
|
||||||
|
|
||||||
def __hash__(self) -> builtins.int:
|
def __hash__(self) -> builtins.int:
|
||||||
return hash(self._get_int())
|
if self.node.is_nested_int():
|
||||||
|
return hash(self.node.nested_int())
|
||||||
|
else:
|
||||||
|
# We could support constant SymInts as well, but not doing it for now
|
||||||
|
raise TypeError("unhashable type: non-nested SymInt")
|
||||||
|
# TODO: Force specialization
|
||||||
|
# This can't be done because the TypeError here is load bearing
|
||||||
|
# for einops
|
||||||
|
# https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
|
||||||
|
# return hash(builtins.int(self))
|
||||||
|
|
||||||
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]:
|
||||||
"""Represent this int as an exact integer ratio"""
|
"""Represent this int as an exact integer ratio"""
|
||||||
return self._get_int(), 1
|
return self, 1
|
||||||
|
|
||||||
def bit_length(self) -> "SymInt":
|
def bit_length(self) -> builtins.int:
|
||||||
return SymInt(self.node.wrap_int(self._get_int().bit_length()))
|
# TODO: A more relaxed guard is possible here, where you guard to
|
||||||
|
# allow all integer quantities which would result in the same bit
|
||||||
|
# length. We can also just make a dedicated Sympy function for
|
||||||
|
# computing this quantity and represent it symbolically.
|
||||||
|
return builtins.int(self).bit_length()
|
||||||
|
|
||||||
def conjugate(self) -> "SymInt":
|
def conjugate(self) -> "SymInt":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _get_int(self) -> builtins.int:
|
|
||||||
if self.node.is_nested_int():
|
|
||||||
return self.node.nested_int()
|
|
||||||
else:
|
|
||||||
return builtins.int(self)
|
|
||||||
|
|
||||||
|
|
||||||
class SymFloat:
|
class SymFloat:
|
||||||
"""
|
"""
|
||||||
|
|
@ -638,7 +645,7 @@ class SymFloat:
|
||||||
|
|
||||||
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
|
||||||
"""Represent this float as an exact integer ratio"""
|
"""Represent this float as an exact integer ratio"""
|
||||||
return self._get_float().as_integer_ratio()
|
return builtins.float(self).as_integer_ratio()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.node._graph_repr()
|
return self.node._graph_repr()
|
||||||
|
|
@ -647,10 +654,7 @@ class SymFloat:
|
||||||
return self.node.expr
|
return self.node.expr
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self._get_float())
|
return hash(builtins.float(self))
|
||||||
|
|
||||||
def _get_float(self) -> builtins.float:
|
|
||||||
return self.node.float_() if self.node.is_constant() else builtins.float(self)
|
|
||||||
|
|
||||||
|
|
||||||
class SymBool:
|
class SymBool:
|
||||||
|
|
|
||||||
|
|
@ -2122,7 +2122,9 @@ def wrap_fx_proxy_cls(
|
||||||
):
|
):
|
||||||
set_example_value(proxy.node, example_value)
|
set_example_value(proxy.node, example_value)
|
||||||
return EventVariable(proxy, example_value, **options)
|
return EventVariable(proxy, example_value, **options)
|
||||||
elif isinstance(example_value, int) and proxy.node.target in [
|
elif isinstance(example_value, int) and (
|
||||||
|
proxy.node.target
|
||||||
|
in [
|
||||||
torch.sym_int,
|
torch.sym_int,
|
||||||
getattr,
|
getattr,
|
||||||
operator.getitem,
|
operator.getitem,
|
||||||
|
|
@ -2136,7 +2138,13 @@ def wrap_fx_proxy_cls(
|
||||||
# This always wants to be in the graph, even if the constraint
|
# This always wants to be in the graph, even if the constraint
|
||||||
# results in a constant int
|
# results in a constant int
|
||||||
torch._constrain_as_size,
|
torch._constrain_as_size,
|
||||||
]:
|
]
|
||||||
|
or (
|
||||||
|
# TODO: this is a little sus, because we didn't check what the self is
|
||||||
|
proxy.node.op == "call_method"
|
||||||
|
and proxy.node.target in ["bit_length"]
|
||||||
|
)
|
||||||
|
):
|
||||||
set_example_value(proxy.node, example_value)
|
set_example_value(proxy.node, example_value)
|
||||||
return ConstantVariable.create(example_value, **options)
|
return ConstantVariable.create(example_value, **options)
|
||||||
elif isinstance(example_value, torch.backends.cuda.SDPAParams):
|
elif isinstance(example_value, torch.backends.cuda.SDPAParams):
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import typing
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
@ -932,7 +933,14 @@ class TritonHOPifier:
|
||||||
self.raise_unsupported("Grid can have at most rank 3")
|
self.raise_unsupported("Grid can have at most rank 3")
|
||||||
|
|
||||||
assert len(grids) != 0
|
assert len(grids) != 0
|
||||||
if len(set(grids)) == 1:
|
|
||||||
|
def intify(x):
|
||||||
|
if isinstance(x, torch.SymInt):
|
||||||
|
return int(x)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if len(set(pytree.tree_map(intify, grids))) == 1:
|
||||||
# If there's only one unique grid, lets simplify
|
# If there's only one unique grid, lets simplify
|
||||||
grids = [grids[0]]
|
grids = [grids[0]]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
|
||||||
pass
|
pass
|
||||||
# Going via an iterator directly is slower than via list comprehension.
|
# Going via an iterator directly is slower than via list comprehension.
|
||||||
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
|
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
|
||||||
if not allow_duplicate and len(set(axis)) != len(axis):
|
if not allow_duplicate and len(set(map(int, axis))) != len(axis):
|
||||||
if argname:
|
if argname:
|
||||||
raise ValueError(f"repeated axis in `{argname}` argument")
|
raise ValueError(f"repeated axis in `{argname}` argument")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user