mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44679 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D23696766 Pulled By: jamesr66a fbshipit-source-id: fe18b7b579c1728d00589bd5fd5e54c917cc61fe
37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
# Nodes represent a definition of a value in our graph of operators.
|
|
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict
|
|
import torch
|
|
|
|
if TYPE_CHECKING:
|
|
from .graph import Graph
|
|
|
|
|
|
BaseArgumentTypes = Union[str, int, float, bool, torch.dtype, torch.Tensor]
|
|
base_types = BaseArgumentTypes.__args__ # type: ignore
|
|
|
|
Target = Union[Callable[..., Any], str]
|
|
|
|
Argument = Optional[Union[
|
|
Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
|
List[Any], # actually Argument
|
|
Dict[str, Any], # actually Argument
|
|
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
|
'Node',
|
|
BaseArgumentTypes
|
|
]]
|
|
|
|
class Node:
|
|
def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
|
|
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> None:
|
|
self.graph = graph
|
|
self.name = name # unique name of value being created
|
|
self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|getattr
|
|
self.target = target # for method/module/function, the name of the method/module/function/attr
|
|
# being invoked, e.g add, layer1, or torch.add
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.uses = 0
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name
|