address remaining straight forward gso in meta_registrations (#156902)

Those are all straight forward generalization of existing checks,
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156902
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Laith Sakka 2025-06-25 17:28:46 -07:00 committed by PyTorch MergeBot
parent 640703d95f
commit cbcffce48a

View File

@ -3593,9 +3593,9 @@ def meta_index_Tensor(self, indices):
return self.as_strided(shape, strides)
out = self.new_empty(before_shape + replacement_shape + after_shape)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false
if guard_size_oblivious(self.numel() == 0):
if guard_or_false(self.numel() == 0):
# No need to worry about the output strides if self is empty.
return out
@ -5606,10 +5606,10 @@ def gather_shape_check(self, dim, index):
@register_meta(aten.gather.default)
def meta_gather(self, dim, index, sparse_grad=False):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false
wrapped_dim = maybe_wrap_dim(dim, self.dim())
is_index_empty = guard_size_oblivious(index.numel() == 0)
is_index_empty = guard_or_false(index.numel() == 0)
if not is_index_empty:
torch._check(
index.dtype == torch.long or index.dtype == torch.int,
@ -5648,9 +5648,9 @@ def get_operator_enum(reduce_, use_new_options=False):
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_true
if guard_size_oblivious(index.numel() != 0):
if guard_or_true(index.numel() != 0):
torch._check(
index.dtype == torch.long or index.dtype == torch.int,
lambda: f"{method_name}(): Expected dtype int32/int64 for index",
@ -5669,9 +5669,9 @@ def ensure_nonempty_dim(dim):
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_shape_check(self, dim, index, src_opt=None):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false
if guard_size_oblivious(index.numel() == 0):
if guard_or_false(index.numel() == 0):
return
torch._check(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),