mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FSDP2] Add cache for FSDP wrapper class (#134135)
Currently, `fully_shard` will create a new `FSDPMyModuleClass` class for each `MyModuleClass` module **object**, which causes Dynamo to guard-fail on every module object's type checking. This PR fixes the issue by caching and reusing previously created FSDP wrapper class. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134135 Approved by: https://github.com/awgu
This commit is contained in:
parent
2a73ba298c
commit
6bddfb9546
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Any, cast, Iterable, List, NoReturn, Optional, Union
|
||||
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -23,6 +23,9 @@ from ._fsdp_param_group import FSDPParamGroup
|
|||
from ._fsdp_state import _get_module_fsdp_state, FSDPState
|
||||
|
||||
|
||||
cls_to_fsdp_cls: Dict[Type, Type] = {}
|
||||
|
||||
|
||||
# The decorator adds a state object to `module` that can be accessed via
|
||||
# `fully_shard.state(module)`. The state object and module are 1:1.
|
||||
@contract(state_cls=FSDPState) # type: ignore[operator]
|
||||
|
|
@ -144,8 +147,11 @@ def fully_shard(
|
|||
# Place FSDP leftmost for highest priority in the method resolution order
|
||||
for module in modules:
|
||||
cls = module.__class__
|
||||
dct = {"__deepcopy__": unimplemented_deepcopy}
|
||||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
||||
new_cls = cls_to_fsdp_cls.get(cls, None)
|
||||
if not new_cls:
|
||||
dct = {"__deepcopy__": unimplemented_deepcopy}
|
||||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
||||
cls_to_fsdp_cls[cls] = new_cls
|
||||
module.__class__ = new_cls
|
||||
return arg_module
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user