mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix torch.bernoulli decomposition return type (#115699)
Strangely enough, `torch.bernoulli` doesn't return a boolean and instead it matches the output type of the inplace bernoulli. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115699 Approved by: https://github.com/lezcano ghstack dependencies: #115677
This commit is contained in:
parent
0e0dd8f985
commit
9cdc80d581
|
|
@ -203,7 +203,6 @@ inductor_expected_failures_single_sample["cpu"] = {
|
||||||
f16
|
f16
|
||||||
}, # half_to_float is only valid for the CUDA implementation
|
}, # half_to_float is only valid for the CUDA implementation
|
||||||
"_upsample_bilinear2d_aa": {f32, f64},
|
"_upsample_bilinear2d_aa": {f32, f64},
|
||||||
"bernoulli": {f16, f32, f64},
|
|
||||||
"cholesky": {f32, f64},
|
"cholesky": {f32, f64},
|
||||||
"complex": {f16},
|
"complex": {f16},
|
||||||
"cross": {f16},
|
"cross": {f16},
|
||||||
|
|
@ -229,7 +228,6 @@ inductor_expected_failures_single_sample["cpu"] = {
|
||||||
inductor_expected_failures_single_sample["cuda"] = {
|
inductor_expected_failures_single_sample["cuda"] = {
|
||||||
"_upsample_bilinear2d_aa": {f16, f32, f64},
|
"_upsample_bilinear2d_aa": {f16, f32, f64},
|
||||||
"atanh": {f32},
|
"atanh": {f32},
|
||||||
"bernoulli": {f16, f32, f64},
|
|
||||||
"cholesky": {f32, f64},
|
"cholesky": {f32, f64},
|
||||||
"multinomial": {f16, f32, f64},
|
"multinomial": {f16, f32, f64},
|
||||||
"nn.functional.normalize": {f16},
|
"nn.functional.normalize": {f16},
|
||||||
|
|
|
||||||
|
|
@ -297,7 +297,7 @@ def lift(self):
|
||||||
@register_decomposition([aten.bernoulli.default])
|
@register_decomposition([aten.bernoulli.default])
|
||||||
def bernoulli(self, *, generator=None):
|
def bernoulli(self, *, generator=None):
|
||||||
assert generator is None
|
assert generator is None
|
||||||
return torch.rand_like(self, dtype=torch.float32) < self
|
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition([aten.fmin, prims.fmin])
|
@register_decomposition([aten.fmin, prims.fmin])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user