mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
24 lines
438 B
Python
24 lines
438 B
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
import torch
|
|
|
|
|
|
try:
|
|
from torchvision.models import resnet18
|
|
|
|
class TorchVisionTest(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.tvmod = resnet18()
|
|
|
|
def forward(self, x):
|
|
x = a_non_torch_leaf(x, x)
|
|
return torch.relu(x + 3.0)
|
|
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def a_non_torch_leaf(a, b):
|
|
return a + b
|