mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Update planner.py (#107998)
Fixes #107997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107998 Approved by: https://github.com/wz337
This commit is contained in:
parent
86e6bd3e53
commit
ab99a95470
|
|
@ -283,7 +283,17 @@ class LoadPlanner:
|
|||
>>> class RenamePlanner(DefaultLoadPlanner):
|
||||
>>> def set_up_planner(self, state_dict, metadata, is_coordinator):
|
||||
>>> self.original_state_dict = state_dict
|
||||
>>> super().set_up_planner(self, {"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
|
||||
>>> state_dict = {"foo_" + k: v for k, v in state_dict.items()}
|
||||
>>>
|
||||
>>> if self.flatten_sharded_tensors:
|
||||
>>> state_dict = _flatten_sharded_tensors(state_dict)
|
||||
>>>
|
||||
>>> if self.flatten_state_dict:
|
||||
>>> state_dict, self.mappings = flatten_state_dict(state_dict)
|
||||
>>>
|
||||
>>> self.state_dict = state_dict
|
||||
>>> self.metadata = metadata
|
||||
>>> self.is_coordinator = is_coordinator
|
||||
>>>
|
||||
>>> def load_bytes(self, read_item, value):
|
||||
>>> # Remove the "foo_" prefix
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user