mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[inductor] Make {output_graph,pad_mm}.py pass follow_imports typechecking (#113413)
I changed OutputGraph.nn_modules' type to `Dict[str, Any]` because it seems that `register_attr_or_module` can populate it with essentially any type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113413 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
8d41a5c605
commit
a8cf04fd2a
|
|
@ -2054,3 +2054,4 @@ def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]
|
||||||
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
|
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
|
||||||
def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ...
|
def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ...
|
||||||
def _get_module_info_from_flatbuffer(data: bytes): ...
|
def _get_module_info_from_flatbuffer(data: bytes): ...
|
||||||
|
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ from torch import fx
|
||||||
from torch._guards import (
|
from torch._guards import (
|
||||||
Checkpointable,
|
Checkpointable,
|
||||||
GlobalContextCheckpointState,
|
GlobalContextCheckpointState,
|
||||||
Guard,
|
|
||||||
GuardsCheckpointState,
|
GuardsCheckpointState,
|
||||||
Source,
|
Source,
|
||||||
TracingContext,
|
TracingContext,
|
||||||
|
|
@ -457,11 +456,11 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||||
return self.tracing_context.fake_mode.shape_env
|
return self.tracing_context.fake_mode.shape_env
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guards(self) -> Set[Guard]:
|
def guards(self) -> torch._guards.GuardsSet:
|
||||||
return self.tracing_context.guards_context.dynamo_guards
|
return self.tracing_context.guards_context.dynamo_guards
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nn_modules(self) -> Dict[str, torch.nn.Module]:
|
def nn_modules(self) -> Dict[str, Any]:
|
||||||
return self.tracing_context.module_context.nn_modules
|
return self.tracing_context.module_context.nn_modules
|
||||||
|
|
||||||
def save_global_state(self, out=None):
|
def save_global_state(self, out=None):
|
||||||
|
|
@ -616,7 +615,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||||
|
|
||||||
def get_submodule(self, keys):
|
def get_submodule(self, keys):
|
||||||
assert keys
|
assert keys
|
||||||
obj = self.nn_modules
|
obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
|
||||||
for k in keys.split("."):
|
for k in keys.split("."):
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
obj = obj[k]
|
obj = obj[k]
|
||||||
|
|
|
||||||
|
|
@ -388,7 +388,7 @@ class ModuleContextCheckpointState:
|
||||||
|
|
||||||
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.nn_modules: Dict[str, torch.nn.Module] = {}
|
self.nn_modules: Dict[str, Any] = {}
|
||||||
|
|
||||||
def copy_graphstate(self):
|
def copy_graphstate(self):
|
||||||
return ModuleContextCheckpointState(dict(self.nn_modules))
|
return ModuleContextCheckpointState(dict(self.nn_modules))
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ def should_pad_common(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_padded_length(x: Tensor, alignment_size) -> int:
|
def get_padded_length(x: int, alignment_size) -> int:
|
||||||
if alignment_size == 0 or x % alignment_size == 0:
|
if alignment_size == 0 or x % alignment_size == 0:
|
||||||
return 0
|
return 0
|
||||||
return int((x // alignment_size + 1) * alignment_size) - x
|
return int((x // alignment_size + 1) * alignment_size) - x
|
||||||
|
|
@ -94,7 +94,7 @@ def should_pad_addmm(match: Match) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def addmm_replace(
|
def addmm_replace(
|
||||||
input: Tensor, mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
|
input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
|
||||||
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
|
||||||
|
|
@ -116,7 +116,7 @@ def addmm_replace(
|
||||||
|
|
||||||
|
|
||||||
def pad_addmm(
|
def pad_addmm(
|
||||||
input: Tensor,
|
input: Optional[Tensor],
|
||||||
mat1: Tensor,
|
mat1: Tensor,
|
||||||
mat2: Tensor,
|
mat2: Tensor,
|
||||||
m_padded_length: int,
|
m_padded_length: int,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user