mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Make {output_graph,pad_mm}.py pass follow_imports typechecking (#113413)
I changed OutputGraph.nn_modules' type to `Dict[str, Any]` because it seems that `register_attr_or_module` can populate it with essentially any type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113413 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
8d41a5c605
commit
a8cf04fd2a
|
|
@ -2054,3 +2054,4 @@ def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]
|
|||
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
|
||||
def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ...
|
||||
def _get_module_info_from_flatbuffer(data: bytes): ...
|
||||
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from torch import fx
|
|||
from torch._guards import (
|
||||
Checkpointable,
|
||||
GlobalContextCheckpointState,
|
||||
Guard,
|
||||
GuardsCheckpointState,
|
||||
Source,
|
||||
TracingContext,
|
||||
|
|
@ -457,11 +456,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
return self.tracing_context.fake_mode.shape_env
|
||||
|
||||
@property
|
||||
def guards(self) -> Set[Guard]:
|
||||
def guards(self) -> torch._guards.GuardsSet:
|
||||
return self.tracing_context.guards_context.dynamo_guards
|
||||
|
||||
@property
|
||||
def nn_modules(self) -> Dict[str, torch.nn.Module]:
|
||||
def nn_modules(self) -> Dict[str, Any]:
|
||||
return self.tracing_context.module_context.nn_modules
|
||||
|
||||
def save_global_state(self, out=None):
|
||||
|
|
@ -616,7 +615,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
|||
|
||||
def get_submodule(self, keys):
|
||||
assert keys
|
||||
obj = self.nn_modules
|
||||
obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
|
||||
for k in keys.split("."):
|
||||
if isinstance(obj, dict):
|
||||
obj = obj[k]
|
||||
|
|
|
|||
|
|
@ -388,7 +388,7 @@ class ModuleContextCheckpointState:
|
|||
|
||||
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||
def __init__(self):
|
||||
self.nn_modules: Dict[str, torch.nn.Module] = {}
|
||||
self.nn_modules: Dict[str, Any] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
return ModuleContextCheckpointState(dict(self.nn_modules))
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def should_pad_common(
|
|||
)
|
||||
|
||||
|
||||
def get_padded_length(x: Tensor, alignment_size) -> int:
|
||||
def get_padded_length(x: int, alignment_size) -> int:
|
||||
if alignment_size == 0 or x % alignment_size == 0:
|
||||
return 0
|
||||
return int((x // alignment_size + 1) * alignment_size) - x
|
||||
|
|
@ -94,7 +94,7 @@ def should_pad_addmm(match: Match) -> bool:
|
|||
|
||||
|
||||
def addmm_replace(
|
||||
input: Tensor, mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
|
||||
input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
|
||||
) -> Tensor:
|
||||
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
||||
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
||||
|
|
@ -116,7 +116,7 @@ def addmm_replace(
|
|||
|
||||
|
||||
def pad_addmm(
|
||||
input: Tensor,
|
||||
input: Optional[Tensor],
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
m_padded_length: int,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user