mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
When calling `torch.masked.mean(...)` with a boolean tensor, the dtype is inferred to be bool. When the mean is being computed, the sum operator is used. When the sum operator is used with dtype=torch.bool, the result is clamped to True (1) leading to an incorrect mean being calculated. The below example shows how the incorrect result occurs: ``` a = torch.tensor([True, True]) count = torch.sum(torch.ones(a.shape, dtype=torch.int64)) # 2 total = torch.sum(a, dtype=torch.bool) # True (1) mean = total / count # 0.5 ``` This PR upcasts the dtype used for the sumation to int32 in the case of bool tensors allowing for the correct result to be computed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139999 Approved by: https://github.com/cpuhrsch |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| _masked.py | ||
| fft.py | ||
| linalg.py | ||
| nested.py | ||
| signal.py | ||
| sparse.py | ||
| special.py | ||