Break graph on manual_seed. (#107594)

Fix: #107187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107594
Approved by: https://github.com/eellison
This commit is contained in:
Yukio Siraichi 2023-08-25 16:48:35 -03:00 committed by PyTorch MergeBot
parent b7624fc91e
commit 6c28de2437
6 changed files with 34 additions and 24 deletions

View File

@ -1133,6 +1133,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
x = np.random.randn(2, 2)
return x - x
def test_manual_seed(self):
@torch.compile
def foo():
torch.manual_seed(3)
return torch.randint(0, 5, (5,))
self.assertEqual(foo(), foo())
self.assertEqual(foo(), foo())
def global_func_with_default_tensor_args(
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))

View File

@ -1772,14 +1772,6 @@ class MiscTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x, obj)
self.assertTrue(same(ref, res))
def test_manual_seed(self):
def fn(a, b):
x = a + b
torch.manual_seed(9000)
return x + 1
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
def test_usr_cls_staticmethod(self):
class Foo:
@staticmethod
@ -2226,13 +2218,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
torch.manual_seed(attention_seed)
return (x,)
x = torch.randn(100, requires_grad=True)
x = torch.randn(10, requires_grad=True)
ref = fn(x)
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
# Python code is needed here, since torch.manual_seed graph-breaks.
# Refs: https://github.com/pytorch/pytorch/issues/107187
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.op_count, 1)
self.assertEqual(cnts.frame_count, 1)
def test_is_tensor_like(self):
cnts = torch._dynamo.testing.CompileCounter()

View File

@ -264,7 +264,6 @@ inductor_expected_failures_single_sample["cuda"] = {
"nn.functional.instance_norm": {f16},
"nn.functional.local_response_norm": {f16},
"nn.functional.normalize": {f16},
"nn.functional.rrelu": {f16, f32, f64},
"nn.functional.soft_margin_loss": {f16},
"nn.functional.softsign": {f16},
"nn.functional.triplet_margin_loss": {f16},
@ -344,15 +343,13 @@ test_skips_or_fails = (
)
def wrapper_set_seed(op, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
torch.manual_seed(42)
def wrapper_noop_set_seed(op, *args, **kwargs):
return op(*args, **kwargs)
torch.testing._internal.common_methods_invocations.wrapper_set_seed = wrapper_set_seed
torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
wrapper_noop_set_seed
)
# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
torch._dynamo.variables.torch.tensor_dunder_fns.append(

View File

@ -1303,6 +1303,17 @@ _storage_classes = {
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
_tensor_classes: Set[Type] = set()
################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices
################################################################################
# needs to be before from .functional import * to avoid circular dependencies
from ._compile import _disable_dynamo
################################################################################
# Import miscelaneous torch functions
################################################################################
# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from .serialization import save, load
@ -1367,13 +1378,6 @@ for name in dir(_C._VariableFunctions):
################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices
################################################################################
# needs to be before from .functional import * to avoid circular dependencies
from ._compile import _disable_dynamo
################################################################################
# Import interface functions defined in Python
################################################################################

View File

@ -452,6 +452,9 @@ class TorchVariable(VariableTracker):
elif self.value is torch.nn.Parameter:
# https://github.com/pytorch/pytorch/issues/99569
unimplemented("torch.nn.Parameter not supported")
elif self.value is torch.manual_seed:
# https://github.com/pytorch/pytorch/issues/107187
unimplemented("torch.manual_seed not supported")
if (
self.value.__name__ == "get_state"
and hasattr(self.value, "__self__")

View File

@ -23,6 +23,7 @@ def get_rng_state() -> torch.Tensor:
return default_generator.get_state()
@torch._disable_dynamo
def manual_seed(seed) -> torch._C.Generator:
r"""Sets the seed for generating random numbers. Returns a
`torch.Generator` object.