diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 51a8186bac5..f031085b23b 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -60,9 +60,9 @@ class DistMathOpsTest(DTensorTestBase): shard_spec = [Shard(0)] tensor = torch.randn(12, 8, 8) - # TODO: check `all` correctness and test `all` on a bool tensor - if op_str in ("any"): - # test out a bool tensor for any + if op_str in ("any", "all"): + # Test bool tensor for any() and all() reduction ops + # Previously all() had a bug using sum reduction instead of product tensor = tensor < 0 dtensor = distribute_tensor(tensor, device_mesh, shard_spec) diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 93030c7142b..45a786b9058 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -335,8 +335,8 @@ def common_reduction_strategy( LINEAR_REDUCTION_OP_MAP = { - aten.all.default: "sum", - aten.all.dim: "sum", + aten.all.default: "product", + aten.all.dim: "product", aten.sum.default: "sum", aten.sum.dim_IntList: "sum", aten.any.default: "sum",