mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51267 Original commit changeset: b70185916502 Test Plan: test locally, oss ci-all, fbcode incl deferred Reviewed By: suo Differential Revision: D26121251 fbshipit-source-id: 4315b7fd5476914c8e5d6f547e1cfbcf0c227781
21 lines
549 B
Python
21 lines
549 B
Python
import argparse
|
|
import torch
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, N, M):
|
|
super(MyModule, self).__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand(N, M))
|
|
|
|
def forward(self, input):
|
|
output = self.weight + input
|
|
return output
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("save_file", help="Where to save the model")
|
|
args = parser.parse_args()
|
|
|
|
my_module = MyModule(10, 20)
|
|
sm = torch.jit.script(my_module)
|
|
sm.save(args.save_file)
|