pytorch/torch/distributed/fsdp/utils.py
Yanli Zhao 2336571cb7 make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
2022-02-02 15:50:14 +00:00

25 lines
743 B
Python

from typing import Dict, List, Tuple, Union, Any, Callable, Set
import torch
"""Useful functions to deal with tensor types with other python container types."""
def _apply_to_tensors(
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, dict):
return {key: apply(value) for key, value in x.items()}
elif isinstance(x, (list, tuple, set)):
return type(x)(apply(el) for el in x)
else:
return x
return apply(container)