mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D25757691: [pytorch][PR] Run mypy over test/test_utils.py
Test Plan: revert-hammer
Differential Revision:
D25757691 (c86cfcd81d)
Original commit changeset: 145ce3ae532c
fbshipit-source-id: 3dfd68f0c42fc074cde15c6213a630b16e9d8879
This commit is contained in:
parent
e442ac1e3f
commit
e3c56ddde6
3
mypy.ini
3
mypy.ini
|
|
@ -26,8 +26,7 @@ files =
|
|||
test/test_numpy_interop.py,
|
||||
test/test_torch.py,
|
||||
test/test_type_hints.py,
|
||||
test/test_type_info.py,
|
||||
test/test_utils.py
|
||||
test/test_type_info.py
|
||||
|
||||
|
||||
# Minimum version supported - variable annotations were introduced
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import unittest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.cuda
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
import torch.utils.cpp_extension
|
||||
|
|
@ -29,7 +28,7 @@ HAS_CUDA = torch.cuda.is_available()
|
|||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
class RandomDatasetMock(torch.utils.data.Dataset):
|
||||
class RandomDatasetMock(object):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)])
|
||||
|
|
@ -191,7 +190,7 @@ class TestCheckpoint(TestCase):
|
|||
b = torch.randn(1, 100, requires_grad=True)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg]
|
||||
checkpoint_sequential(model, 1, a, b)
|
||||
|
||||
def test_checkpoint_sequential_deprecated_no_args(self):
|
||||
class Noop(nn.Module):
|
||||
|
|
@ -201,7 +200,7 @@ class TestCheckpoint(TestCase):
|
|||
model = nn.Sequential(Noop())
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
checkpoint_sequential(model, 1) # type: ignore[call-arg]
|
||||
checkpoint_sequential(model, 1)
|
||||
|
||||
def test_checkpoint_rng_cpu(self):
|
||||
for _ in range(5):
|
||||
|
|
@ -278,7 +277,7 @@ class TestCheckpoint(TestCase):
|
|||
out = checkpoint(run_fn, input_var, input_var2)
|
||||
out[0].sum().backward()
|
||||
|
||||
def run_fn2(tensor1, tensor2):
|
||||
def run_fn(tensor1, tensor2):
|
||||
return tensor1
|
||||
input_var = torch.randn(1, 4, requires_grad=False)
|
||||
input_var2 = torch.randn(1, 4, requires_grad=True)
|
||||
|
|
@ -286,7 +285,7 @@ class TestCheckpoint(TestCase):
|
|||
RuntimeError,
|
||||
r"none of output has requires_grad=True, this checkpoint\(\) is not necessary"
|
||||
):
|
||||
out = checkpoint(run_fn2, input_var, input_var2)
|
||||
out = checkpoint(run_fn, input_var, input_var2)
|
||||
out.sum().backward()
|
||||
|
||||
class TestDataLoader(TestCase):
|
||||
|
|
@ -309,10 +308,7 @@ class TestDataLoader(TestCase):
|
|||
self.assertEqual(x1, x2)
|
||||
|
||||
def test_single_keep(self):
|
||||
# self.dataset is a Tensor here; technically not a valid input because
|
||||
# not a Dataset subclass, but needs to stay working so add ignore's
|
||||
# for type checking with mypy
|
||||
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
||||
dataloader = torch.utils.data.DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=0,
|
||||
drop_last=False)
|
||||
|
|
@ -320,7 +316,7 @@ class TestDataLoader(TestCase):
|
|||
self.assertEqual(len(list(dataiter)), 2)
|
||||
|
||||
def test_single_drop(self):
|
||||
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
||||
dataloader = torch.utils.data.DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=0,
|
||||
drop_last=True)
|
||||
|
|
@ -329,7 +325,7 @@ class TestDataLoader(TestCase):
|
|||
|
||||
@unittest.skip("FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN")
|
||||
def test_multi_keep(self):
|
||||
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
||||
dataloader = torch.utils.data.DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=2,
|
||||
drop_last=False)
|
||||
|
|
@ -337,7 +333,7 @@ class TestDataLoader(TestCase):
|
|||
self.assertEqual(len(list(dataiter)), 2)
|
||||
|
||||
def test_multi_drop(self):
|
||||
dataloader : DataLoader = DataLoader(self.dataset, # type: ignore[arg-type]
|
||||
dataloader = torch.utils.data.DataLoader(self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=2,
|
||||
drop_last=True)
|
||||
|
|
@ -351,7 +347,7 @@ test_dir = os.path.abspath(os.path.dirname(str(__file__)))
|
|||
class TestFFI(TestCase):
|
||||
def test_deprecated(self):
|
||||
with self.assertRaisesRegex(ImportError, "torch.utils.ffi is deprecated. Please use cpp extensions instead."):
|
||||
from torch.utils.ffi import create_extension # type: ignore # noqa: F401
|
||||
from torch.utils.ffi import create_extension # noqa: F401
|
||||
|
||||
|
||||
@unittest.skipIf('SKIP_TEST_BOTTLENECK' in os.environ.keys(), 'SKIP_TEST_BOTTLENECK is set')
|
||||
|
|
@ -368,9 +364,9 @@ class TestBottleneck(TestCase):
|
|||
p.kill()
|
||||
output, err = p.communicate()
|
||||
rc = p.returncode
|
||||
output_str = output.decode("ascii")
|
||||
err_str = err.decode("ascii")
|
||||
return (rc, output_str, err_str)
|
||||
output = output.decode("ascii")
|
||||
err = err.decode("ascii")
|
||||
return (rc, output, err)
|
||||
|
||||
def _run_bottleneck(self, test_file, scriptargs=''):
|
||||
curdir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
|
@ -665,7 +661,7 @@ class TestAssert(TestCase):
|
|||
# data can be passed without errors
|
||||
x = torch.randn(4, 4).fill_(1.0)
|
||||
ms(x)
|
||||
with self.assertRaisesRegex(torch.jit.Error, "foo"): # type: ignore[type-var]
|
||||
with self.assertRaisesRegex(torch.jit.Error, "foo"):
|
||||
ms(torch.tensor([False], dtype=torch.bool))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from .file_baton import FileBaton
|
|||
from ._cpp_extension_versioner import ExtensionVersioner
|
||||
from .hipify import hipify_python
|
||||
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from pkg_resources import packaging # type: ignore
|
||||
|
|
@ -980,7 +980,7 @@ def library_paths(cuda: bool = False) -> List[str]:
|
|||
|
||||
|
||||
def load(name,
|
||||
sources: Union[str, List[str]],
|
||||
sources: List[str],
|
||||
extra_cflags=None,
|
||||
extra_cuda_cflags=None,
|
||||
extra_ldflags=None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user