mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[reland] switch mean to use reduction linear (#97996)
mean is actually a reduction linear formula if the final reduction is partial sum (which currently is), so switching to use that instead Pull Request resolved: https://github.com/pytorch/pytorch/pull/97996 Approved by: https://github.com/XilunWu, https://github.com/yifuwang
This commit is contained in:
parent
d9e5ab4606
commit
7fcff01b50
|
|
@ -3,6 +3,7 @@ from typing import cast, Optional, Sequence
|
|||
|
||||
import torch
|
||||
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
|
||||
from torch.distributed._tensor.ops.common_rules import pointwise_rule, reduction_rule
|
||||
from torch.distributed._tensor.ops.utils import (
|
||||
|
|
@ -10,7 +11,7 @@ from torch.distributed._tensor.ops.utils import (
|
|||
normalize_dims,
|
||||
register_prop_rule,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec, _Partial
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
@ -87,9 +88,17 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
|
|||
dims = _infer_reduction_dims(args_schema[1], input_spec.ndim)
|
||||
|
||||
keep_dim = len(args_schema) > 2 and bool(args_schema[2])
|
||||
return reduction_rule(
|
||||
op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False
|
||||
output_sharding = reduction_rule(
|
||||
op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=True
|
||||
)
|
||||
if output_sharding.output_spec is not None:
|
||||
assert isinstance(output_sharding.output_spec, DTensorSpec)
|
||||
for placement in output_sharding.output_spec.placements:
|
||||
if placement.is_partial():
|
||||
partial_placement = cast(_Partial, placement)
|
||||
partial_placement.reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
|
||||
|
||||
return output_sharding
|
||||
|
||||
|
||||
@register_prop_rule(
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ def bitwise_reduce(tensors, op):
|
|||
|
||||
_reduce_ops = {
|
||||
ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
|
||||
ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
|
||||
ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
|
||||
ReduceOp.MIN: partial(binop_reduce, op=torch.min),
|
||||
ReduceOp.MAX: partial(binop_reduce, op=torch.max),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user