mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] add test images to repo (#63717)
Summary: This is better than the status quo: * Test doesn't download files from the internet -> faster and more reliable. * Test doesn't leave the git working directory dirty. Rather than using the original images, I've copied some images from the pytorch/vision repo. This will keep the tests in the two repos in sync, while avoiding adding new assets to the vision repo. See https://github.com/pytorch/vision/pull/4176. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63717 Reviewed By: janeyx99 Differential Revision: D30466016 Pulled By: malfet fbshipit-source-id: 2c56d4c11b5c74db1764576bf1c95ce4ae714574
This commit is contained in:
parent
bafd875f74
commit
f1d865346f
BIN
test/onnx/assets/grace_hopper_517x606.jpg
Normal file
BIN
test/onnx/assets/grace_hopper_517x606.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 72 KiB |
BIN
test/onnx/assets/rgb_pytorch.png
Normal file
BIN
test/onnx/assets/rgb_pytorch.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 575 B |
|
|
@ -496,35 +496,20 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
# Only support CPU version, since tracer is not working in GPU RNN.
|
||||
self.run_test(model, (x, model.hidden))
|
||||
|
||||
def get_image_from_url(self, url, size=(300, 200)):
|
||||
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
|
||||
import os
|
||||
from urllib.parse import urlsplit
|
||||
from urllib import request
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torch._utils_internal import get_writable_path
|
||||
|
||||
filename = os.path.basename(urlsplit(url)[2])
|
||||
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__)))
|
||||
path = os.path.join(data_dir, filename)
|
||||
data = request.urlopen(url, timeout=15).read()
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
image = Image.open(path).convert("RGB")
|
||||
data_dir = os.path.join(os.path.dirname(__file__), "assets")
|
||||
path = os.path.join(data_dir, *rel_path.split("/"))
|
||||
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
|
||||
|
||||
image = image.resize(size, Image.BILINEAR)
|
||||
return transforms.ToTensor()(image)
|
||||
|
||||
to_tensor = transforms.ToTensor()
|
||||
return to_tensor(image)
|
||||
|
||||
def get_test_images(self):
|
||||
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
|
||||
image = self.get_image_from_url(url=image_url, size=(100, 320))
|
||||
|
||||
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
|
||||
image2 = self.get_image_from_url(url=image_url2, size=(250, 380))
|
||||
|
||||
return [image], [image2]
|
||||
def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
return ([self.get_image("grace_hopper_517x606.jpg", (100, 320))],
|
||||
[self.get_image("rgb_pytorch.png", (250, 380))])
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
@disableScriptTest() # Faster RCNN model is not scriptable
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user