mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Apply parts of pyupgrade to torch (starting with the safest changes). This PR only does two things: removes the need to inherit from object and removes unused future imports. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308 Approved by: https://github.com/ezyang, https://github.com/albanD
19 lines
616 B
Python
19 lines
616 B
Python
import torch
|
|
from typing import Union
|
|
|
|
class _InsertPoint:
|
|
def __init__(self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]):
|
|
self.insert_point = insert_point
|
|
self.g = insert_point_graph
|
|
self.guard = None
|
|
|
|
def __enter__(self):
|
|
self.prev_insert_point = self.g.insertPoint()
|
|
self.g.setInsertPoint(self.insert_point)
|
|
|
|
def __exit__(self, *args):
|
|
self.g.setInsertPoint(self.prev_insert_point)
|
|
|
|
def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
|
|
return _InsertPoint(self, insert_point)
|