mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
`torch.autocast` with `xla` backend has been restricted to `torch.bfloat16`. This shouldn't be the case anymore.
This works with `xla::cast( ..., type=f16)`
```
IR {
%0 = f32[] prim::Constant(), xla_shape=f32[], value=1
%1 = f32[3,2]{1,0} aten::expand(%0), xla_shape=f32[3,2]{1,0}, size=(3, 2), dynamic_dims=(0, 0)
%2 = f16[3,2]{1,0} xla::cast(%1), xla_shape=f16[3,2]{1,0}, type=f16, dtype=Half, stype=Float
%3 = f32[] prim::Constant(), xla_shape=f32[], value=1
%4 = f32[2,3]{1,0} aten::expand(%3), xla_shape=f32[2,3]{1,0}, size=(2, 3), dynamic_dims=(0, 0)
%5 = f16[2,3]{1,0} xla::cast(%4), xla_shape=f16[2,3]{1,0}, type=f16, dtype=Half, stype=Float
%6 = f16[2,2]{1,0} aten::mm(%5, %2), xla_shape=f16[2,2]{1,0}, ROOT=0
}
```
This will allow PyTorch/XLA to extend its autocast implementation to use `xla` backend for `float16` type as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109554
Approved by: https://github.com/JackCaoG, https://github.com/bdhirsh
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| autocast_mode.py | ||