pytorch/test/cpp/jit/tests_setup.py
Will Feng a40e8ce7c5 Add train() / eval() / is_training() to C++ ScriptModule API (#16044)
Summary:
This PR aims to fix https://discuss.pytorch.org/t/how-to-change-a-loaded-model-to-evaluation-mode-in-c/32330, by adding `train()` / `eval()` / `is_training()` to C++ ScriptModule API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16044

Differential Revision: D13857724

Pulled By: yf225

fbshipit-source-id: 16d3969fb5840ff7e66c7f72e800e6c75db8d2ff
2019-02-01 13:07:38 -08:00

42 lines
920 B
Python

import sys
import os
import torch
testEvalModeForLoadedModule_module_path = 'dropout_model.pt'
def testEvalModeForLoadedModule_setup():
class Model(torch.jit.ScriptModule):
def __init__(self):
super(Model, self).__init__()
self.dropout = torch.nn.Dropout(0.1)
def forward(self, x):
x = self.dropout(x)
return x
model = Model()
model = model.train()
model.save(testEvalModeForLoadedModule_module_path)
def testEvalModeForLoadedModule_shutdown():
if os.path.exists(testEvalModeForLoadedModule_module_path):
os.remove(testEvalModeForLoadedModule_module_path)
def setup():
testEvalModeForLoadedModule_setup()
def shutdown():
testEvalModeForLoadedModule_shutdown()
if __name__ == "__main__":
command = sys.argv[1]
if command == "setup":
setup()
elif command == "shutdown":
shutdown()