mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Differential Revision: D45416983nnPull Request resolved: https://github.com/pytorch/pytorch/pull/100388
216 lines
6.7 KiB
Python
216 lines
6.7 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import copy
|
|
from typing import Tuple
|
|
import unittest
|
|
|
|
import torch # noqa: F401
|
|
import torch.nn as nn
|
|
import torch._dynamo as torchdynamo
|
|
from functorch import make_fx
|
|
from functorch.experimental import functionalize
|
|
from torch import Tensor
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch._dynamo.eval_frame import is_dynamo_supported
|
|
|
|
from torch._export.verifier import (
|
|
SpecViolationError,
|
|
Verifier,
|
|
ATenDialectVerifier,
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def capture(f, args):
|
|
torchdynamo.config.capture_scalar_outputs = True
|
|
torchdynamo.config.guard_nn_modules = True
|
|
torchdynamo.config.dynamic_shapes = True
|
|
torchdynamo.config.allow_rnn = True
|
|
torchdynamo.config.verbose = True
|
|
torchdynamo.reset()
|
|
graphmodule, _ = torchdynamo.export(
|
|
f,
|
|
*copy.deepcopy(args),
|
|
aten_graph=True,
|
|
)
|
|
|
|
def graph_with_interpreter(*args):
|
|
with torch.fx.traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(graphmodule).run(*args)
|
|
|
|
functionalized_callable = functionalize(
|
|
graph_with_interpreter,
|
|
remove='mutations_and_views',
|
|
)
|
|
gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args)
|
|
return gm
|
|
|
|
|
|
class Transpose(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor:
|
|
return x.transpose(dim0, dim1)
|
|
|
|
|
|
class Mul(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, input: Tensor, other: Tensor) -> Tensor:
|
|
# or return torch.mul(input, other)
|
|
return input * other
|
|
|
|
def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
|
|
return (torch.randn(3, 2), torch.randn(3, 2))
|
|
|
|
|
|
class ElementwiseAdd(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: Tensor, y: Tensor) -> Tensor:
|
|
return x + y
|
|
|
|
def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
|
|
return (torch.randn(1, 3), torch.randn(1, 3))
|
|
|
|
|
|
class Cat(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
# def forward(self, tensors, dim=0):
|
|
def forward(self, *args: Tensor, dim: int) -> Tensor:
|
|
tensors = args[:-1]
|
|
return torch.cat(tensors, dim)
|
|
|
|
|
|
class FeedForwardBlock(nn.Module):
|
|
def __init__(self, input_dim: int, hidden_dim: int) -> None:
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
|
|
self.layer_norm = nn.LayerNorm(input_dim)
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
|
self.dropout1 = nn.Dropout()
|
|
|
|
self.linear2 = nn.Linear(hidden_dim, input_dim)
|
|
self.dropout2 = nn.Dropout()
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout
|
|
y = self.layer_norm(x)
|
|
y = self.linear1(y)
|
|
y = self.dropout1(y)
|
|
y = self.relu(y)
|
|
y = self.linear2(y)
|
|
y = self.dropout2(y)
|
|
return y
|
|
|
|
|
|
class VerifierTest(TestCase):
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_verifier(self) -> None:
|
|
m = ElementwiseAdd()
|
|
egm = capture(m, (torch.randn(100), torch.randn(100)))
|
|
# assert not throw
|
|
verifier = Verifier()
|
|
verifier(egm)
|
|
self.assertTrue(verifier.is_valid(egm))
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_verifier_call_module(self) -> None:
|
|
m = FeedForwardBlock(10, 10)
|
|
gm = torch.fx.symbolic_trace(m)
|
|
# this would have modules that are not delegates
|
|
verifier = Verifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier(gm)
|
|
self.assertFalse(verifier.is_valid(gm))
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_verifier_no_functional(self) -> None:
|
|
m = ElementwiseAdd()
|
|
egm = capture(m, (torch.randn(100), torch.randn(100)))
|
|
for node in egm.graph.nodes:
|
|
if node.target == torch.ops.aten.add.Tensor:
|
|
node.target = torch.ops.aten.add.out
|
|
verifier = Verifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier(egm)
|
|
self.assertFalse(verifier.is_valid(egm))
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_aten_dialect(self) -> None:
|
|
m = ElementwiseAdd()
|
|
egm = capture(m, (torch.randn(100), torch.randn(100)))
|
|
verifier = ATenDialectVerifier()
|
|
verifier(egm)
|
|
self.assertTrue(verifier.is_valid(egm))
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_aten_wrong_mem_format(self) -> None:
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = torch.nn.parameter.Parameter(
|
|
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.a + x
|
|
|
|
m = TestModel()
|
|
egm = capture(m, (torch.randn(1, 3, 100, 100),))
|
|
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
|
|
verifier = ATenDialectVerifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier(egm)
|
|
self.assertFalse(verifier.is_valid(egm))
|
|
|
|
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
|
def test_aten_wrong_mem_format_buffer(self) -> None:
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"a",
|
|
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.a + x
|
|
|
|
m = TestModel()
|
|
egm = capture(m, (torch.randn(1, 3, 100, 100),))
|
|
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
|
|
verifier = ATenDialectVerifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier(egm)
|
|
self.assertFalse(verifier.is_valid(egm))
|
|
|
|
def test_aten_wrong_op(self) -> None:
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten._add_relu(x, x)
|
|
|
|
m = TestModel()
|
|
egm = torch.fx.symbolic_trace(m)
|
|
verifier = ATenDialectVerifier()
|
|
with self.assertRaises(SpecViolationError):
|
|
verifier(egm)
|
|
self.assertFalse(verifier.is_valid(egm))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|