mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
25 lines
743 B
Python
25 lines
743 B
Python
from typing import Dict, List, Tuple, Union, Any, Callable, Set
|
|
|
|
import torch
|
|
|
|
|
|
"""Useful functions to deal with tensor types with other python container types."""
|
|
|
|
|
|
def _apply_to_tensors(
|
|
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]
|
|
) -> Any:
|
|
"""Recursively apply to all tensor in different kinds of container types."""
|
|
|
|
def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
|
|
if torch.is_tensor(x):
|
|
return fn(x)
|
|
elif isinstance(x, dict):
|
|
return {key: apply(value) for key, value in x.items()}
|
|
elif isinstance(x, (list, tuple, set)):
|
|
return type(x)(apply(el) for el in x)
|
|
else:
|
|
return x
|
|
|
|
return apply(container)
|