mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fixed compilations on xla tensor print (#71147)
Summary:
Fixes multiple compilation on xla tensor print. Please check the conversation here: https://github.com/pytorch/xla/pull/3253
This is done to avoid compilations during tensor printing. Torch performs some tensor operations like slicing to make the tensor readable. These operations result in compilations. Hence to avoid the compilations, copying the tensor to cpu before printing.
example:
```
dev = xm.xla_device()
def test_linear(input_shape=(8, 1024)):
import pdb
pdb.set_trace()
linear = torch.nn.Linear(in_features=1024, out_features=4096, bias=True).to(dev)
inp = torch.randn(*input_shape).to(dev)
output = linear(inp)
xm.mark_step()
return output
```
Returning from this function would have resulted in 63 compiles, since PDB prints the value of the return output. In this case it is a xla tensor.
Now with the current change, there is no compilation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71147
Reviewed By: shunting314
Differential Revision: D33795177
Pulled By: wconstab
fbshipit-source-id: 74b53d9a1cb7ef67f9d8b0a32064f3896be449b5
(cherry picked from commit a9e0687fc5)
This commit is contained in:
parent
76a2c22341
commit
027c0d7f8e
|
|
@ -318,6 +318,12 @@ def _str_intern(inp):
|
|||
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index):
|
||||
suffixes.append('device=\'' + str(self.device) + '\'')
|
||||
|
||||
# Tensor printing performs tensor operations like slice, indexing, etc to make it in a
|
||||
# representable format. These operations on xla/lazy tensor results in compilations. Hence,
|
||||
# to avoid compilations, copying the tensor to cpu before printing.
|
||||
if self.device.type == 'xla' or self.device.type == 'lazy':
|
||||
self = self.to('cpu')
|
||||
|
||||
# TODO: add an API to map real -> complex dtypes
|
||||
_default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
|
||||
has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user