mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: This PR aims to fix https://discuss.pytorch.org/t/how-to-change-a-loaded-model-to-evaluation-mode-in-c/32330, by adding `train()` / `eval()` / `is_training()` to C++ ScriptModule API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16044 Differential Revision: D13857724 Pulled By: yf225 fbshipit-source-id: 16d3969fb5840ff7e66c7f72e800e6c75db8d2ff
42 lines
920 B
Python
42 lines
920 B
Python
import sys
|
|
import os
|
|
import torch
|
|
|
|
testEvalModeForLoadedModule_module_path = 'dropout_model.pt'
|
|
|
|
|
|
def testEvalModeForLoadedModule_setup():
|
|
class Model(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.1)
|
|
|
|
def forward(self, x):
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
model = Model()
|
|
model = model.train()
|
|
model.save(testEvalModeForLoadedModule_module_path)
|
|
|
|
|
|
def testEvalModeForLoadedModule_shutdown():
|
|
if os.path.exists(testEvalModeForLoadedModule_module_path):
|
|
os.remove(testEvalModeForLoadedModule_module_path)
|
|
|
|
|
|
def setup():
|
|
testEvalModeForLoadedModule_setup()
|
|
|
|
|
|
def shutdown():
|
|
testEvalModeForLoadedModule_shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
command = sys.argv[1]
|
|
if command == "setup":
|
|
setup()
|
|
elif command == "shutdown":
|
|
shutdown()
|