[PT2] - Guard oblivious on meta registrations (#122216)

Summary:
```
[trainer0|0]:Potential framework code culprit (scroll up for full backtrace):
[trainer0|0]:  File "/mnt/xarfuse/uid-539346/56d4bb3d-seed-nspid4026531836_cgpid183208940-ns-4026531840/torch/_meta_registrations.py", line 5043, in scatter_gather_dtype_check
[trainer0|0]:    if index.numel() != 0:
```

Test Plan: General CI.

Reviewed By: ezyang

Differential Revision: D54689183

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122216
Approved by: https://github.com/ezyang
This commit is contained in:
Flavio Sales Truzzi 2024-03-22 01:36:03 +00:00 committed by PyTorch MergeBot
parent 4f93b3d958
commit bde22835c6

View File

@ -5040,8 +5040,10 @@ def gather_shape_check(self, dim, index):
@register_meta(aten.gather.default)
def meta_gather(self, dim, index, sparse_grad=False):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
wrapped_dim = maybe_wrap_dim(dim, self.dim())
is_index_empty = index.numel() == 0
is_index_empty = guard_size_oblivious(index.numel() == 0)
if not is_index_empty:
torch._check(
index.dtype == torch.long,
@ -5080,7 +5082,9 @@ def get_operator_enum(reduce_, use_new_options=False):
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
if index.numel() != 0:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(index.numel() != 0):
torch._check(
index.dtype == torch.long,
lambda: f"{method_name}(): Expected dtype int64 for index",
@ -5099,7 +5103,9 @@ def ensure_nonempty_dim(dim):
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_shape_check(self, dim, index, src_opt=None):
if index.numel() == 0:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(index.numel() == 0):
return
torch._check(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),