mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Ports over the step closure functionality from PyTorch/XLA to Lazy Tensor Core: References:205ae574c0/torch_xla/core/xla_model.py (L852-L900)205ae574c0/torch_xla/utils/closures.py (L7-L83)CC: @wconstab @JackCaoG @Krovatkin Pull Request resolved: https://github.com/pytorch/pytorch/pull/84300 Approved by: https://github.com/JackCaoG, https://github.com/wconstab
28 lines
968 B
Python
28 lines
968 B
Python
from typing import List
|
|
from torch import Tensor
|
|
|
|
# defined in torch/csrc/lazy/python/init.cpp
|
|
def _mark_step(device: str, devices: List[str], wait: bool): ...
|
|
def _wait_device_ops(devices: List[str]): ...
|
|
def _reset_metrics(): ...
|
|
def _counter_names() -> List[str]: ...
|
|
def _counter_value(name: str) -> int: ...
|
|
def _metrics_report() -> str: ...
|
|
def _get_graph_hash(tensors: List[Tensor]) -> str: ...
|
|
def _sync_multi(
|
|
tensors: List[Tensor],
|
|
devices: List[str],
|
|
wait: bool = True,
|
|
sync_ltc_data: bool = True,
|
|
): ...
|
|
def _get_tensor_id(tensor: Tensor) -> int: ...
|
|
def _get_tensors_text(tensors: List[Tensor]) -> str: ...
|
|
def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
|
|
def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
|
|
def _get_force_fallback() -> str: ...
|
|
def _set_force_fallback(newval: str): ...
|
|
def _clear_ir_cache(): ...
|
|
def _dump_ir_cache(filename: str): ...
|
|
def _set_reuse_ir(val: bool): ...
|
|
def _get_default_device_type(): ...
|