mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69871 Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D33515232 Pulled By: eellison fbshipit-source-id: d48da7b398a3f1a8862789484a4035d874196763 (cherry picked from commit e5976b8b7a4995be25a93601bbae5c52d6d3fca8)
19 lines
624 B
Python
19 lines
624 B
Python
import torch
|
|
from typing import Union
|
|
|
|
class _InsertPoint(object):
|
|
def __init__(self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]):
|
|
self.insert_point = insert_point
|
|
self.g = insert_point_graph
|
|
self.guard = None
|
|
|
|
def __enter__(self):
|
|
self.prev_insert_point = self.g.insertPoint()
|
|
self.g.setInsertPoint(self.insert_point)
|
|
|
|
def __exit__(self, *args):
|
|
self.g.setInsertPoint(self.prev_insert_point)
|
|
|
|
def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
|
|
return _InsertPoint(self, insert_point)
|