pytorch/torch/_dynamo/backends/torchxla.py
JackCaoG e38d60bc07 Remove some stale xla dynamo backend (#122128)
`torchxla_trace_once ` and `aot_torchxla_trivial ` should be removed.

In our internal(hopefully dashboard can be open source soon) torchbench daily runs, `openxla` backend has much higher passing rate and similar perfomrance as the `openxla_eval`(non-aot-auto-grad backend). We still use `openxla_eval` in llama2 example but I think we should move user to `openxla` backend going forward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122128
Approved by: https://github.com/alanwaketan, https://github.com/jansel
2024-03-21 01:13:50 +00:00

47 lines
1.2 KiB
Python

# mypy: ignore-errors
import logging
from functorch.compile import make_boxed_func
from ..backends.common import aot_autograd
from .registry import register_backend, register_experimental_backend
log = logging.getLogger(__name__)
@register_experimental_backend
def openxla_eval(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
def openxla_eval_boxed(model, fake_tensor_inputs):
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
try:
import torch_xla.core.dynamo_bridge as bridge
except ImportError as e:
raise ImportError(
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
) from e
compiled_graph = None
def fwd(*args):
nonlocal model
nonlocal compiled_graph
if compiled_graph is None:
compiled_graph = bridge.extract_compiled_graph(model, args)
del model
return compiled_graph(*args)
return make_boxed_func(fwd) if boxed else fwd
openxla = aot_autograd(
fw_compiler=openxla_eval_boxed,
)
register_backend(name="openxla", compiler_fn=openxla)