[inductor] Make {output_graph,pad_mm}.py pass follow_imports typechecking (#113413)

I changed OutputGraph.nn_modules' type to `Dict[str, Any]` because it
seems that `register_attr_or_module` can populate it with essentially
any type.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113413
Approved by: https://github.com/Skylion007
This commit is contained in:
Jez Ng 2023-11-11 10:27:17 -08:00 committed by PyTorch MergeBot
parent 8d41a5c605
commit a8cf04fd2a
4 changed files with 8 additions and 8 deletions

View File

@ -2054,3 +2054,4 @@ def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ...
def _get_module_info_from_flatbuffer(data: bytes): ...
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...

View File

@ -24,7 +24,6 @@ from torch import fx
from torch._guards import (
Checkpointable,
GlobalContextCheckpointState,
Guard,
GuardsCheckpointState,
Source,
TracingContext,
@ -457,11 +456,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
return self.tracing_context.fake_mode.shape_env
@property
def guards(self) -> Set[Guard]:
def guards(self) -> torch._guards.GuardsSet:
return self.tracing_context.guards_context.dynamo_guards
@property
def nn_modules(self) -> Dict[str, torch.nn.Module]:
def nn_modules(self) -> Dict[str, Any]:
return self.tracing_context.module_context.nn_modules
def save_global_state(self, out=None):
@ -616,7 +615,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
def get_submodule(self, keys):
assert keys
obj = self.nn_modules
obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
for k in keys.split("."):
if isinstance(obj, dict):
obj = obj[k]

View File

@ -388,7 +388,7 @@ class ModuleContextCheckpointState:
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self):
self.nn_modules: Dict[str, torch.nn.Module] = {}
self.nn_modules: Dict[str, Any] = {}
def copy_graphstate(self):
return ModuleContextCheckpointState(dict(self.nn_modules))

View File

@ -67,7 +67,7 @@ def should_pad_common(
)
def get_padded_length(x: Tensor, alignment_size) -> int:
def get_padded_length(x: int, alignment_size) -> int:
if alignment_size == 0 or x % alignment_size == 0:
return 0
return int((x // alignment_size + 1) * alignment_size) - x
@ -94,7 +94,7 @@ def should_pad_addmm(match: Match) -> bool:
def addmm_replace(
input: Tensor, mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
) -> Tensor:
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
@ -116,7 +116,7 @@ def addmm_replace(
def pad_addmm(
input: Tensor,
input: Optional[Tensor],
mat1: Tensor,
mat2: Tensor,
m_padded_length: int,