mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Break graph on manual_seed. (#107594)
Fix: #107187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107594 Approved by: https://github.com/eellison
This commit is contained in:
parent
b7624fc91e
commit
6c28de2437
|
|
@ -1133,6 +1133,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
x = np.random.randn(2, 2)
|
||||
return x - x
|
||||
|
||||
def test_manual_seed(self):
|
||||
@torch.compile
|
||||
def foo():
|
||||
torch.manual_seed(3)
|
||||
return torch.randint(0, 5, (5,))
|
||||
|
||||
self.assertEqual(foo(), foo())
|
||||
self.assertEqual(foo(), foo())
|
||||
|
||||
|
||||
def global_func_with_default_tensor_args(
|
||||
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
|
||||
|
|
|
|||
|
|
@ -1772,14 +1772,6 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x, obj)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
def test_manual_seed(self):
|
||||
def fn(a, b):
|
||||
x = a + b
|
||||
torch.manual_seed(9000)
|
||||
return x + 1
|
||||
|
||||
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
|
||||
|
||||
def test_usr_cls_staticmethod(self):
|
||||
class Foo:
|
||||
@staticmethod
|
||||
|
|
@ -2226,13 +2218,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
torch.manual_seed(attention_seed)
|
||||
return (x,)
|
||||
|
||||
x = torch.randn(100, requires_grad=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
# Python code is needed here, since torch.manual_seed graph-breaks.
|
||||
# Refs: https://github.com/pytorch/pytorch/issues/107187
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
|
||||
res = opt_fn(x)
|
||||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(cnts.op_count, 1)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_is_tensor_like(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
|
|
|||
|
|
@ -264,7 +264,6 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"nn.functional.instance_norm": {f16},
|
||||
"nn.functional.local_response_norm": {f16},
|
||||
"nn.functional.normalize": {f16},
|
||||
"nn.functional.rrelu": {f16, f32, f64},
|
||||
"nn.functional.soft_margin_loss": {f16},
|
||||
"nn.functional.softsign": {f16},
|
||||
"nn.functional.triplet_margin_loss": {f16},
|
||||
|
|
@ -344,15 +343,13 @@ test_skips_or_fails = (
|
|||
)
|
||||
|
||||
|
||||
def wrapper_set_seed(op, *args, **kwargs):
|
||||
"""Wrapper to set seed manually for some functions like dropout
|
||||
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
def wrapper_noop_set_seed(op, *args, **kwargs):
|
||||
return op(*args, **kwargs)
|
||||
|
||||
|
||||
torch.testing._internal.common_methods_invocations.wrapper_set_seed = wrapper_set_seed
|
||||
torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
|
||||
wrapper_noop_set_seed
|
||||
)
|
||||
|
||||
# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
|
||||
torch._dynamo.variables.torch.tensor_dunder_fns.append(
|
||||
|
|
|
|||
|
|
@ -1303,6 +1303,17 @@ _storage_classes = {
|
|||
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
||||
_tensor_classes: Set[Type] = set()
|
||||
|
||||
################################################################################
|
||||
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
||||
################################################################################
|
||||
|
||||
# needs to be before from .functional import * to avoid circular dependencies
|
||||
from ._compile import _disable_dynamo
|
||||
|
||||
################################################################################
|
||||
# Import miscelaneous torch functions
|
||||
################################################################################
|
||||
|
||||
# If you edit these imports, please update torch/__init__.py.in as well
|
||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
||||
from .serialization import save, load
|
||||
|
|
@ -1367,13 +1378,6 @@ for name in dir(_C._VariableFunctions):
|
|||
|
||||
|
||||
|
||||
################################################################################
|
||||
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
||||
################################################################################
|
||||
|
||||
# needs to be before from .functional import * to avoid circular dependencies
|
||||
from ._compile import _disable_dynamo
|
||||
|
||||
################################################################################
|
||||
# Import interface functions defined in Python
|
||||
################################################################################
|
||||
|
|
|
|||
|
|
@ -452,6 +452,9 @@ class TorchVariable(VariableTracker):
|
|||
elif self.value is torch.nn.Parameter:
|
||||
# https://github.com/pytorch/pytorch/issues/99569
|
||||
unimplemented("torch.nn.Parameter not supported")
|
||||
elif self.value is torch.manual_seed:
|
||||
# https://github.com/pytorch/pytorch/issues/107187
|
||||
unimplemented("torch.manual_seed not supported")
|
||||
if (
|
||||
self.value.__name__ == "get_state"
|
||||
and hasattr(self.value, "__self__")
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ def get_rng_state() -> torch.Tensor:
|
|||
return default_generator.get_state()
|
||||
|
||||
|
||||
@torch._disable_dynamo
|
||||
def manual_seed(seed) -> torch._C.Generator:
|
||||
r"""Sets the seed for generating random numbers. Returns a
|
||||
`torch.Generator` object.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user