Commit Graph

4 Commits

Author SHA1 Message Date
Peter Bell
cd9e158007 Accept non-standard bools in more CUDA kernels
This fixes all remaining CUDA kernels, except those using `cub` or
`thrust`, to accept boolean tensors with values other than 1 or 0.

I do this by using `c10::load` in more places, and also adding a
`load_vector` helper into `MemoryAccess.cuh` that does the same thing
for vectorized loads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78957

Approved by: https://github.com/mruberry
2022-06-09 08:31:28 +00:00
Peter Bell
c936396af2 Always convert truthy booleans to 1
Ref #54789

A `bool` has only two valid values, 1 or 0. Any in-memory value
outside of those leads to undefined behavior. So, instead of
`reinterpret_cast`-ing to `bool*` I introduce `c10::load<scalar_t>`
which will read as `unsigned char` and convert to a valid `bool`.

This gets >90% of operators working, but the remaining operators where
skips and xfails have been added will require individual attention.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77122

Approved by: https://github.com/mruberry
2022-06-07 16:00:30 +00:00
PyTorch MergeBot
78824a7d54 Revert "Always convert truthy booleans to 1"
This reverts commit 3c3c6cd982.

Reverted https://github.com/pytorch/pytorch/pull/77122 on behalf of https://github.com/mruberry due to broke some jobs, like https://github.com/pytorch/pytorch/runs/6706333043?check_suite_focus=true
2022-06-02 13:45:54 +00:00
Peter Bell
3c3c6cd982 Always convert truthy booleans to 1
Ref #54789

A `bool` has only two valid values, 1 or 0. Any in-memory value
outside of those leads to undefined behavior. So, instead of
`reinterpret_cast`-ing to `bool*` I introduce `c10::load<scalar_t>`
which will read as `unsigned char` and convert to a valid `bool`.

This gets >90% of operators working, but the remaining operators where
skips and xfails have been added will require individual attention.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77122

Approved by: https://github.com/mruberry
2022-06-02 04:18:34 +00:00