[functorch] move batch_norm_replacement to torch.func (#91412)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91412
Approved by: https://github.com/zou3519
This commit is contained in:
samdow 2023-01-11 17:45:40 -05:00 committed by PyTorch MergeBot
parent 7bdcf6d4f0
commit 515dff7811
5 changed files with 6 additions and 2 deletions

View File

@ -66,6 +66,7 @@ Here's how we would compute the Jacobian over the parameters
functional_call
stack_module_state
replace_all_batch_norm_modules_
If you're looking for information on fixing Batch Norm modules, please follow the
guidance here

View File

@ -69,7 +69,7 @@ have a net where you want the BatchNorm to not use running stats, you can run
.. code-block:: python
from functorch.experimental import replace_all_batch_norm_modules_
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
Option 4: eval mode

View File

@ -1,5 +1,5 @@
# PyTorch forward-mode is not mature yet
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
from torch._functorch.vmap import chunk_vmap
from .batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from functorch import functionalize

View File

@ -1,4 +1,5 @@
import torch.nn as nn
from torch._functorch.utils import exposed_in
def batch_norm_without_running_stats(module: nn.Module):
@ -9,6 +10,7 @@ def batch_norm_without_running_stats(module: nn.Module):
module.track_running_stats = False
@exposed_in("torch.func")
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
"""
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and

View File

@ -9,4 +9,5 @@ from torch._functorch.eager_transforms import (
functionalize,
)
from torch._functorch.functional_call import functional_call, stack_module_state
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.vmap import vmap