mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
The error for `test_error_stack_module`:
```
Traceback (most recent call last):
File "../test.py", line 35, in <module>
scripted = torch.jit.script(M())
File "/home/davidriazati/other/pytorch/torch/jit/__init__.py", line 1119, in script
return _convert_to_script_module(obj)
File "/home/davidriazati/other/pytorch/torch/jit/__init__.py", line 1825, in _convert_to_script_module
raise e
RuntimeError:
d(int x) -> int:
Expected a value of type 'int' for argument 'x' but instead found type 'str'.
:
at ../test.py:11:12
def c(x):
return d("hello") + d(x)
~ <--- HERE
'c' is being compiled since it was called from 'b'
at ../test.py:14:12
def b(x):
return c(x)
~~~ <--- HERE
'b' is being compiled since it was called from 'forward'
at ../test.py:22:16
def forward(self, x):
return b(x)
~~~ <--- HERE
'forward' is being compiled since it was called from 'forward'
at ../test.py:31:20
def forward(self, x):
return x + self.submodule(x)
~~~~~~~~~~~~~~~~ <--- HERE
```
This also unifies our error reporting in the front end with `ErrorReport`
TODO
* Include module names in message, #22207 should make this easy
](https://our.intern.facebook.com/intern/diff/16060781/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22280
Pulled By: driazati
Differential Revision: D16060781
fbshipit-source-id: c42968b53aaddb774ac69d5abbf7e60c23df8eed
43 lines
953 B
Python
43 lines
953 B
Python
import sys
|
|
import os
|
|
import torch
|
|
|
|
testEvalModeForLoadedModule_module_path = 'dropout_model.pt'
|
|
|
|
|
|
def testEvalModeForLoadedModule_setup():
|
|
class Model(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.dropout = torch.nn.Dropout(0.1)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
model = Model()
|
|
model = model.train()
|
|
model.save(testEvalModeForLoadedModule_module_path)
|
|
|
|
|
|
def testEvalModeForLoadedModule_shutdown():
|
|
if os.path.exists(testEvalModeForLoadedModule_module_path):
|
|
os.remove(testEvalModeForLoadedModule_module_path)
|
|
|
|
|
|
def setup():
|
|
testEvalModeForLoadedModule_setup()
|
|
|
|
|
|
def shutdown():
|
|
testEvalModeForLoadedModule_shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
command = sys.argv[1]
|
|
if command == "setup":
|
|
setup()
|
|
elif command == "shutdown":
|
|
shutdown()
|