mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[DTensor] Fix torch.all() using incorrect reduction operator (#165924)
Fixes #165923 Corrects the reduction operation to be product. Enables "all" in the boolean tensor tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165924 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
parent
b33762bd2f
commit
56a809aa07
|
|
@ -60,9 +60,9 @@ class DistMathOpsTest(DTensorTestBase):
|
|||
shard_spec = [Shard(0)]
|
||||
|
||||
tensor = torch.randn(12, 8, 8)
|
||||
# TODO: check `all` correctness and test `all` on a bool tensor
|
||||
if op_str in ("any"):
|
||||
# test out a bool tensor for any
|
||||
if op_str in ("any", "all"):
|
||||
# Test bool tensor for any() and all() reduction ops
|
||||
# Previously all() had a bug using sum reduction instead of product
|
||||
tensor = tensor < 0
|
||||
dtensor = distribute_tensor(tensor, device_mesh, shard_spec)
|
||||
|
||||
|
|
|
|||
|
|
@ -335,8 +335,8 @@ def common_reduction_strategy(
|
|||
|
||||
|
||||
LINEAR_REDUCTION_OP_MAP = {
|
||||
aten.all.default: "sum",
|
||||
aten.all.dim: "sum",
|
||||
aten.all.default: "product",
|
||||
aten.all.dim: "product",
|
||||
aten.sum.default: "sum",
|
||||
aten.sum.dim_IntList: "sum",
|
||||
aten.any.default: "sum",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user