[dynamo] Move skipIf decorator to class level in test_fx_graph_runnable (#157594)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157594
Approved by: https://github.com/xmfan
ghstack dependencies: #157162
This commit is contained in:
Sandeep Narendranath Karjala 2025-07-09 17:14:03 -07:00 committed by PyTorch MergeBot
parent ddd74d10fc
commit e44e05f7ae

View File

@ -50,6 +50,7 @@ class ToyModel(torch.nn.Module):
return x
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
class FxGraphRunnableTest(TestCase):
def setUp(self):
super().setUp()
@ -92,7 +93,6 @@ class FxGraphRunnableTest(TestCase):
)
# basic tests
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_basic_tensor_add(self):
def f(x):
return x + 1
@ -100,7 +100,6 @@ class FxGraphRunnableTest(TestCase):
torch.compile(f)(torch.randn(4))
self._exec_and_verify_payload()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_two_inputs_matmul(self):
def f(a, b):
return (a @ b).relu()
@ -109,7 +108,6 @@ class FxGraphRunnableTest(TestCase):
torch.compile(f)(a, b)
self._exec_and_verify_payload()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_scalar_multiply(self):
def f(x):
return x * 2
@ -118,7 +116,6 @@ class FxGraphRunnableTest(TestCase):
self._exec_and_verify_payload()
# testing dynamic shapes
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_dynamic_shapes_run(self):
def f(x):
return (x @ x.transpose(0, 1)).relu()
@ -130,7 +127,6 @@ class FxGraphRunnableTest(TestCase):
torch.compile(f)(a)
self._exec_and_verify_payload()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_broadcast_add_dynamic(self):
def f(x, y):
return x + y * 2
@ -143,7 +139,6 @@ class FxGraphRunnableTest(TestCase):
torch.compile(f)(x, y)
self._exec_and_verify_payload()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_toy_model_basic(self):
model = ToyModel(input_size=8, hidden_size=16, output_size=4)
model.eval() # Set to eval mode to avoid dropout randomness
@ -152,7 +147,6 @@ class FxGraphRunnableTest(TestCase):
torch.compile(model)(x)
self._exec_and_verify_payload()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_toy_model_batch_processing(self):
model = ToyModel(input_size=12, hidden_size=24, output_size=6)
model.eval()
@ -161,7 +155,6 @@ class FxGraphRunnableTest(TestCase):
torch.compile(model)(x)
self._exec_and_verify_payload()
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle")
def test_toy_model_dynamic_batch(self):
model = ToyModel(input_size=10, hidden_size=20, output_size=5)
model.eval()