mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844 Approved by: https://github.com/oulgen ghstack dependencies: #127842, #127843
27 lines
677 B
Python
27 lines
677 B
Python
# mypy: allow-untyped-defs
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
|
|
class _InsertPoint:
|
|
def __init__(
|
|
self,
|
|
insert_point_graph: torch._C.Graph,
|
|
insert_point: Union[torch._C.Node, torch._C.Block],
|
|
):
|
|
self.insert_point = insert_point
|
|
self.g = insert_point_graph
|
|
self.guard = None
|
|
|
|
def __enter__(self):
|
|
self.prev_insert_point = self.g.insertPoint()
|
|
self.g.setInsertPoint(self.insert_point)
|
|
|
|
def __exit__(self, *args):
|
|
self.g.setInsertPoint(self.prev_insert_point)
|
|
|
|
|
|
def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
|
|
return _InsertPoint(self, insert_point)
|