[spmd] quick fix on batch input view issue (#98813)

This is a quick fix/hack to get around with the issue that some
"global" tensor view operation is invalid, but somehow it get
triggered by some models as mini-batch input itself won't have this
issue.

Since ultimately we should remove the dtensor expand and use the new
expansion, this hack is only temporary to unblock
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98813
Approved by: https://github.com/yifuwang, https://github.com/mrshenli
This commit is contained in:
Wanchao Liang 2023-04-11 04:22:30 +00:00 committed by PyTorch MergeBot
parent 760967a284
commit 15686950b7
2 changed files with 8 additions and 0 deletions

View File

@ -169,6 +169,13 @@ def _get_dtensor_dispatch_graph(
op_overload = cast(torch._ops.OpOverload, node.target)
if node.target == torch.ops.aten.view.default:
# HACK: this is a hack to get around with the fact that some
# view operations on a "global" tensor is invalid usage
# but somehow the view operation on the batch input might hit it
# so we convert the view op to reshape before calling DTensor
op_overload = torch.ops.aten.reshape.default
# run dispatch once to get the real DTensor output.
out, op_schema, output_sharding = _operator_dispatch(
op_overload,

View File

@ -667,6 +667,7 @@ def register_prop_rule_map(
register_prop_rule_map(aten.squeeze.default, torch.squeeze)
register_prop_rule_map(aten.squeeze.dim, torch.squeeze)
register_prop_rule_map(aten.view.default, Tensor.view)
register_prop_rule_map(aten.reshape.default, torch.reshape)
register_prop_rule_map(aten._unsafe_view.default, Tensor.view)
register_prop_rule_map(aten.unsqueeze.default, torch.unsqueeze)
register_prop_rule_map(aten.expand.default, Tensor.expand)