pytorch/test/package/package_c/test_module.py
2024-08-01 15:44:51 +00:00

24 lines
438 B
Python

# Owner(s): ["oncall: package/deploy"]
import torch
try:
from torchvision.models import resnet18
class TorchVisionTest(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.tvmod = resnet18()
def forward(self, x):
x = a_non_torch_leaf(x, x)
return torch.relu(x + 3.0)
except ImportError:
pass
def a_non_torch_leaf(a, b):
return a + b