mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
- Add support for `torch.Generator` type in TorchScript - Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_` - Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab) CC: @eellison @davidberard98 @GlebKazantaev @behzad-a Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
196 lines
5.8 KiB
Python
196 lines
5.8 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import io
|
|
import math
|
|
import unittest
|
|
|
|
import torch
|
|
from torch.nn import init
|
|
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestGenerator(JitTestCase):
|
|
# torch.jit.trace does not properly capture the generator manual seed
|
|
# and thus is non deterministic even if the generator is manually seeded
|
|
@skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type")
|
|
@unittest.expectedFailure
|
|
def test_trace(self):
|
|
def f():
|
|
generator = torch.Generator()
|
|
generator.seed()
|
|
generator.manual_seed(2023)
|
|
generator.initial_seed()
|
|
tensor = torch.empty(2, 2)
|
|
tensor.uniform_(0, 1, generator=generator)
|
|
return tensor
|
|
|
|
traced_f = torch.jit.trace(f, ())
|
|
|
|
# Run this 3 times to ensure that the generator is being manually seeded
|
|
# each time the traced function is run
|
|
for i in range(3):
|
|
torch.manual_seed(1)
|
|
|
|
eager_tensor = f()
|
|
|
|
# Change the seed of the default generator to
|
|
# check that we're using the generator from the
|
|
# trace
|
|
torch.manual_seed(2)
|
|
traced_tensor = traced_f()
|
|
|
|
self.assertEqual(eager_tensor, traced_tensor)
|
|
|
|
def test_script(self):
|
|
def f():
|
|
generator = torch.Generator()
|
|
generator.seed()
|
|
generator.manual_seed(2023)
|
|
generator.initial_seed()
|
|
tensor = torch.empty(2, 2)
|
|
tensor.normal_(-1.0, 1.0, generator=generator)
|
|
return tensor
|
|
|
|
script_f = torch.jit.script(f, ())
|
|
|
|
# Run this 3 times to ensure that the generator is being manually seeded
|
|
# each time the traced function is run
|
|
for i in range(3):
|
|
torch.manual_seed(1)
|
|
|
|
eager_tensor = f()
|
|
|
|
# Change the seed of the default generator to
|
|
# check that we're using the generator from the
|
|
# trace
|
|
torch.manual_seed(2)
|
|
|
|
script_tensor = script_f()
|
|
|
|
self.assertEqual(eager_tensor, script_tensor)
|
|
|
|
def test_default_generator(self):
|
|
def f():
|
|
# check that calling manual seed for the default generator works
|
|
torch.manual_seed(2023)
|
|
tensor = torch.empty(2, 2)
|
|
tensor.normal_(-1.0, 1.0)
|
|
return tensor
|
|
|
|
torch.manual_seed(1)
|
|
|
|
eager_tensor = f()
|
|
|
|
torch.manual_seed(2)
|
|
|
|
script_f = torch.jit.script(f, ())
|
|
script_tensor = script_f()
|
|
|
|
self.assertEqual(eager_tensor, script_tensor)
|
|
|
|
def test_generator_arg(self):
|
|
def f(generator: torch.Generator):
|
|
tensor = torch.empty(2, 2)
|
|
tensor.normal_(-1.0, 1.0, generator=generator)
|
|
return tensor
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(2023)
|
|
|
|
script_f = torch.jit.script(f, (generator,))
|
|
|
|
for i in range(3):
|
|
generator = torch.Generator()
|
|
generator.manual_seed(2023 + i)
|
|
|
|
torch.manual_seed(1 + i)
|
|
|
|
eager_tensor = f(generator)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(2023 + i)
|
|
|
|
torch.manual_seed(1 + i)
|
|
|
|
script_tensor = script_f(generator)
|
|
|
|
self.assertEqual(eager_tensor, script_tensor)
|
|
|
|
def test_save_load(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.foo = torch.nn.Linear(2, 2, bias=False)
|
|
self.bar = torch.nn.Linear(2, 2, bias=False)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_linear(self, module, generator):
|
|
init.kaiming_uniform_(
|
|
module.weight, a=math.sqrt(5), generator=generator
|
|
)
|
|
|
|
def reset_parameters(self):
|
|
generator = torch.Generator()
|
|
generator.manual_seed(1)
|
|
self.reset_linear(self.foo, generator)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(2)
|
|
self.reset_linear(self.bar, generator)
|
|
|
|
def forward(self, x):
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(3)
|
|
r = torch.empty_like(x)
|
|
r.normal_(0.0, 1.0, generator=generator)
|
|
|
|
return x, r
|
|
|
|
eager_foo = Foo()
|
|
|
|
script_module = torch.jit.script(Foo())
|
|
saved_module = io.BytesIO()
|
|
torch.jit.save(script_module, saved_module)
|
|
saved_module.seek(0)
|
|
|
|
loaded_module = torch.jit.load(saved_module)
|
|
|
|
self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight)
|
|
self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight)
|
|
|
|
try:
|
|
# Run this 3 times so make sure that the generator seed is being set
|
|
# every time forward is called
|
|
for i in range(3):
|
|
x = torch.ones(2, 2)
|
|
out1, r1 = eager_foo(x)
|
|
out2, r2 = loaded_module(x)
|
|
|
|
try:
|
|
self.assertEqual(out1, out2)
|
|
except: # noqa: B001, E722
|
|
print(f"Iteration {i}:\n{out1=}\n{out2=}")
|
|
raise
|
|
|
|
try:
|
|
self.assertEqual(r1, r2)
|
|
except: # noqa: B001, E722
|
|
print(f"Iteration {i}:\n{r1=}\n{r2=}")
|
|
raise
|
|
except: # noqa: B001, E722
|
|
print(loaded_module.forward.code)
|
|
raise
|