mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[functorch] move batch_norm_replacement to torch.func (#91412)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91412 Approved by: https://github.com/zou3519
This commit is contained in:
parent
7bdcf6d4f0
commit
515dff7811
|
|
@ -66,6 +66,7 @@ Here's how we would compute the Jacobian over the parameters
|
|||
|
||||
functional_call
|
||||
stack_module_state
|
||||
replace_all_batch_norm_modules_
|
||||
|
||||
If you're looking for information on fixing Batch Norm modules, please follow the
|
||||
guidance here
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ have a net where you want the BatchNorm to not use running stats, you can run
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from functorch.experimental import replace_all_batch_norm_modules_
|
||||
from torch.func import replace_all_batch_norm_modules_
|
||||
replace_all_batch_norm_modules_(net)
|
||||
|
||||
Option 4: eval mode
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# PyTorch forward-mode is not mature yet
|
||||
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
|
||||
from torch._functorch.vmap import chunk_vmap
|
||||
from .batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from functorch import functionalize
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch.nn as nn
|
||||
from torch._functorch.utils import exposed_in
|
||||
|
||||
|
||||
def batch_norm_without_running_stats(module: nn.Module):
|
||||
|
|
@ -9,6 +10,7 @@ def batch_norm_without_running_stats(module: nn.Module):
|
|||
module.track_running_stats = False
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
|
||||
"""
|
||||
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and
|
||||
|
|
@ -9,4 +9,5 @@ from torch._functorch.eager_transforms import (
|
|||
functionalize,
|
||||
)
|
||||
from torch._functorch.functional_call import functional_call, stack_module_state
|
||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from torch._functorch.vmap import vmap
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user