pytorch/torch/_dynamo/backends/torchxla.py
Xuehai Pan e74ba1b34a [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
2024-07-31 21:18:11 +00:00

48 lines
1.2 KiB
Python

# mypy: ignore-errors
import logging
from functorch.compile import make_boxed_func
from ..backends.common import aot_autograd
from .registry import register_backend, register_experimental_backend
log = logging.getLogger(__name__)
@register_experimental_backend
def openxla_eval(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
def openxla_eval_boxed(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
try:
import torch_xla.core.dynamo_bridge as bridge
except ImportError as e:
raise ImportError(
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
) from e
compiled_graph = None
def fwd(*args):
nonlocal model
nonlocal compiled_graph
if compiled_graph is None:
compiled_graph = bridge.extract_compiled_graph(model, args)
del model
return compiled_graph(*args)
return make_boxed_func(fwd) if boxed else fwd
openxla = aot_autograd(
fw_compiler=openxla_eval_boxed,
)
register_backend(name="openxla", compiler_fn=openxla)