fixed compilations on xla tensor print (#71147)

Summary:
Fixes multiple compilation on xla tensor print. Please check the conversation here: https://github.com/pytorch/xla/pull/3253

This is done to avoid compilations during tensor printing. Torch performs some tensor operations like slicing to make the tensor readable. These operations result in compilations. Hence to avoid the compilations, copying the tensor to cpu before printing.

example:

```
dev = xm.xla_device()
def test_linear(input_shape=(8, 1024)):
    import pdb
    pdb.set_trace()
    linear = torch.nn.Linear(in_features=1024, out_features=4096, bias=True).to(dev)
    inp = torch.randn(*input_shape).to(dev)
    output = linear(inp)
    xm.mark_step()
    return output
```
Returning from this function would have resulted in 63 compiles, since PDB prints the value of the return output. In this case it is a xla tensor.

Now with the current change, there is no compilation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71147

Reviewed By: shunting314

Differential Revision: D33795177

Pulled By: wconstab

fbshipit-source-id: 74b53d9a1cb7ef67f9d8b0a32064f3896be449b5
(cherry picked from commit a9e0687fc5)
This commit is contained in:
aws-rhsoln 2022-01-26 17:37:00 -08:00 committed by PyTorch MergeBot
parent 76a2c22341
commit 027c0d7f8e

View File

@ -318,6 +318,12 @@ def _str_intern(inp):
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index):
suffixes.append('device=\'' + str(self.device) + '\'')
# Tensor printing performs tensor operations like slice, indexing, etc to make it in a
# representable format. These operations on xla/lazy tensor results in compilations. Hence,
# to avoid compilations, copying the tensor to cpu before printing.
if self.device.type == 'xla' or self.device.type == 'lazy':
self = self.to('cpu')
# TODO: add an API to map real -> complex dtypes
_default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool)