mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[spmd] quick fix on batch input view issue (#98813)
This is a quick fix/hack to get around with the issue that some "global" tensor view operation is invalid, but somehow it get triggered by some models as mini-batch input itself won't have this issue. Since ultimately we should remove the dtensor expand and use the new expansion, this hack is only temporary to unblock Pull Request resolved: https://github.com/pytorch/pytorch/pull/98813 Approved by: https://github.com/yifuwang, https://github.com/mrshenli
This commit is contained in:
parent
760967a284
commit
15686950b7
|
|
@ -169,6 +169,13 @@ def _get_dtensor_dispatch_graph(
|
|||
|
||||
op_overload = cast(torch._ops.OpOverload, node.target)
|
||||
|
||||
if node.target == torch.ops.aten.view.default:
|
||||
# HACK: this is a hack to get around with the fact that some
|
||||
# view operations on a "global" tensor is invalid usage
|
||||
# but somehow the view operation on the batch input might hit it
|
||||
# so we convert the view op to reshape before calling DTensor
|
||||
op_overload = torch.ops.aten.reshape.default
|
||||
|
||||
# run dispatch once to get the real DTensor output.
|
||||
out, op_schema, output_sharding = _operator_dispatch(
|
||||
op_overload,
|
||||
|
|
|
|||
|
|
@ -667,6 +667,7 @@ def register_prop_rule_map(
|
|||
register_prop_rule_map(aten.squeeze.default, torch.squeeze)
|
||||
register_prop_rule_map(aten.squeeze.dim, torch.squeeze)
|
||||
register_prop_rule_map(aten.view.default, Tensor.view)
|
||||
register_prop_rule_map(aten.reshape.default, torch.reshape)
|
||||
register_prop_rule_map(aten._unsafe_view.default, Tensor.view)
|
||||
register_prop_rule_map(aten.unsqueeze.default, torch.unsqueeze)
|
||||
register_prop_rule_map(aten.expand.default, Tensor.expand)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user