[DTensor] Fix torch.all() using incorrect reduction operator (#165924)

Fixes #165923
Corrects the reduction operation to be product.

Enables "all" in the boolean tensor tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165924
Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
Sean McGovern 2025-10-29 20:58:31 +00:00 committed by PyTorch MergeBot
parent b33762bd2f
commit 56a809aa07
2 changed files with 5 additions and 5 deletions

View File

@ -60,9 +60,9 @@ class DistMathOpsTest(DTensorTestBase):
shard_spec = [Shard(0)]
tensor = torch.randn(12, 8, 8)
# TODO: check `all` correctness and test `all` on a bool tensor
if op_str in ("any"):
# test out a bool tensor for any
if op_str in ("any", "all"):
# Test bool tensor for any() and all() reduction ops
# Previously all() had a bug using sum reduction instead of product
tensor = tensor < 0
dtensor = distribute_tensor(tensor, device_mesh, shard_spec)

View File

@ -335,8 +335,8 @@ def common_reduction_strategy(
LINEAR_REDUCTION_OP_MAP = {
aten.all.default: "sum",
aten.all.dim: "sum",
aten.all.default: "product",
aten.all.dim: "product",
aten.sum.default: "sum",
aten.sum.dim_IntList: "sum",
aten.any.default: "sum",