mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DTensor] Fix torch.all() using incorrect reduction operator (#165924)
Fixes #165923 Corrects the reduction operation to be product. Enables "all" in the boolean tensor tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165924 Approved by: https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
parent
b33762bd2f
commit
56a809aa07
|
|
@ -60,9 +60,9 @@ class DistMathOpsTest(DTensorTestBase):
|
||||||
shard_spec = [Shard(0)]
|
shard_spec = [Shard(0)]
|
||||||
|
|
||||||
tensor = torch.randn(12, 8, 8)
|
tensor = torch.randn(12, 8, 8)
|
||||||
# TODO: check `all` correctness and test `all` on a bool tensor
|
if op_str in ("any", "all"):
|
||||||
if op_str in ("any"):
|
# Test bool tensor for any() and all() reduction ops
|
||||||
# test out a bool tensor for any
|
# Previously all() had a bug using sum reduction instead of product
|
||||||
tensor = tensor < 0
|
tensor = tensor < 0
|
||||||
dtensor = distribute_tensor(tensor, device_mesh, shard_spec)
|
dtensor = distribute_tensor(tensor, device_mesh, shard_spec)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -335,8 +335,8 @@ def common_reduction_strategy(
|
||||||
|
|
||||||
|
|
||||||
LINEAR_REDUCTION_OP_MAP = {
|
LINEAR_REDUCTION_OP_MAP = {
|
||||||
aten.all.default: "sum",
|
aten.all.default: "product",
|
||||||
aten.all.dim: "sum",
|
aten.all.dim: "product",
|
||||||
aten.sum.default: "sum",
|
aten.sum.default: "sum",
|
||||||
aten.sum.dim_IntList: "sum",
|
aten.sum.dim_IntList: "sum",
|
||||||
aten.any.default: "sum",
|
aten.any.default: "sum",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user