mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[spmd expansion] support torch.ops.aten.sym_numel (#98229)
The current logic assumes non-overload ops takes two arguments however torch.ops.aten.sym_numel takes one. Differential Revision: [D44615037](https://our.internmc.facebook.com/intern/diff/D44615037/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/98229 Approved by: https://github.com/mrshenli
This commit is contained in:
parent
a6bd21d935
commit
4d13fcddef
|
|
@ -492,11 +492,14 @@ def _convert_to_distributed(
|
|||
return arg
|
||||
|
||||
args = tree_map(_remap_arg, node.args)
|
||||
assert (
|
||||
len(args) >= 2
|
||||
), f"Expected number of args for call function to be at least 2, found {len(args)}"
|
||||
# TODO(anj): Why do we assume this is only 2?
|
||||
node_to_obj[node] = node.target(args[0], args[1])
|
||||
if node.target == torch.ops.aten.sym_numel:
|
||||
node_to_obj[node] = args[0].numel()
|
||||
else:
|
||||
assert (
|
||||
len(args) >= 2
|
||||
), f"Expected number of args for call function to be at least 2, found {len(args)} {node}"
|
||||
# TODO(anj): Why do we assume this is only 2?
|
||||
node_to_obj[node] = node.target(args[0], args[1])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized node.op type {node.op}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user