mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PT2] - Guard oblivious on meta registrations (#122216)
Summary: ``` [trainer0|0]:Potential framework code culprit (scroll up for full backtrace): [trainer0|0]: File "/mnt/xarfuse/uid-539346/56d4bb3d-seed-nspid4026531836_cgpid183208940-ns-4026531840/torch/_meta_registrations.py", line 5043, in scatter_gather_dtype_check [trainer0|0]: if index.numel() != 0: ``` Test Plan: General CI. Reviewed By: ezyang Differential Revision: D54689183 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122216 Approved by: https://github.com/ezyang
This commit is contained in:
parent
4f93b3d958
commit
bde22835c6
|
|
@ -5040,8 +5040,10 @@ def gather_shape_check(self, dim, index):
|
|||
|
||||
@register_meta(aten.gather.default)
|
||||
def meta_gather(self, dim, index, sparse_grad=False):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
wrapped_dim = maybe_wrap_dim(dim, self.dim())
|
||||
is_index_empty = index.numel() == 0
|
||||
is_index_empty = guard_size_oblivious(index.numel() == 0)
|
||||
if not is_index_empty:
|
||||
torch._check(
|
||||
index.dtype == torch.long,
|
||||
|
|
@ -5080,7 +5082,9 @@ def get_operator_enum(reduce_, use_new_options=False):
|
|||
|
||||
# From aten/src/ATen/native/ScatterGatherChecks.h
|
||||
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
|
||||
if index.numel() != 0:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if guard_size_oblivious(index.numel() != 0):
|
||||
torch._check(
|
||||
index.dtype == torch.long,
|
||||
lambda: f"{method_name}(): Expected dtype int64 for index",
|
||||
|
|
@ -5099,7 +5103,9 @@ def ensure_nonempty_dim(dim):
|
|||
|
||||
# From aten/src/ATen/native/ScatterGatherChecks.h
|
||||
def scatter_shape_check(self, dim, index, src_opt=None):
|
||||
if index.numel() == 0:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if guard_size_oblivious(index.numel() == 0):
|
||||
return
|
||||
torch._check(
|
||||
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user