mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51124 Original commit changeset: 1c7133627da2 Test Plan: Test locally with interpreter_test and on CI Reviewed By: suo Differential Revision: D26077905 fbshipit-source-id: fae83bf9822d79e9a9b5641bc5191a7f3fdea78d
21 lines
549 B
Python
21 lines
549 B
Python
import argparse
|
|
import torch
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, N, M):
|
|
super(MyModule, self).__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
|
|
|
def forward(self, input):
|
|
output = self.weight + input
|
|
return output
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("save_file", help="Where to save the model")
|
|
args = parser.parse_args()
|
|
|
|
my_module = MyModule(10, 20)
|
|
sm = torch.jit.script(my_module)
|
|
sm.save(args.save_file)
|