[FSDP2] Add cache for FSDP wrapper class (#134135)

Currently, `fully_shard` will create a new `FSDPMyModuleClass` class for each `MyModuleClass` module **object**, which causes Dynamo to guard-fail on every module object's type checking. This PR fixes the issue by caching and reusing previously created FSDP wrapper class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134135
Approved by: https://github.com/awgu
This commit is contained in:
Will Feng 2024-08-21 12:02:21 -07:00 committed by PyTorch MergeBot
parent 2a73ba298c
commit 6bddfb9546

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import Any, cast, Iterable, List, NoReturn, Optional, Union
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union
import torch
import torch.nn as nn
@ -23,6 +23,9 @@ from ._fsdp_param_group import FSDPParamGroup
from ._fsdp_state import _get_module_fsdp_state, FSDPState
cls_to_fsdp_cls: Dict[Type, Type] = {}
# The decorator adds a state object to `module` that can be accessed via
# `fully_shard.state(module)`. The state object and module are 1:1.
@contract(state_cls=FSDPState) # type: ignore[operator]
@ -144,8 +147,11 @@ def fully_shard(
# Place FSDP leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
new_cls = cls_to_fsdp_cls.get(cls, None)
if not new_cls:
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
cls_to_fsdp_cls[cls] = new_cls
module.__class__ = new_cls
return arg_module