mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Action following https://github.com/pytorch/pytorch/issues/66232 This change does require some context: there were several suggestions regarding what to do about this group of tests: tests that are core and crucial to all of PyTorch and are too broad to be owned by one team. 1. Let's add a "module: core" and put people behind it! This idea sounds appealing unless you are one of the people backing the label. From talking to albanD among others, this idea of putting all these core tests on the shoulder of a few people or one team isn't super fair and I have not yet found anyone willing to take on this job. 2. Taking advantage of the fact that we already have a triaging oncall that takes turns triaging issues, we can leave these tests essentially unlabeled and allow the oncall to triage these tests. Since these tests are crucial to PyTorch, we'll add the "high priority" label to mark them different from other unowned tests (see https://github.com/pytorch/pytorch/issues/67552). 3. I _could_ still create an unbacked label "module: core" and attribute these tests there, but I don't like the idea of creating a facade that the tests are "triaged" to a label when no one is actually taking a look. Now we could potentially break these tests down into smaller files so that each piece _could_ be owned by a team, but 1. I don't know if this is currently feasible and 2. This approach does not prevent that from happening in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67553 Reviewed By: albanD Differential Revision: D32025004 Pulled By: janeyx99 fbshipit-source-id: 1fb1aa4c27e305695ab6e80ae3d02f90519939c0
872 lines
31 KiB
Python
872 lines
31 KiB
Python
# Owner(s): ["high priority"]
|
|
|
|
import sys
|
|
import os
|
|
import contextlib
|
|
import http
|
|
import io
|
|
import re
|
|
import shutil
|
|
import random
|
|
import subprocess
|
|
import tempfile
|
|
import textwrap
|
|
import unittest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.data
|
|
from torch.utils.data import DataLoader
|
|
import torch.cuda
|
|
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
|
import torch.utils.cpp_extension
|
|
import torch.hub as hub
|
|
from torch.autograd._functions.utils import check_onnx_broadcast
|
|
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
|
|
from torch.testing._internal.common_utils import has_breakpad, load_tests, retry, IS_SANDCASTLE, IS_WINDOWS, TEST_WITH_ASAN
|
|
from urllib.error import URLError
|
|
|
|
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
HAS_CUDA = torch.cuda.is_available()
|
|
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
|
|
class RandomDatasetMock(torch.utils.data.Dataset):
|
|
|
|
def __getitem__(self, index):
|
|
return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)])
|
|
|
|
def __len__(self):
|
|
return 1000
|
|
|
|
|
|
class TestCheckpoint(TestCase):
|
|
|
|
# This runs checkpoint_sequential on each of the nets in
|
|
# module_lists_to_compare, and compares them against the uncheckpointed model.
|
|
# To compare, it checks outputs as well as input gradients and parameter gradients
|
|
def _check_checkpoint_sequential(
|
|
self,
|
|
model,
|
|
module_lists_to_compare,
|
|
num_chunks,
|
|
input,
|
|
):
|
|
|
|
# not checkpointed
|
|
out = model(input)
|
|
out_not_checkpointed = out.detach().clone()
|
|
model.zero_grad()
|
|
out.sum().backward()
|
|
grad_not_checkpointed = {
|
|
name: param.grad.detach().clone()
|
|
for name, param in model.named_parameters()
|
|
}
|
|
input_grad_not_checkpointed = input.grad.detach().clone()
|
|
for model_to_compare in module_lists_to_compare:
|
|
# checkpointed model by passing list of modules
|
|
detached = input.detach()
|
|
detached.requires_grad = True
|
|
|
|
# pass list of modules to checkpoint
|
|
out = checkpoint_sequential(model_to_compare, num_chunks, detached)
|
|
out_checkpointed = out.detach().clone()
|
|
model.zero_grad()
|
|
out.sum().backward()
|
|
grad_checkpointed = {
|
|
name: param.grad.detach().clone()
|
|
for name, param in model.named_parameters()
|
|
}
|
|
input_grad_checkpointed = detached.grad.detach().clone()
|
|
# compare outputs as well as the gradients of input and parameters
|
|
self.assertEqual(out_checkpointed, out_not_checkpointed)
|
|
self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed)
|
|
for name in grad_checkpointed:
|
|
self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
|
|
|
|
# Test whether checkpoint is being triggered or not. For this, we check
|
|
# the number of times forward pass happens
|
|
def test_checkpoint_trigger(self):
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.counter = 0
|
|
|
|
def forward(self, input_var):
|
|
self.counter += 1
|
|
return input_var
|
|
|
|
# checkpointed
|
|
modules = [Net() for _ in range(10)]
|
|
for m in modules:
|
|
self.assertEqual(m.counter, 0)
|
|
input_var = torch.randn(3, 4, requires_grad=True)
|
|
out = checkpoint_sequential(modules, 2, input_var)
|
|
for m in modules:
|
|
self.assertEqual(m.counter, 1)
|
|
out.sum().backward()
|
|
for m in modules[:(len(modules) // 2)]:
|
|
self.assertEqual(m.counter, 2)
|
|
for m in modules[(len(modules) // 2):]:
|
|
self.assertEqual(m.counter, 1)
|
|
|
|
def test_checkpoint_valid(self):
|
|
model = nn.Sequential(
|
|
nn.Linear(100, 50),
|
|
nn.ReLU(),
|
|
nn.Linear(50, 20),
|
|
nn.ReLU(),
|
|
nn.Linear(20, 5),
|
|
nn.ReLU()
|
|
)
|
|
|
|
input_var = torch.randn(1, 100, requires_grad=True)
|
|
|
|
# checkpointed
|
|
chunks = 2
|
|
modules = list(model.children())
|
|
out = checkpoint_sequential(modules, chunks, input_var)
|
|
with self.assertRaisesRegex(RuntimeError, "Checkpointing is not compatible"):
|
|
torch.autograd.grad(
|
|
outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True
|
|
)
|
|
|
|
def test_checkpoint(self):
|
|
model = nn.Sequential(
|
|
nn.Linear(100, 50),
|
|
nn.ReLU(),
|
|
nn.Linear(50, 20),
|
|
nn.ReLU(),
|
|
nn.Linear(20, 5),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Compare uncheckpointed model with its checkpointed counterparts
|
|
# In addition to running checkpoint_sequential on the nn.Sequential
|
|
# instance, we also run the function on the list of functions within
|
|
# the module.
|
|
self._check_checkpoint_sequential(
|
|
model,
|
|
[list(model.children()), model],
|
|
2,
|
|
torch.randn(1, 100, requires_grad=True)
|
|
)
|
|
|
|
def test_checkpoint_module_list(self):
|
|
class ModuleListNet(nn.Module):
|
|
def __init__(self):
|
|
super(ModuleListNet, self).__init__()
|
|
module_list = [
|
|
nn.Linear(100, 50),
|
|
nn.ReLU(),
|
|
nn.Linear(50, 20),
|
|
nn.ReLU(),
|
|
nn.Linear(20, 5),
|
|
nn.ReLU(),
|
|
]
|
|
self.module_list = nn.ModuleList(module_list)
|
|
|
|
def forward(self, input):
|
|
for layer in self.module_list:
|
|
input = layer(input)
|
|
return input
|
|
|
|
model = ModuleListNet()
|
|
|
|
# Compare uncheckpointed model with its checkpointed counterparts.
|
|
self._check_checkpoint_sequential(
|
|
model,
|
|
[list(model.module_list.children()), model.module_list],
|
|
2,
|
|
torch.randn(1, 100, requires_grad=True),
|
|
)
|
|
|
|
def test_checkpoint_sequential_deprecated_multiple_args(self):
|
|
class Two(nn.Module):
|
|
def forward(self, a, b):
|
|
return a, b
|
|
|
|
model = nn.Sequential(Two())
|
|
a = torch.randn(1, 100, requires_grad=True)
|
|
b = torch.randn(1, 100, requires_grad=True)
|
|
|
|
with self.assertRaises(TypeError):
|
|
checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg]
|
|
|
|
def test_checkpoint_sequential_deprecated_no_args(self):
|
|
class Noop(nn.Module):
|
|
def forward(self):
|
|
pass
|
|
|
|
model = nn.Sequential(Noop())
|
|
|
|
with self.assertRaises(TypeError):
|
|
checkpoint_sequential(model, 1) # type: ignore[call-arg]
|
|
|
|
def test_checkpoint_rng_cpu(self):
|
|
for _ in range(5):
|
|
inp = torch.randn(20000, device='cpu').requires_grad_()
|
|
phase1 = torch.nn.Dropout()
|
|
phase2 = torch.nn.Dropout()
|
|
|
|
def run_fn(input):
|
|
return phase2(input)
|
|
|
|
state = torch.get_rng_state()
|
|
|
|
out = phase1(inp)
|
|
out = checkpoint(run_fn, out)
|
|
out.sum().backward()
|
|
grad_with_checkpointing = inp.grad
|
|
|
|
torch.set_rng_state(state)
|
|
|
|
inp.grad = None
|
|
|
|
out = phase1(inp)
|
|
out = run_fn(out)
|
|
out.sum().backward()
|
|
grad_no_checkpointing = inp.grad
|
|
|
|
self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
|
|
|
|
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
|
|
def test_checkpoint_rng_cuda(self):
|
|
for _ in range(5):
|
|
inp = torch.randn(20000, device='cuda').requires_grad_()
|
|
phase1 = torch.nn.Dropout()
|
|
phase2 = torch.nn.Dropout()
|
|
|
|
def run_fn(input):
|
|
return phase2(input)
|
|
|
|
state = torch.cuda.get_rng_state()
|
|
|
|
out = phase1(inp)
|
|
out = checkpoint(run_fn, out)
|
|
out.sum().backward()
|
|
grad_with_checkpointing = inp.grad
|
|
|
|
torch.cuda.set_rng_state(state)
|
|
|
|
inp.grad = None
|
|
|
|
out = phase1(inp)
|
|
out = run_fn(out)
|
|
out.sum().backward()
|
|
grad_no_checkpointing = inp.grad
|
|
|
|
self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
|
|
|
|
def test_checkpoint_non_tensor(self):
|
|
|
|
def run_fn(tensor1, tensor2):
|
|
if tensor2 is None:
|
|
return tensor1
|
|
return tensor1 + tensor2
|
|
|
|
input_var = torch.randn(1, 100, requires_grad=True)
|
|
out = checkpoint(run_fn, input_var, None)
|
|
out.sum().backward()
|
|
|
|
def test_checkpoint_non_tensor_inputs_outputs(self):
|
|
def foo(t1, t2, scale, t3):
|
|
t4 = t1 + t2 * t3
|
|
t5 = t1 * t2 + t3
|
|
t4 *= scale
|
|
t5 *= scale
|
|
return scale, t4, None, True, t5, "bar", t1
|
|
|
|
t1 = torch.rand(10, requires_grad=True)
|
|
t2 = torch.rand(10, requires_grad=True)
|
|
t3 = torch.rand(10)
|
|
scale = random.randint(0, 10)
|
|
res = checkpoint(foo, t1, t2, scale, t3)
|
|
self.assertEqual(scale, res[0])
|
|
self.assertEqual((t1 + t2 * t3) * scale, res[1])
|
|
self.assertEqual(None, res[2])
|
|
self.assertEqual(True, res[3])
|
|
self.assertEqual((t1 * t2 + t3) * scale, res[4])
|
|
self.assertEqual("bar", res[5])
|
|
self.assertEqual(t1, res[6])
|
|
|
|
# Validate running backward.
|
|
res[1].sum().backward(retain_graph=True)
|
|
res[4].sum().backward(retain_graph=True)
|
|
res[6].sum().backward()
|
|
with self.assertRaisesRegex(RuntimeError, "Trying to backward through the graph a second time"):
|
|
res[6].sum().backward()
|
|
t1_grad = t1.grad
|
|
t2_grad = t2.grad
|
|
|
|
# Reset grads, run without checkpoint and validate we receive same grads.
|
|
t1.grad = None
|
|
t2.grad = None
|
|
res = foo(t1, t2, scale, t3)
|
|
torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
|
|
self.assertEqual(t1.grad, t1_grad)
|
|
self.assertEqual(t2.grad, t2_grad)
|
|
|
|
def test_checkpoint_no_tensors(self):
|
|
def foo(t1, t2, scale, t3):
|
|
t4 = t1 + t2 * t3
|
|
t5 = t1 * t2 + t3
|
|
t4 *= scale
|
|
t5 *= scale
|
|
return scale, t4, None, True, t5, "bar", t1
|
|
|
|
t1 = random.random()
|
|
t2 = random.random()
|
|
t3 = random.random()
|
|
scale = random.randint(0, 10)
|
|
res = checkpoint(foo, t1, t2, scale, t3)
|
|
self.assertEqual(scale, res[0])
|
|
self.assertEqual((t1 + t2 * t3) * scale, res[1])
|
|
self.assertEqual(None, res[2])
|
|
self.assertEqual(True, res[3])
|
|
self.assertEqual((t1 * t2 + t3) * scale, res[4])
|
|
self.assertEqual("bar", res[5])
|
|
self.assertEqual(t1, res[6])
|
|
|
|
def test_checkpoint_partial_grad(self):
|
|
def run_fn(tensor1, tensor2):
|
|
# tensor 2 is used for other application logic
|
|
return tensor1, tensor2
|
|
input_var = torch.randn(1, 4, requires_grad=True)
|
|
input_var2 = torch.randn(1, 4, requires_grad=False)
|
|
out = checkpoint(run_fn, input_var, input_var2)
|
|
out[0].sum().backward()
|
|
|
|
def run_fn2(tensor1, tensor2):
|
|
return tensor1
|
|
input_var = torch.randn(1, 4, requires_grad=False)
|
|
input_var2 = torch.randn(1, 4, requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"none of output has requires_grad=True, this checkpoint\(\) is not necessary"
|
|
):
|
|
out = checkpoint(run_fn2, input_var, input_var2)
|
|
out.sum().backward()
|
|
|
|
class TestDataLoaderUtils(TestCase):
|
|
def setUp(self):
|
|
self.dataset = torch.randn(5, 3, 3, 2)
|
|
self.batch_size = 3
|
|
|
|
def test_random_seed(self):
|
|
def run():
|
|
dataloader = torch.utils.data.DataLoader(RandomDatasetMock(),
|
|
batch_size=2,
|
|
num_workers=4,
|
|
shuffle=True)
|
|
return next(iter(dataloader))
|
|
|
|
torch.manual_seed(2018)
|
|
x1 = run()
|
|
torch.manual_seed(2018)
|
|
x2 = run()
|
|
self.assertEqual(x1, x2)
|
|
|
|
def test_single_keep(self):
|
|
# self.dataset is a Tensor here; technically not a valid input because
|
|
# not a Dataset subclass, but needs to stay working so add ignore's
|
|
# for type checking with mypy
|
|
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
|
batch_size=self.batch_size,
|
|
num_workers=0,
|
|
drop_last=False)
|
|
dataiter = iter(dataloader)
|
|
self.assertEqual(len(list(dataiter)), 2)
|
|
|
|
def test_single_drop(self):
|
|
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
|
batch_size=self.batch_size,
|
|
num_workers=0,
|
|
drop_last=True)
|
|
dataiter = iter(dataloader)
|
|
self.assertEqual(len(list(dataiter)), 1)
|
|
|
|
@unittest.skip("FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN")
|
|
def test_multi_keep(self):
|
|
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
|
batch_size=self.batch_size,
|
|
num_workers=2,
|
|
drop_last=False)
|
|
dataiter = iter(dataloader)
|
|
self.assertEqual(len(list(dataiter)), 2)
|
|
|
|
def test_multi_drop(self):
|
|
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
|
batch_size=self.batch_size,
|
|
num_workers=2,
|
|
drop_last=True)
|
|
dataiter = iter(dataloader)
|
|
self.assertEqual(len(list(dataiter)), 1)
|
|
|
|
|
|
test_dir = os.path.abspath(os.path.dirname(str(__file__)))
|
|
|
|
|
|
class TestFFI(TestCase):
|
|
def test_deprecated(self):
|
|
with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."):
|
|
from torch.utils.ffi import create_extension # type: ignore[attr-defined] # noqa: F401
|
|
|
|
|
|
@unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set')
|
|
class TestBottleneck(TestCase):
|
|
def _run(self, command, timeout=30):
|
|
"""Returns (return-code, stdout, stderr)"""
|
|
import subprocess
|
|
|
|
p = subprocess.Popen(command, stdout=subprocess.PIPE, # noqa: P204
|
|
stderr=subprocess.PIPE, shell=True)
|
|
try:
|
|
output, err = p.communicate(timeout=timeout)
|
|
except subprocess.TimeoutExpired:
|
|
p.kill()
|
|
output, err = p.communicate()
|
|
rc = p.returncode
|
|
output_str = output.decode("ascii")
|
|
err_str = err.decode("ascii")
|
|
return (rc, output_str, err_str)
|
|
|
|
def _run_bottleneck(self, test_file, scriptargs=''):
|
|
curdir = os.path.dirname(os.path.abspath(__file__))
|
|
filepath = '{}/{}'.format(curdir, test_file)
|
|
if scriptargs != '':
|
|
scriptargs = ' {}'.format(scriptargs)
|
|
rc, out, err = self._run(
|
|
'{} -m torch.utils.bottleneck {}{}'.format(sys.executable, filepath, scriptargs))
|
|
return rc, out, err
|
|
|
|
def _check_run_args(self):
|
|
# Check that this fails due to missing args
|
|
rc, out, err = self._run_bottleneck('bottleneck_test/test_args.py')
|
|
self.assertEqual(rc, 2, atol=0, rtol=0, msg=self._fail_msg('Missing args should error', out + err))
|
|
|
|
# This should succeed
|
|
rc, out, err = self._run_bottleneck('bottleneck_test/test_args.py', '--foo foo --bar bar')
|
|
self.assertEqual(rc, 0, atol=0, rtol=0, msg=self._fail_msg('Should pass args to script', out + err))
|
|
|
|
def _fail_msg(self, msg, output):
|
|
return '{}, output was:\n{}'.format(msg, output)
|
|
|
|
def _check_environment_summary(self, output):
|
|
results = re.search('Environment Summary', output)
|
|
self.assertIsNotNone(results, self._fail_msg('Should have Environment Summary', output))
|
|
|
|
# Up to five lines away from the heading, there should be the version number
|
|
results = re.search(r'Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+', output)
|
|
self.assertIsNotNone(results, self._fail_msg('Should have PyTorch version', output))
|
|
|
|
def _check_cprof_summary(self, output):
|
|
results = re.search('cProfile output', output)
|
|
self.assertIsNotNone(results, self._fail_msg('Should have cProfile output', output))
|
|
|
|
# This assumes that after the cProfile output section we have
|
|
# the autograd profiler output
|
|
results = re.search(r'cProfile output.*(\n.*){6,50}\n.*autograd profiler output', output)
|
|
self.assertIsNotNone(results, self._fail_msg(
|
|
'Distance between cProfile and autograd prof out not in [6, 50] lines', output))
|
|
|
|
def _check_autograd_summary(self, output):
|
|
results = re.search('autograd profiler output', output)
|
|
self.assertIsNotNone(results, self._fail_msg('Should have autograd profiler output', output))
|
|
|
|
# This assumes that after the autograd profiler output is the end of the
|
|
# output.
|
|
results = re.search(r'autograd profiler output.*(\n.*){6,100}', output)
|
|
self.assertIsNotNone(results, self._fail_msg(
|
|
'Distance between autograd prof output and end of output not in [6, 100] lines', output))
|
|
|
|
def _check_cuda(self, output):
|
|
if HAS_CUDA:
|
|
results = re.search('CUDA mode', output)
|
|
self.assertIsNotNone(results, self._fail_msg('Should tell users CUDA', output))
|
|
else:
|
|
results = re.search('CUDA mode', output)
|
|
self.assertIsNone(results, self._fail_msg('Should not tell users about CUDA', output))
|
|
|
|
@unittest.skipIf(HAS_CUDA, 'CPU-only test')
|
|
def test_bottleneck_cpu_only(self):
|
|
rc, out, err = self._run_bottleneck('bottleneck_test/test.py')
|
|
self.assertEqual(rc, 0, msg='Run failed with\n{}'.format(err))
|
|
|
|
self._check_run_args()
|
|
self._check_environment_summary(out)
|
|
self._check_autograd_summary(out)
|
|
self._check_cprof_summary(out)
|
|
self._check_cuda(out)
|
|
|
|
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
|
|
def test_bottleneck_cuda(self):
|
|
rc, out, err = self._run_bottleneck('bottleneck_test/test_cuda.py')
|
|
self.assertEqual(rc, 0, msg='Run failed with\n{}'.format(err))
|
|
|
|
self._check_run_args()
|
|
self._check_environment_summary(out)
|
|
self._check_autograd_summary(out)
|
|
self._check_cprof_summary(out)
|
|
self._check_cuda(out)
|
|
|
|
|
|
from torch.utils.collect_env import get_pretty_env_info
|
|
|
|
|
|
class TestCollectEnv(TestCase):
|
|
def test_smoke(self):
|
|
info_output = get_pretty_env_info()
|
|
self.assertTrue(info_output.count('\n') >= 17)
|
|
|
|
|
|
class TestONNXUtils(TestCase):
|
|
def test_prepare_onnx_paddings(self):
|
|
sizes = [2, 3, 4]
|
|
pad = [1, 2, 3, 4]
|
|
paddings = _prepare_onnx_paddings(len(sizes), pad)
|
|
self.assertEqual(paddings, [0, 3, 1, 0, 4, 2])
|
|
|
|
def test_check_onnx_broadcast(self):
|
|
|
|
def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail):
|
|
broadcast = True
|
|
fail = False
|
|
try:
|
|
broadcast = check_onnx_broadcast(dims1, dims2)
|
|
except ValueError:
|
|
fail = True
|
|
self.assertEqual(broadcast, expect_broadcast)
|
|
self.assertEqual(fail, expect_fail)
|
|
|
|
# Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1
|
|
dims1 = [3, 4]
|
|
dims2 = [2, 3, 4]
|
|
try_check_onnx_broadcast(dims1, dims2, True, True)
|
|
|
|
# Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1
|
|
dims1 = [3, 4]
|
|
dims2 = [1, 1, 1]
|
|
try_check_onnx_broadcast(dims1, dims2, True, False)
|
|
|
|
# Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1
|
|
dims1 = [1, 1]
|
|
dims2 = [1]
|
|
try_check_onnx_broadcast(dims1, dims2, True, False)
|
|
|
|
# Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2
|
|
dims1 = [2, 3, 4]
|
|
dims2 = [3, 4]
|
|
try_check_onnx_broadcast(dims1, dims2, True, False)
|
|
|
|
# Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2
|
|
dims1 = [2, 3, 4]
|
|
dims2 = [1, 4]
|
|
try_check_onnx_broadcast(dims1, dims2, True, True)
|
|
|
|
# Case 6, check the equal case, no broadcast
|
|
dims1 = [3, 4]
|
|
dims2 = [3, 4]
|
|
try_check_onnx_broadcast(dims1, dims2, False, False)
|
|
|
|
# Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2
|
|
dims1 = [3, 4]
|
|
dims2 = [1, 4]
|
|
try_check_onnx_broadcast(dims1, dims2, True, True)
|
|
|
|
# Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1
|
|
dims1 = [3, 4]
|
|
dims2 = [1, 1]
|
|
try_check_onnx_broadcast(dims1, dims2, True, False)
|
|
|
|
|
|
def sum_of_state_dict(state_dict):
|
|
s = 0
|
|
for _, v in state_dict.items():
|
|
s += v.sum()
|
|
return s
|
|
|
|
SUM_OF_HUB_EXAMPLE = 431080
|
|
TORCHHUB_EXAMPLE_RELEASE_URL = 'https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones'
|
|
|
|
@unittest.skipIf(IS_SANDCASTLE, 'Sandcastle cannot ping external')
|
|
class TestHub(TestCase):
|
|
@retry(URLError, tries=3)
|
|
def test_load_from_github(self):
|
|
hub_model = hub.load(
|
|
'ailzhang/torchhub_example',
|
|
'mnist',
|
|
source='github',
|
|
pretrained=True,
|
|
verbose=False)
|
|
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_load_from_local_dir(self):
|
|
local_dir = hub._get_cache_or_reload(
|
|
'ailzhang/torchhub_example', force_reload=False)
|
|
hub_model = hub.load(
|
|
local_dir,
|
|
'mnist',
|
|
source='local',
|
|
pretrained=True,
|
|
verbose=False)
|
|
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_load_from_branch(self):
|
|
hub_model = hub.load(
|
|
'ailzhang/torchhub_example:ci/test_slash',
|
|
'mnist',
|
|
pretrained=True,
|
|
verbose=False)
|
|
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_set_dir(self):
|
|
temp_dir = tempfile.gettempdir()
|
|
hub.set_dir(temp_dir)
|
|
hub_model = hub.load(
|
|
'ailzhang/torchhub_example',
|
|
'mnist',
|
|
pretrained=True,
|
|
verbose=False)
|
|
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
assert os.path.exists(temp_dir + '/ailzhang_torchhub_example_master')
|
|
shutil.rmtree(temp_dir + '/ailzhang_torchhub_example_master')
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_list_entrypoints(self):
|
|
entry_lists = hub.list('ailzhang/torchhub_example', force_reload=True)
|
|
self.assertObjectIn('mnist', entry_lists)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_download_url_to_file(self):
|
|
temp_file = os.path.join(tempfile.gettempdir(), 'temp')
|
|
hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, temp_file, progress=False)
|
|
loaded_state = torch.load(temp_file)
|
|
self.assertEqual(sum_of_state_dict(loaded_state),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
@retry(URLError, tries=3)
|
|
@retry(http.client.RemoteDisconnected, tries=3)
|
|
def test_load_state_dict_from_url(self):
|
|
loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
|
|
self.assertEqual(sum_of_state_dict(loaded_state),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_load_zip_checkpoint(self):
|
|
hub_model = hub.load(
|
|
'ailzhang/torchhub_example',
|
|
'mnist_zip',
|
|
pretrained=True,
|
|
verbose=False)
|
|
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
# Test the default zipfile serialization format produced by >=1.6 release.
|
|
@retry(URLError, tries=3)
|
|
def test_load_zip_1_6_checkpoint(self):
|
|
hub_model = hub.load(
|
|
'ailzhang/torchhub_example',
|
|
'mnist_zip_1_6',
|
|
pretrained=True,
|
|
verbose=False)
|
|
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
|
|
def test_hub_dir(self):
|
|
with tempfile.TemporaryDirectory('hub_dir') as dirname:
|
|
torch.hub.set_dir(dirname)
|
|
self.assertEqual(torch.hub.get_dir(), dirname)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_hub_parse_repo_info(self):
|
|
# If the branch is specified we just parse the input and return
|
|
self.assertEqual(
|
|
torch.hub._parse_repo_info('a/b:c'),
|
|
('a', 'b', 'c')
|
|
)
|
|
# For torchvision, the default branch is main
|
|
self.assertEqual(
|
|
torch.hub._parse_repo_info('pytorch/vision'),
|
|
('pytorch', 'vision', 'main')
|
|
)
|
|
# For the torchhub_example repo, the default branch is still master
|
|
self.assertEqual(
|
|
torch.hub._parse_repo_info('ailzhang/torchhub_example'),
|
|
('ailzhang', 'torchhub_example', 'master')
|
|
)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_load_state_dict_from_url_with_name(self):
|
|
with tempfile.TemporaryDirectory('hub_dir') as dirname:
|
|
torch.hub.set_dir(dirname)
|
|
file_name = 'test_file'
|
|
loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name)
|
|
self.assertTrue(os.path.exists(os.path.join(dirname, 'checkpoints', file_name)))
|
|
self.assertEqual(sum_of_state_dict(loaded_state),
|
|
SUM_OF_HUB_EXAMPLE)
|
|
|
|
@retry(URLError, tries=3)
|
|
def test_load_commit_from_forked_repo(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'If it\'s a commit from a forked repo'):
|
|
model = torch.hub.load('pytorch/vision:4e2c216', 'resnet18', force_reload=True)
|
|
|
|
class TestHipify(TestCase):
|
|
def test_import_hipify(self):
|
|
from torch.utils.hipify import hipify_python # noqa: F401
|
|
|
|
|
|
class TestAssert(TestCase):
|
|
def test_assert_true(self):
|
|
# verify assertions work as expected
|
|
# bool argument
|
|
torch._assert(True, "foo")
|
|
with self.assertRaisesRegex(AssertionError, "bar"):
|
|
torch._assert(False, "bar")
|
|
# tensor argument
|
|
torch._assert(torch.tensor([True], dtype=torch.bool), "foo")
|
|
with self.assertRaisesRegex(AssertionError, "bar"):
|
|
torch._assert(torch.tensor([False], dtype=torch.bool), "bar")
|
|
|
|
def test_assert_scriptable(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch._assert(x.sum() > 0, "foo")
|
|
return x
|
|
|
|
m = M()
|
|
# scriptable
|
|
ms = torch.jit.script(m)
|
|
# data can be passed without errors
|
|
x = torch.randn(4, 4).fill_(1.0)
|
|
ms(x)
|
|
with self.assertRaisesRegex(torch.jit.Error, "foo"):
|
|
ms(torch.tensor([False], dtype=torch.bool))
|
|
|
|
|
|
class TestCrashHandler(TestCase):
|
|
@unittest.skipIf(TEST_WITH_ASAN, "ASAN disables the crash handler's signal handler")
|
|
@unittest.skipIf(not has_breakpad(), "Built without breakpad")
|
|
def test_python_exception_writing(self):
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
torch.utils._crash_handler.enable_minidumps(temp_dir)
|
|
torch.utils._crash_handler.enable_minidumps_on_exceptions()
|
|
|
|
files = os.listdir(temp_dir)
|
|
self.assertEqual(len(files), 0)
|
|
|
|
f = io.StringIO()
|
|
with contextlib.redirect_stderr(f):
|
|
try:
|
|
@torch.jit.script
|
|
def x(i: int):
|
|
return i + "2" # type: ignore[operator]
|
|
except RuntimeError as e:
|
|
pass
|
|
|
|
files = os.listdir(temp_dir)
|
|
self.assertEqual(len(files), 1)
|
|
self.assertTrue(files[0].endswith(".dmp"))
|
|
torch.utils._crash_handler.disable_minidumps()
|
|
|
|
|
|
@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only")
|
|
class TestStandaloneCPPJIT(TestCase):
|
|
def test_load_standalone(self):
|
|
build_dir = tempfile.mkdtemp()
|
|
try:
|
|
src_path = os.path.join(build_dir, "main.cpp")
|
|
src = textwrap.dedent("""\
|
|
#include <iostream>
|
|
#include <torch/torch.h>
|
|
int main() {
|
|
auto x = torch::eye(3);
|
|
std::cout << x << std::endl;
|
|
}
|
|
""")
|
|
with open(src_path, "wt") as f:
|
|
f.write(src)
|
|
|
|
exec_path = torch.utils.cpp_extension.load(
|
|
"standalone_load_test",
|
|
src_path,
|
|
build_directory=build_dir,
|
|
is_python_module=False,
|
|
is_standalone=True,
|
|
)
|
|
|
|
ext = ".exe" if IS_WINDOWS else ""
|
|
self.assertEqual(
|
|
exec_path,
|
|
os.path.join(build_dir, f"standalone_load_test{ext}")
|
|
)
|
|
|
|
for shell in [True, False]:
|
|
r = subprocess.run(
|
|
[exec_path],
|
|
shell=shell,
|
|
stdout=subprocess.PIPE,
|
|
)
|
|
self.assertEqual(r.returncode, 0)
|
|
self.assertEqual(
|
|
# Windows prints "\r\n" for newlines.
|
|
textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"),
|
|
textwrap.dedent("""\
|
|
1 0 0
|
|
0 1 0
|
|
0 0 1
|
|
[ CPUFloatType{3,3} ]
|
|
""")
|
|
)
|
|
|
|
finally:
|
|
shutil.rmtree(build_dir)
|
|
|
|
|
|
class DummyXPUModule(object):
|
|
@staticmethod
|
|
def is_available():
|
|
return True
|
|
|
|
|
|
class TestExtensionUtils(TestCase):
|
|
def test_external_module_register(self):
|
|
# Built-in module
|
|
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
|
torch._register_device_module('cuda', torch.cuda)
|
|
|
|
# Wrong device type
|
|
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
|
torch._register_device_module('dummmy', DummyXPUModule)
|
|
|
|
with self.assertRaises(AttributeError):
|
|
torch.xpu.is_available() # type: ignore[attr-defined]
|
|
|
|
torch._register_device_module('xpu', DummyXPUModule)
|
|
|
|
torch.xpu.is_available() # type: ignore[attr-defined]
|
|
|
|
# No supporting for override
|
|
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
|
torch._register_device_module('xpu', DummyXPUModule)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|