mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
address remaining straight forward gso in meta_registrations (#156902)
Those are all straight forward generalization of existing checks, Pull Request resolved: https://github.com/pytorch/pytorch/pull/156902 Approved by: https://github.com/ColinPeppler
This commit is contained in:
parent
640703d95f
commit
cbcffce48a
|
|
@ -3593,9 +3593,9 @@ def meta_index_Tensor(self, indices):
|
|||
return self.as_strided(shape, strides)
|
||||
|
||||
out = self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if guard_size_oblivious(self.numel() == 0):
|
||||
if guard_or_false(self.numel() == 0):
|
||||
# No need to worry about the output strides if self is empty.
|
||||
return out
|
||||
|
||||
|
|
@ -5606,10 +5606,10 @@ def gather_shape_check(self, dim, index):
|
|||
|
||||
@register_meta(aten.gather.default)
|
||||
def meta_gather(self, dim, index, sparse_grad=False):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
wrapped_dim = maybe_wrap_dim(dim, self.dim())
|
||||
is_index_empty = guard_size_oblivious(index.numel() == 0)
|
||||
is_index_empty = guard_or_false(index.numel() == 0)
|
||||
if not is_index_empty:
|
||||
torch._check(
|
||||
index.dtype == torch.long or index.dtype == torch.int,
|
||||
|
|
@ -5648,9 +5648,9 @@ def get_operator_enum(reduce_, use_new_options=False):
|
|||
|
||||
# From aten/src/ATen/native/ScatterGatherChecks.h
|
||||
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
||||
|
||||
if guard_size_oblivious(index.numel() != 0):
|
||||
if guard_or_true(index.numel() != 0):
|
||||
torch._check(
|
||||
index.dtype == torch.long or index.dtype == torch.int,
|
||||
lambda: f"{method_name}(): Expected dtype int32/int64 for index",
|
||||
|
|
@ -5669,9 +5669,9 @@ def ensure_nonempty_dim(dim):
|
|||
|
||||
# From aten/src/ATen/native/ScatterGatherChecks.h
|
||||
def scatter_shape_check(self, dim, index, src_opt=None):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if guard_size_oblivious(index.numel() == 0):
|
||||
if guard_or_false(index.numel() == 0):
|
||||
return
|
||||
torch._check(
|
||||
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user