mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Break graph on manual_seed. (#107594)
Fix: #107187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107594 Approved by: https://github.com/eellison
This commit is contained in:
parent
b1b9a3646a
commit
6ad5568cbc
|
|
@ -1133,6 +1133,15 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
x = np.random.randn(2, 2)
|
x = np.random.randn(2, 2)
|
||||||
return x - x
|
return x - x
|
||||||
|
|
||||||
|
def test_manual_seed(self):
|
||||||
|
@torch.compile
|
||||||
|
def foo():
|
||||||
|
torch.manual_seed(3)
|
||||||
|
return torch.randint(0, 5, (5,))
|
||||||
|
|
||||||
|
self.assertEqual(foo(), foo())
|
||||||
|
self.assertEqual(foo(), foo())
|
||||||
|
|
||||||
|
|
||||||
def global_func_with_default_tensor_args(
|
def global_func_with_default_tensor_args(
|
||||||
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
|
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
|
||||||
|
|
|
||||||
|
|
@ -1772,14 +1772,6 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||||
res = opt_fn(x, obj)
|
res = opt_fn(x, obj)
|
||||||
self.assertTrue(same(ref, res))
|
self.assertTrue(same(ref, res))
|
||||||
|
|
||||||
def test_manual_seed(self):
|
|
||||||
def fn(a, b):
|
|
||||||
x = a + b
|
|
||||||
torch.manual_seed(9000)
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
|
|
||||||
|
|
||||||
def test_usr_cls_staticmethod(self):
|
def test_usr_cls_staticmethod(self):
|
||||||
class Foo:
|
class Foo:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -2226,13 +2218,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||||
torch.manual_seed(attention_seed)
|
torch.manual_seed(attention_seed)
|
||||||
return (x,)
|
return (x,)
|
||||||
|
|
||||||
x = torch.randn(100, requires_grad=True)
|
x = torch.randn(10, requires_grad=True)
|
||||||
ref = fn(x)
|
ref = fn(x)
|
||||||
|
|
||||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
# Python code is needed here, since torch.manual_seed graph-breaks.
|
||||||
|
# Refs: https://github.com/pytorch/pytorch/issues/107187
|
||||||
|
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
|
||||||
res = opt_fn(x)
|
res = opt_fn(x)
|
||||||
|
|
||||||
self.assertTrue(same(ref, res))
|
self.assertTrue(same(ref, res))
|
||||||
|
self.assertEqual(cnts.op_count, 1)
|
||||||
|
self.assertEqual(cnts.frame_count, 1)
|
||||||
|
|
||||||
def test_is_tensor_like(self):
|
def test_is_tensor_like(self):
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,6 @@ inductor_expected_failures_single_sample["cuda"] = {
|
||||||
"nn.functional.instance_norm": {f16},
|
"nn.functional.instance_norm": {f16},
|
||||||
"nn.functional.local_response_norm": {f16},
|
"nn.functional.local_response_norm": {f16},
|
||||||
"nn.functional.normalize": {f16},
|
"nn.functional.normalize": {f16},
|
||||||
"nn.functional.rrelu": {f16, f32, f64},
|
|
||||||
"nn.functional.soft_margin_loss": {f16},
|
"nn.functional.soft_margin_loss": {f16},
|
||||||
"nn.functional.softsign": {f16},
|
"nn.functional.softsign": {f16},
|
||||||
"nn.functional.triplet_margin_loss": {f16},
|
"nn.functional.triplet_margin_loss": {f16},
|
||||||
|
|
@ -280,7 +279,6 @@ inductor_expected_failures_single_sample["cuda"] = {
|
||||||
"sparse.sampled_addmm": {f32, f64},
|
"sparse.sampled_addmm": {f32, f64},
|
||||||
("std_mean", "unbiased"): {f16},
|
("std_mean", "unbiased"): {f16},
|
||||||
"to_sparse": {f16, f32, f64},
|
"to_sparse": {f16, f32, f64},
|
||||||
"uniform": {f16, f32, f64},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -344,15 +342,13 @@ test_skips_or_fails = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def wrapper_set_seed(op, *args, **kwargs):
|
def wrapper_noop_set_seed(op, *args, **kwargs):
|
||||||
"""Wrapper to set seed manually for some functions like dropout
|
|
||||||
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
|
|
||||||
"""
|
|
||||||
torch.manual_seed(42)
|
|
||||||
return op(*args, **kwargs)
|
return op(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
torch.testing._internal.common_methods_invocations.wrapper_set_seed = wrapper_set_seed
|
torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
|
||||||
|
wrapper_noop_set_seed
|
||||||
|
)
|
||||||
|
|
||||||
# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
|
# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
|
||||||
torch._dynamo.variables.torch.tensor_dunder_fns.append(
|
torch._dynamo.variables.torch.tensor_dunder_fns.append(
|
||||||
|
|
|
||||||
|
|
@ -1303,6 +1303,17 @@ _storage_classes = {
|
||||||
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
||||||
_tensor_classes: Set[Type] = set()
|
_tensor_classes: Set[Type] = set()
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
# needs to be before from .functional import * to avoid circular dependencies
|
||||||
|
from ._compile import _disable_dynamo
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Import miscelaneous torch functions
|
||||||
|
################################################################################
|
||||||
|
|
||||||
# If you edit these imports, please update torch/__init__.py.in as well
|
# If you edit these imports, please update torch/__init__.py.in as well
|
||||||
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
||||||
from .serialization import save, load
|
from .serialization import save, load
|
||||||
|
|
@ -1367,13 +1378,6 @@ for name in dir(_C._VariableFunctions):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
# needs to be before from .functional import * to avoid circular dependencies
|
|
||||||
from ._compile import _disable_dynamo
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Import interface functions defined in Python
|
# Import interface functions defined in Python
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
|
||||||
|
|
@ -452,6 +452,9 @@ class TorchVariable(VariableTracker):
|
||||||
elif self.value is torch.nn.Parameter:
|
elif self.value is torch.nn.Parameter:
|
||||||
# https://github.com/pytorch/pytorch/issues/99569
|
# https://github.com/pytorch/pytorch/issues/99569
|
||||||
unimplemented("torch.nn.Parameter not supported")
|
unimplemented("torch.nn.Parameter not supported")
|
||||||
|
elif self.value is torch.manual_seed:
|
||||||
|
# https://github.com/pytorch/pytorch/issues/107187
|
||||||
|
unimplemented("torch.manual_seed not supported")
|
||||||
if (
|
if (
|
||||||
self.value.__name__ == "get_state"
|
self.value.__name__ == "get_state"
|
||||||
and hasattr(self.value, "__self__")
|
and hasattr(self.value, "__self__")
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ def get_rng_state() -> torch.Tensor:
|
||||||
return default_generator.get_state()
|
return default_generator.get_state()
|
||||||
|
|
||||||
|
|
||||||
|
@torch._disable_dynamo
|
||||||
def manual_seed(seed) -> torch._C.Generator:
|
def manual_seed(seed) -> torch._C.Generator:
|
||||||
r"""Sets the seed for generating random numbers. Returns a
|
r"""Sets the seed for generating random numbers. Returns a
|
||||||
`torch.Generator` object.
|
`torch.Generator` object.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user