[reland] switch mean to use reduction linear (#97996)

mean is actually a reduction linear formula if the final reduction
is partial sum (which currently is), so switching to use that instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97996
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
This commit is contained in:
Wanchao Liang 2023-04-01 18:46:06 +00:00 committed by PyTorch MergeBot
parent d9e5ab4606
commit 7fcff01b50
2 changed files with 13 additions and 3 deletions

View File

@ -3,6 +3,7 @@ from typing import cast, Optional, Sequence
import torch
import torch.distributed.distributed_c10d as c10d
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule, reduction_rule
from torch.distributed._tensor.ops.utils import (
@ -10,7 +11,7 @@ from torch.distributed._tensor.ops.utils import (
normalize_dims,
register_prop_rule,
)
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed._tensor.placement_types import DTensorSpec, _Partial
aten = torch.ops.aten
@ -87,9 +88,17 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
dims = _infer_reduction_dims(args_schema[1], input_spec.ndim)
keep_dim = len(args_schema) > 2 and bool(args_schema[2])
return reduction_rule(
op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False
output_sharding = reduction_rule(
op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=True
)
if output_sharding.output_spec is not None:
assert isinstance(output_sharding.output_spec, DTensorSpec)
for placement in output_sharding.output_spec.placements:
if placement.is_partial():
partial_placement = cast(_Partial, placement)
partial_placement.reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
return output_sharding
@register_prop_rule(

View File

@ -52,6 +52,7 @@ def bitwise_reduce(tensors, op):
_reduce_ops = {
ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
ReduceOp.MIN: partial(binop_reduce, op=torch.min),
ReduceOp.MAX: partial(binop_reduce, op=torch.max),