pytorch/torch/testing/_internal/opinfo
George Wigley 0742b2366e Update torch.masked.mean to upcast dtype for bool tensors (#139999)
When calling `torch.masked.mean(...)` with a boolean tensor, the dtype is inferred to be bool. When the mean is being computed, the sum operator is used. When the sum operator is used with dtype=torch.bool, the result is clamped to True (1) leading to an incorrect mean being calculated.

The below example shows how the incorrect result occurs:
```
a = torch.tensor([True, True])
count = torch.sum(torch.ones(a.shape, dtype=torch.int64)) # 2
total = torch.sum(a, dtype=torch.bool) # True (1)
mean = total / count # 0.5
```

This PR upcasts the dtype used for the sumation to int32 in the case of bool tensors allowing for the correct result to be computed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139999
Approved by: https://github.com/cpuhrsch
2025-01-07 00:26:59 +00:00
..
definitions Update torch.masked.mean to upcast dtype for bool tensors (#139999) 2025-01-07 00:26:59 +00:00
__init__.py
core.py Update core.py to fix typo (#144201) 2025-01-05 18:20:52 +00:00
refs.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00
utils.py Remove unused Python variables in torch/[b-z]* (#136963) 2024-10-19 16:45:22 +00:00