mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Move skipIf decorator to class level in test_fx_graph_runnable (#157594)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157594 Approved by: https://github.com/xmfan ghstack dependencies: #157162
This commit is contained in:
parent
ddd74d10fc
commit
e44e05f7ae
|
|
@ -50,6 +50,7 @@ class ToyModel(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
class FxGraphRunnableTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
|
@ -92,7 +93,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
)
|
||||
|
||||
# basic tests
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_basic_tensor_add(self):
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
|
@ -100,7 +100,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(f)(torch.randn(4))
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_two_inputs_matmul(self):
|
||||
def f(a, b):
|
||||
return (a @ b).relu()
|
||||
|
|
@ -109,7 +108,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(f)(a, b)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_scalar_multiply(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
|
@ -118,7 +116,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
self._exec_and_verify_payload()
|
||||
|
||||
# testing dynamic shapes
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_dynamic_shapes_run(self):
|
||||
def f(x):
|
||||
return (x @ x.transpose(0, 1)).relu()
|
||||
|
|
@ -130,7 +127,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(f)(a)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_broadcast_add_dynamic(self):
|
||||
def f(x, y):
|
||||
return x + y * 2
|
||||
|
|
@ -143,7 +139,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(f)(x, y)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_toy_model_basic(self):
|
||||
model = ToyModel(input_size=8, hidden_size=16, output_size=4)
|
||||
model.eval() # Set to eval mode to avoid dropout randomness
|
||||
|
|
@ -152,7 +147,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(model)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_toy_model_batch_processing(self):
|
||||
model = ToyModel(input_size=12, hidden_size=24, output_size=6)
|
||||
model.eval()
|
||||
|
|
@ -161,7 +155,6 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(model)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
|
||||
def test_toy_model_dynamic_batch(self):
|
||||
model = ToyModel(input_size=10, hidden_size=20, output_size=5)
|
||||
model.eval()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user