[spmd expansion] support torch.ops.aten.sym_numel (#98229)

The current logic assumes non-overload ops takes two arguments however torch.ops.aten.sym_numel takes one.

Differential Revision: [D44615037](https://our.internmc.facebook.com/intern/diff/D44615037/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98229
Approved by: https://github.com/mrshenli
This commit is contained in:
Yifu Wang 2023-04-03 13:26:30 -07:00 committed by PyTorch MergeBot
parent a6bd21d935
commit 4d13fcddef

View File

@ -492,11 +492,14 @@ def _convert_to_distributed(
return arg
args = tree_map(_remap_arg, node.args)
assert (
len(args) >= 2
), f"Expected number of args for call function to be at least 2, found {len(args)}"
# TODO(anj): Why do we assume this is only 2?
node_to_obj[node] = node.target(args[0], args[1])
if node.target == torch.ops.aten.sym_numel:
node_to_obj[node] = args[0].numel()
else:
assert (
len(args) >= 2
), f"Expected number of args for call function to be at least 2, found {len(args)} {node}"
# TODO(anj): Why do we assume this is only 2?
node_to_obj[node] = node.target(args[0], args[1])
else:
raise ValueError(f"Unrecognized node.op type {node.op}")