mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Remove remaining global set_default_dtype calls from tests (#107246)"
This reverts commit aa8ea1d787.
Reverted https://github.com/pytorch/pytorch/pull/107246 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107246#issuecomment-1693838522))
This commit is contained in:
parent
c68d0a7042
commit
161ea463e6
|
|
@ -19,6 +19,8 @@ from torch.testing._internal.common_utils import dtype2prec_DONTUSE
|
||||||
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if
|
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.double)
|
||||||
|
|
||||||
NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")
|
NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")
|
||||||
|
|
||||||
# batched grad doesn't support data parallel
|
# batched grad doesn't support data parallel
|
||||||
|
|
@ -38,11 +40,11 @@ class TestDataParallel(TestCase):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * self.t_rg + self.t_not_rg
|
return x * self.t_rg + self.t_not_rg
|
||||||
|
|
||||||
m = TestModule(torch.randn(100, device='cuda', requires_grad=True, dtype=torch.double))
|
m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
|
||||||
self.assertTrue(m.t_rg.requires_grad)
|
self.assertTrue(m.t_rg.requires_grad)
|
||||||
|
|
||||||
dpm = nn.DataParallel(m, [0, 1])
|
dpm = nn.DataParallel(m, [0, 1])
|
||||||
inp = torch.randn(2, 100, device='cuda', dtype=torch.double)
|
inp = torch.randn(2, 100, device='cuda')
|
||||||
|
|
||||||
def fn(t):
|
def fn(t):
|
||||||
return dpm(inp)
|
return dpm(inp)
|
||||||
|
|
@ -510,11 +512,11 @@ class TestDataParallel(TestCase):
|
||||||
|
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
def test_scatter_cpu(self):
|
def test_scatter_cpu(self):
|
||||||
self._test_scatter(torch.randn((4, 4), dtype=torch.double))
|
self._test_scatter(torch.randn((4, 4)))
|
||||||
|
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
def test_scatter_gpu(self):
|
def test_scatter_gpu(self):
|
||||||
self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda())
|
self._test_scatter(torch.randn((4, 4)).cuda())
|
||||||
|
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
||||||
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
|
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
|
||||||
|
|
@ -537,8 +539,8 @@ class TestDataParallel(TestCase):
|
||||||
|
|
||||||
def _test_gather(self, output_device):
|
def _test_gather(self, output_device):
|
||||||
inputs = (
|
inputs = (
|
||||||
torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double),
|
torch.randn(2, 4, device='cuda:0', requires_grad=True),
|
||||||
torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double),
|
torch.randn(2, 4, device='cuda:1', requires_grad=True),
|
||||||
)
|
)
|
||||||
result = dp.gather(inputs, output_device)
|
result = dp.gather(inputs, output_device)
|
||||||
self.assertEqual(result.size(), torch.Size([4, 4]))
|
self.assertEqual(result.size(), torch.Size([4, 4]))
|
||||||
|
|
@ -548,7 +550,7 @@ class TestDataParallel(TestCase):
|
||||||
self.assertEqual(result.get_device(), output_device)
|
self.assertEqual(result.get_device(), output_device)
|
||||||
else:
|
else:
|
||||||
self.assertFalse(result.is_cuda)
|
self.assertFalse(result.is_cuda)
|
||||||
grad = torch.randn((4, 4), dtype=torch.double)
|
grad = torch.randn((4, 4))
|
||||||
if output_device != -1:
|
if output_device != -1:
|
||||||
grad = grad.cuda(output_device)
|
grad = grad.cuda(output_device)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
|
|
@ -558,8 +560,8 @@ class TestDataParallel(TestCase):
|
||||||
|
|
||||||
# test scalar inputs, should stack into a vector in this case
|
# test scalar inputs, should stack into a vector in this case
|
||||||
inputs = (
|
inputs = (
|
||||||
torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double),
|
torch.randn((), device='cuda:0', requires_grad=True),
|
||||||
torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double),
|
torch.randn((), device='cuda:1', requires_grad=True),
|
||||||
)
|
)
|
||||||
result = dp.gather(inputs, output_device)
|
result = dp.gather(inputs, output_device)
|
||||||
self.assertEqual(result.size(), torch.Size([2]))
|
self.assertEqual(result.size(), torch.Size([2]))
|
||||||
|
|
@ -569,7 +571,7 @@ class TestDataParallel(TestCase):
|
||||||
self.assertEqual(result.get_device(), output_device)
|
self.assertEqual(result.get_device(), output_device)
|
||||||
else:
|
else:
|
||||||
self.assertFalse(result.is_cuda)
|
self.assertFalse(result.is_cuda)
|
||||||
grad = torch.randn(2, dtype=torch.double)
|
grad = torch.randn(2)
|
||||||
if output_device != -1:
|
if output_device != -1:
|
||||||
grad = grad.cuda(output_device)
|
grad = grad.cuda(output_device)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1995,8 +1995,8 @@ class TestFrozenOptimizations(JitTestCase):
|
||||||
torch.set_default_dtype(torch.double)
|
torch.set_default_dtype(torch.double)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
torch.set_default_dtype(self.default_dtype)
|
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
torch.set_default_dtype(self.default_dtype)
|
||||||
|
|
||||||
def test_conv_bn_folding(self):
|
def test_conv_bn_folding(self):
|
||||||
conv_bias = [True, False]
|
conv_bias = [True, False]
|
||||||
|
|
|
||||||
|
|
@ -1881,11 +1881,15 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
x = torch.randn(2, 3, 4).to(torch.int)
|
x = torch.randn(2, 3, 4).to(torch.int)
|
||||||
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
||||||
|
|
||||||
with common_utils.set_default_dtype(torch.float):
|
prev_default = torch.get_default_dtype()
|
||||||
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
|
||||||
|
|
||||||
with common_utils.set_default_dtype(torch.double):
|
torch.set_default_dtype(torch.float)
|
||||||
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.double)
|
||||||
|
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
|
||||||
|
|
||||||
|
torch.set_default_dtype(prev_default)
|
||||||
|
|
||||||
# In scripting x, y do not carry shape and dtype info.
|
# In scripting x, y do not carry shape and dtype info.
|
||||||
# The following test only works when onnx shape inference is enabled.
|
# The following test only works when onnx shape inference is enabled.
|
||||||
|
|
@ -1901,20 +1905,23 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
x = torch.randn(2, 3, 4).to(torch.int)
|
x = torch.randn(2, 3, 4).to(torch.int)
|
||||||
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
|
||||||
|
|
||||||
|
prev_default = torch.get_default_dtype()
|
||||||
|
|
||||||
# 1. x,y are int, and output is float.
|
# 1. x,y are int, and output is float.
|
||||||
# This can be handled by the default case, where both are cast to float.
|
# This can be handled by the default case, where both are cast to float.
|
||||||
# It works even if type of x, y are unknown.
|
# It works even if type of x, y are unknown.
|
||||||
with common_utils.set_default_dtype(torch.float):
|
torch.set_default_dtype(torch.float)
|
||||||
self.run_test(torch.jit.script(DivModule()), (x, y))
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
||||||
|
|
||||||
# 2. x,y are int, and output is double.
|
# 2. x,y are int, and output is double.
|
||||||
# This can be handled by the default case, where both are cast to double.
|
# This can be handled by the default case, where both are cast to double.
|
||||||
# It works even if type of x, y are unknown.
|
# It works even if type of x, y are unknown.
|
||||||
with common_utils.set_default_dtype(torch.double):
|
torch.set_default_dtype(torch.double)
|
||||||
self.run_test(torch.jit.script(DivModule()), (x, y))
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
||||||
|
|
||||||
# 3. x is int, y is double, and output is double.
|
# 3. x is int, y is double, and output is double.
|
||||||
# This can only be handled when both type of x and y are known.
|
# This can only be handled when both type of x and y are known.
|
||||||
|
torch.set_default_dtype(prev_default)
|
||||||
x = torch.randn(2, 3, 4).to(torch.int)
|
x = torch.randn(2, 3, 4).to(torch.int)
|
||||||
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
|
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
|
||||||
self.run_test(torch.jit.script(DivModule()), (x, y))
|
self.run_test(torch.jit.script(DivModule()), (x, y))
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
onlyCPU,
|
onlyCPU,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests, set_default_dtype
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
from torch.testing._internal.common_dtype import complex_types
|
from torch.testing._internal.common_dtype import complex_types
|
||||||
|
|
||||||
devices = (torch.device('cpu'), torch.device('cuda:0'))
|
devices = (torch.device('cpu'), torch.device('cuda:0'))
|
||||||
|
|
@ -21,8 +21,10 @@ class TestComplexTensor(TestCase):
|
||||||
@dtypes(torch.float32, torch.float64)
|
@dtypes(torch.float32, torch.float64)
|
||||||
def test_dtype_inference(self, device, dtype):
|
def test_dtype_inference(self, device, dtype):
|
||||||
# issue: https://github.com/pytorch/pytorch/issues/36834
|
# issue: https://github.com/pytorch/pytorch/issues/36834
|
||||||
with set_default_dtype(dtype):
|
default_dtype = torch.get_default_dtype()
|
||||||
x = torch.tensor([3., 3. + 5.j], device=device)
|
torch.set_default_dtype(dtype)
|
||||||
|
x = torch.tensor([3., 3. + 5.j], device=device)
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
|
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
# Owner(s): ["module: cpp"]
|
# Owner(s): ["module: cpp"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
# NN tests use double as the default dtype
|
||||||
|
torch.set_default_dtype(torch.double)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Owner(s): ["module: nvfuser"]
|
# Owner(s): ["module: nvfuser"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import set_default_dtype, TestCase
|
from torch.testing._internal.common_utils import set_default_dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from _nvfuser.test_torchscript import * # noqa: F403,F401
|
from _nvfuser.test_torchscript import * # noqa: F403,F401
|
||||||
|
|
@ -13,5 +13,4 @@ except ImportError:
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# TODO: Update nvfuser to work with float default dtype
|
# TODO: Update nvfuser to work with float default dtype
|
||||||
with set_default_dtype(torch.double):
|
with set_default_dtype(torch.double):
|
||||||
TestCase._avoid_default_dtype_check = True
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,8 @@ from torch.distributions.binomial import Binomial
|
||||||
import torch.backends.opt_einsum as opt_einsum
|
import torch.backends.opt_einsum as opt_einsum
|
||||||
|
|
||||||
# Protects against includes accidentally setting the default dtype
|
# Protects against includes accidentally setting the default dtype
|
||||||
|
# NOTE: jit_metaprogramming_utils sets the default dtype to double!
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
assert torch.get_default_dtype() is torch.float32
|
assert torch.get_default_dtype() is torch.float32
|
||||||
|
|
||||||
if TEST_SCIPY:
|
if TEST_SCIPY:
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ from torch.testing._internal.common_utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
# Protects against includes accidentally setting the default dtype
|
# Protects against includes accidentally setting the default dtype
|
||||||
|
# NOTE: jit_metaprogramming_utils sets the default dtype to double!
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
assert torch.get_default_dtype() is torch.float32
|
assert torch.get_default_dtype() is torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,9 @@ from torch.testing._internal.common_methods_invocations import op_db
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, ops, OpDTypes)
|
(instantiate_device_type_tests, ops, OpDTypes)
|
||||||
|
|
||||||
|
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033
|
# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033
|
||||||
# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The
|
# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The
|
||||||
# issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33
|
# issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,9 @@ from torch.testing._internal.custom_op_db import custom_op_db
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, ops, OpDTypes)
|
(instantiate_device_type_tests, ops, OpDTypes)
|
||||||
|
|
||||||
|
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
# gradcheck requires double precision
|
# gradcheck requires double precision
|
||||||
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||||
allowed_dtypes=[torch.double, torch.cdouble])
|
allowed_dtypes=[torch.double, torch.cdouble])
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,9 @@ from torch.testing._internal.jit_metaprogramming_utils import create_script_fn,
|
||||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
|
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
# variant testing is only done with torch.float and torch.cfloat to avoid
|
# variant testing is only done with torch.float and torch.cfloat to avoid
|
||||||
# excessive test times and maximize signal to noise ratio
|
# excessive test times and maximize signal to noise ratio
|
||||||
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
|
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
|
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
|
||||||
set_default_dtype, set_default_tensor_type,
|
|
||||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo)
|
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo)
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
||||||
|
|
@ -1966,36 +1965,37 @@ class TestTensorCreation(TestCase):
|
||||||
# TODO: this test should be updated
|
# TODO: this test should be updated
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_constructor_dtypes(self, device):
|
def test_constructor_dtypes(self, device):
|
||||||
|
default_type = torch.tensor([]).type()
|
||||||
self.assertIs(torch.tensor([]).dtype, torch.get_default_dtype())
|
self.assertIs(torch.tensor([]).dtype, torch.get_default_dtype())
|
||||||
|
|
||||||
self.assertIs(torch.uint8, torch.ByteTensor.dtype)
|
self.assertIs(torch.uint8, torch.ByteTensor.dtype)
|
||||||
self.assertIs(torch.float32, torch.FloatTensor.dtype)
|
self.assertIs(torch.float32, torch.FloatTensor.dtype)
|
||||||
self.assertIs(torch.float64, torch.DoubleTensor.dtype)
|
self.assertIs(torch.float64, torch.DoubleTensor.dtype)
|
||||||
|
|
||||||
with set_default_tensor_type('torch.FloatTensor'):
|
torch.set_default_tensor_type('torch.FloatTensor')
|
||||||
self.assertIs(torch.float32, torch.get_default_dtype())
|
self.assertIs(torch.float32, torch.get_default_dtype())
|
||||||
self.assertIs(torch.FloatStorage, torch.Storage)
|
self.assertIs(torch.FloatStorage, torch.Storage)
|
||||||
|
|
||||||
# only floating-point types are supported as the default type
|
# only floating-point types are supported as the default type
|
||||||
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor'))
|
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor'))
|
||||||
|
|
||||||
with set_default_dtype(torch.float64):
|
torch.set_default_dtype(torch.float64)
|
||||||
self.assertIs(torch.float64, torch.get_default_dtype())
|
self.assertIs(torch.float64, torch.get_default_dtype())
|
||||||
self.assertIs(torch.DoubleStorage, torch.Storage)
|
self.assertIs(torch.DoubleStorage, torch.Storage)
|
||||||
|
|
||||||
with set_default_tensor_type(torch.FloatTensor):
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
self.assertIs(torch.float32, torch.get_default_dtype())
|
self.assertIs(torch.float32, torch.get_default_dtype())
|
||||||
self.assertIs(torch.FloatStorage, torch.Storage)
|
self.assertIs(torch.FloatStorage, torch.Storage)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
with set_default_tensor_type(torch.cuda.FloatTensor):
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||||
self.assertIs(torch.float32, torch.get_default_dtype())
|
self.assertIs(torch.float32, torch.get_default_dtype())
|
||||||
self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
|
self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
|
||||||
self.assertIs(torch.cuda.FloatStorage, torch.Storage)
|
self.assertIs(torch.cuda.FloatStorage, torch.Storage)
|
||||||
|
|
||||||
with set_default_dtype(torch.float64):
|
torch.set_default_dtype(torch.float64)
|
||||||
self.assertIs(torch.float64, torch.get_default_dtype())
|
self.assertIs(torch.float64, torch.get_default_dtype())
|
||||||
self.assertIs(torch.cuda.DoubleStorage, torch.Storage)
|
self.assertIs(torch.cuda.DoubleStorage, torch.Storage)
|
||||||
|
|
||||||
# don't allow passing dtype to set_default_tensor_type
|
# don't allow passing dtype to set_default_tensor_type
|
||||||
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32))
|
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32))
|
||||||
|
|
@ -2008,11 +2008,12 @@ class TestTensorCreation(TestCase):
|
||||||
torch.float,
|
torch.float,
|
||||||
torch.double,
|
torch.double,
|
||||||
torch.bfloat16):
|
torch.bfloat16):
|
||||||
with set_default_dtype(t):
|
torch.set_default_dtype(t)
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
self.assertRaises(TypeError, lambda: torch.set_default_dtype(t))
|
self.assertRaises(TypeError, lambda: torch.set_default_dtype(t))
|
||||||
|
|
||||||
|
torch.set_default_tensor_type(default_type)
|
||||||
|
|
||||||
# TODO: this test should be updated
|
# TODO: this test should be updated
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_constructor_device_legacy(self, device):
|
def test_constructor_device_legacy(self, device):
|
||||||
|
|
@ -2048,10 +2049,14 @@ class TestTensorCreation(TestCase):
|
||||||
self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu'))
|
self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu'))
|
||||||
self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu'))
|
self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu'))
|
||||||
|
|
||||||
with set_default_tensor_type(torch.cuda.FloatTensor):
|
default_type = torch.Tensor().type()
|
||||||
self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu'))
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||||
self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu'))
|
self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu'))
|
||||||
self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu'))
|
self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu'))
|
||||||
|
self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu'))
|
||||||
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||||
|
torch.set_default_tensor_type(default_type)
|
||||||
|
|
||||||
x = torch.randn((3,), device='cuda')
|
x = torch.randn((3,), device='cuda')
|
||||||
self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
|
self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
|
||||||
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
|
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
|
||||||
|
|
@ -2153,6 +2158,8 @@ class TestTensorCreation(TestCase):
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_tensor_factory_type_inference(self, device):
|
def test_tensor_factory_type_inference(self, device):
|
||||||
def test_inference(default_dtype):
|
def test_inference(default_dtype):
|
||||||
|
saved_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
default_complex_dtype = torch.complex64 if default_dtype == torch.float32 else torch.complex128
|
default_complex_dtype = torch.complex64 if default_dtype == torch.float32 else torch.complex128
|
||||||
self.assertIs(default_dtype, torch.tensor(()).dtype)
|
self.assertIs(default_dtype, torch.tensor(()).dtype)
|
||||||
self.assertIs(default_dtype, torch.tensor(5.).dtype)
|
self.assertIs(default_dtype, torch.tensor(5.).dtype)
|
||||||
|
|
@ -2174,10 +2181,10 @@ class TestTensorCreation(TestCase):
|
||||||
self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
|
self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
|
||||||
self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
|
self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
|
||||||
self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
|
self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
|
||||||
|
torch.set_default_dtype(saved_dtype)
|
||||||
|
|
||||||
for dtype in [torch.float64, torch.float32]:
|
test_inference(torch.float64)
|
||||||
with set_default_dtype(dtype):
|
test_inference(torch.float32)
|
||||||
test_inference(dtype)
|
|
||||||
|
|
||||||
# TODO: this test should be updated
|
# TODO: this test should be updated
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
|
|
@ -2464,6 +2471,8 @@ class TestTensorCreation(TestCase):
|
||||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
|
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_arange_inference(self, device):
|
def test_arange_inference(self, device):
|
||||||
|
saved_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
# end only
|
# end only
|
||||||
self.assertIs(torch.float32, torch.arange(1.).dtype)
|
self.assertIs(torch.float32, torch.arange(1.).dtype)
|
||||||
self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype)
|
self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype)
|
||||||
|
|
@ -2492,6 +2501,7 @@ class TestTensorCreation(TestCase):
|
||||||
torch.arange(torch.tensor(1),
|
torch.arange(torch.tensor(1),
|
||||||
torch.tensor(3),
|
torch.tensor(3),
|
||||||
torch.tensor(1, dtype=torch.int16)).dtype)
|
torch.tensor(1, dtype=torch.int16)).dtype)
|
||||||
|
torch.set_default_dtype(saved_dtype)
|
||||||
|
|
||||||
# cannot call storage() on meta tensor
|
# cannot call storage() on meta tensor
|
||||||
@skipMeta
|
@skipMeta
|
||||||
|
|
@ -2808,24 +2818,28 @@ class TestTensorCreation(TestCase):
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_tensor_factory_gpu_type_inference(self, device):
|
def test_tensor_factory_gpu_type_inference(self, device):
|
||||||
with set_default_tensor_type(torch.cuda.DoubleTensor):
|
saved_type = torch.tensor([]).type()
|
||||||
with set_default_dtype(torch.float32):
|
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
|
||||||
self.assertIs(torch.float32, torch.tensor(0.).dtype)
|
torch.set_default_dtype(torch.float32)
|
||||||
self.assertEqual(torch.device(device), torch.tensor(0.).device)
|
self.assertIs(torch.float32, torch.tensor(0.).dtype)
|
||||||
with set_default_dtype(torch.float64):
|
self.assertEqual(torch.device(device), torch.tensor(0.).device)
|
||||||
self.assertIs(torch.float64, torch.tensor(0.).dtype)
|
torch.set_default_dtype(torch.float64)
|
||||||
self.assertEqual(torch.device(device), torch.tensor(0.).device)
|
self.assertIs(torch.float64, torch.tensor(0.).dtype)
|
||||||
|
self.assertEqual(torch.device(device), torch.tensor(0.).device)
|
||||||
|
torch.set_default_tensor_type(saved_type)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_tensor_factory_gpu_type(self, device):
|
def test_tensor_factory_gpu_type(self, device):
|
||||||
with set_default_tensor_type(torch.cuda.FloatTensor):
|
saved_type = torch.tensor([]).type()
|
||||||
x = torch.zeros((5, 5))
|
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||||
self.assertIs(torch.float32, x.dtype)
|
x = torch.zeros((5, 5))
|
||||||
self.assertTrue(x.is_cuda)
|
self.assertIs(torch.float32, x.dtype)
|
||||||
with set_default_tensor_type(torch.cuda.DoubleTensor):
|
self.assertTrue(x.is_cuda)
|
||||||
x = torch.zeros((5, 5))
|
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
|
||||||
self.assertIs(torch.float64, x.dtype)
|
x = torch.zeros((5, 5))
|
||||||
self.assertTrue(x.is_cuda)
|
self.assertIs(torch.float64, x.dtype)
|
||||||
|
self.assertTrue(x.is_cuda)
|
||||||
|
torch.set_default_tensor_type(saved_type)
|
||||||
|
|
||||||
@skipCPUIf(True, 'compares device with cpu')
|
@skipCPUIf(True, 'compares device with cpu')
|
||||||
@dtypes(torch.int, torch.long, torch.float, torch.double)
|
@dtypes(torch.int, torch.long, torch.float, torch.double)
|
||||||
|
|
@ -3067,23 +3081,27 @@ class TestTensorCreation(TestCase):
|
||||||
def test_full_inference(self, device, dtype):
|
def test_full_inference(self, device, dtype):
|
||||||
size = (2, 2)
|
size = (2, 2)
|
||||||
|
|
||||||
with set_default_dtype(dtype):
|
prev_default = torch.get_default_dtype()
|
||||||
# Tests bool fill value inference
|
torch.set_default_dtype(dtype)
|
||||||
t = torch.full(size, True)
|
|
||||||
self.assertEqual(t.dtype, torch.bool)
|
|
||||||
|
|
||||||
# Tests integer fill value inference
|
# Tests bool fill value inference
|
||||||
t = torch.full(size, 1)
|
t = torch.full(size, True)
|
||||||
self.assertEqual(t.dtype, torch.long)
|
self.assertEqual(t.dtype, torch.bool)
|
||||||
|
|
||||||
# Tests float fill value inference
|
# Tests integer fill value inference
|
||||||
t = torch.full(size, 1.)
|
t = torch.full(size, 1)
|
||||||
self.assertEqual(t.dtype, dtype)
|
self.assertEqual(t.dtype, torch.long)
|
||||||
|
|
||||||
# Tests complex inference
|
# Tests float fill value inference
|
||||||
t = torch.full(size, (1 + 1j))
|
t = torch.full(size, 1.)
|
||||||
ctype = torch.complex128 if dtype is torch.double else torch.complex64
|
self.assertEqual(t.dtype, dtype)
|
||||||
self.assertEqual(t.dtype, ctype)
|
|
||||||
|
# Tests complex inference
|
||||||
|
t = torch.full(size, (1 + 1j))
|
||||||
|
ctype = torch.complex128 if dtype is torch.double else torch.complex64
|
||||||
|
self.assertEqual(t.dtype, ctype)
|
||||||
|
|
||||||
|
torch.set_default_dtype(prev_default)
|
||||||
|
|
||||||
def test_full_out(self, device):
|
def test_full_out(self, device):
|
||||||
size = (5,)
|
size = (5,)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||||
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
|
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
|
||||||
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
||||||
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
|
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
|
||||||
TEST_WITH_CROSSREF, skipIfTorchDynamo, set_default_dtype,
|
TEST_WITH_CROSSREF, skipIfTorchDynamo,
|
||||||
skipCUDAMemoryLeakCheckIf, BytesIOContext,
|
skipCUDAMemoryLeakCheckIf, BytesIOContext,
|
||||||
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
|
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
|
||||||
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
|
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
|
||||||
|
|
@ -6010,7 +6010,6 @@ class TestTorch(TestCase):
|
||||||
self.assertEqual(added, -tensor)
|
self.assertEqual(added, -tensor)
|
||||||
|
|
||||||
@skipIfTorchInductor("AssertionError: RuntimeError not raised by <lambda>")
|
@skipIfTorchInductor("AssertionError: RuntimeError not raised by <lambda>")
|
||||||
@set_default_dtype(torch.double)
|
|
||||||
def test_index_add_correctness(self):
|
def test_index_add_correctness(self):
|
||||||
# Check whether index_add can get correct result when
|
# Check whether index_add can get correct result when
|
||||||
# alpha is 1, and dtype of index is torch.long,
|
# alpha is 1, and dtype of index is torch.long,
|
||||||
|
|
@ -7274,21 +7273,21 @@ tensor([4.0000+0.j, inf+0.j, 1.5000+infj, -inf+4.j, 0.0000+0.j, nan+infj
|
||||||
self.assertExpectedInline(str(y), expected_str)
|
self.assertExpectedInline(str(y), expected_str)
|
||||||
|
|
||||||
# test dtype
|
# test dtype
|
||||||
with set_default_dtype(torch.float):
|
torch.set_default_dtype(torch.float)
|
||||||
x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
|
x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
|
||||||
self.assertEqual(x.__repr__(), str(x))
|
self.assertEqual(x.__repr__(), str(x))
|
||||||
expected_str = '''\
|
expected_str = '''\
|
||||||
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
|
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
|
||||||
inf], dtype=torch.float64)'''
|
inf], dtype=torch.float64)'''
|
||||||
self.assertExpectedInline(str(x), expected_str)
|
self.assertExpectedInline(str(x), expected_str)
|
||||||
|
|
||||||
# test changing default dtype
|
# test changing default dtype
|
||||||
with set_default_dtype(torch.float64):
|
torch.set_default_dtype(torch.float64)
|
||||||
self.assertEqual(x.__repr__(), str(x))
|
self.assertEqual(x.__repr__(), str(x))
|
||||||
expected_str = '''\
|
expected_str = '''\
|
||||||
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
|
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
|
||||||
inf])'''
|
inf])'''
|
||||||
self.assertExpectedInline(str(x), expected_str)
|
self.assertExpectedInline(str(x), expected_str)
|
||||||
|
|
||||||
# test summary
|
# test summary
|
||||||
x = torch.zeros(10000)
|
x = torch.zeros(10000)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Owner(s): ["module: typing"]
|
# Owner(s): ["module: typing"]
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests, set_default_dtype
|
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
|
|
@ -38,6 +38,7 @@ class TestDTypeInfo(TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_finfo(self):
|
def test_finfo(self):
|
||||||
|
initial_default_type = torch.get_default_dtype()
|
||||||
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
|
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
|
||||||
x = torch.zeros((2, 2), dtype=dtype)
|
x = torch.zeros((2, 2), dtype=dtype)
|
||||||
xinfo = torch.finfo(x.dtype)
|
xinfo = torch.finfo(x.dtype)
|
||||||
|
|
@ -51,8 +52,8 @@ class TestDTypeInfo(TestCase):
|
||||||
self.assertEqual(xinfo.resolution, xninfo.resolution)
|
self.assertEqual(xinfo.resolution, xninfo.resolution)
|
||||||
self.assertEqual(xinfo.dtype, xninfo.dtype)
|
self.assertEqual(xinfo.dtype, xninfo.dtype)
|
||||||
if not dtype.is_complex:
|
if not dtype.is_complex:
|
||||||
with set_default_dtype(dtype):
|
torch.set_default_dtype(dtype)
|
||||||
self.assertEqual(torch.finfo(dtype), torch.finfo())
|
self.assertEqual(torch.finfo(dtype), torch.finfo())
|
||||||
|
|
||||||
# Special test case for BFloat16 type
|
# Special test case for BFloat16 type
|
||||||
x = torch.zeros((2, 2), dtype=torch.bfloat16)
|
x = torch.zeros((2, 2), dtype=torch.bfloat16)
|
||||||
|
|
@ -65,8 +66,11 @@ class TestDTypeInfo(TestCase):
|
||||||
self.assertEqual(xinfo.tiny, xinfo.smallest_normal)
|
self.assertEqual(xinfo.tiny, xinfo.smallest_normal)
|
||||||
self.assertEqual(xinfo.resolution, 0.01)
|
self.assertEqual(xinfo.resolution, 0.01)
|
||||||
self.assertEqual(xinfo.dtype, "bfloat16")
|
self.assertEqual(xinfo.dtype, "bfloat16")
|
||||||
with set_default_dtype(x.dtype):
|
torch.set_default_dtype(x.dtype)
|
||||||
self.assertEqual(torch.finfo(x.dtype), torch.finfo())
|
self.assertEqual(torch.finfo(x.dtype), torch.finfo())
|
||||||
|
|
||||||
|
# Restore the default type to ensure that the test has no side effect
|
||||||
|
torch.set_default_dtype(initial_default_type)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1785,14 +1785,14 @@ new_module_tests = [
|
||||||
dict(
|
dict(
|
||||||
fullname='EmbeddingBag_sparse',
|
fullname='EmbeddingBag_sparse',
|
||||||
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
|
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
|
||||||
cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
|
cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)',
|
||||||
input_fn=lambda: torch.randperm(2).repeat(1, 2),
|
input_fn=lambda: torch.randperm(2).repeat(1, 2),
|
||||||
check_gradgrad=False,
|
check_gradgrad=False,
|
||||||
has_sparse_gradients=True,
|
has_sparse_gradients=True,
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
|
constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
|
||||||
cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
|
cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)',
|
||||||
input_fn=lambda: torch.randperm(2).repeat(1, 2),
|
input_fn=lambda: torch.randperm(2).repeat(1, 2),
|
||||||
fullname='Embedding_sparse',
|
fullname='Embedding_sparse',
|
||||||
check_gradgrad=False,
|
check_gradgrad=False,
|
||||||
|
|
@ -3168,7 +3168,7 @@ criterion_tests = [
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
module_name='MSELoss',
|
module_name='MSELoss',
|
||||||
input_fn=lambda: torch.rand((2, 3, 4, 5), dtype=torch.double),
|
input_size=(2, 3, 4, 5),
|
||||||
target_fn=lambda: torch.randn((2, 3, 4, 5), dtype=torch.double, requires_grad=True),
|
target_fn=lambda: torch.randn((2, 3, 4, 5), dtype=torch.double, requires_grad=True),
|
||||||
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
|
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
|
||||||
if get_reduction(m) == 'mean' else 1)),
|
if get_reduction(m) == 'mean' else 1)),
|
||||||
|
|
@ -3314,9 +3314,9 @@ criterion_tests = [
|
||||||
dict(
|
dict(
|
||||||
module_name='MultiMarginLoss',
|
module_name='MultiMarginLoss',
|
||||||
constructor_args=(1, 1., torch.rand(10, dtype=torch.double)),
|
constructor_args=(1, 1., torch.rand(10, dtype=torch.double)),
|
||||||
cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10).to(torch::kFloat64))',
|
cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10))',
|
||||||
legacy_constructor_args=(1, torch.rand(10, dtype=torch.double)),
|
legacy_constructor_args=(1, torch.rand(10, dtype=torch.double)),
|
||||||
input_fn=lambda: torch.rand(5, 10, dtype=torch.double),
|
input_size=(5, 10),
|
||||||
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
|
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
|
||||||
reference_fn=lambda i, t, m:
|
reference_fn=lambda i, t, m:
|
||||||
multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
|
multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
|
||||||
|
|
@ -3403,7 +3403,7 @@ criterion_tests = [
|
||||||
dict(
|
dict(
|
||||||
module_name='BCEWithLogitsLoss',
|
module_name='BCEWithLogitsLoss',
|
||||||
constructor_args=(torch.rand(10, dtype=torch.double),),
|
constructor_args=(torch.rand(10, dtype=torch.double),),
|
||||||
cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10).to(torch::kFloat64))',
|
cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10))',
|
||||||
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
|
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
|
||||||
target_fn=lambda: torch.randn(15, 10).gt(0).to(torch.get_default_dtype()),
|
target_fn=lambda: torch.randn(15, 10).gt(0).to(torch.get_default_dtype()),
|
||||||
desc='weights',
|
desc='weights',
|
||||||
|
|
@ -3412,7 +3412,7 @@ criterion_tests = [
|
||||||
dict(
|
dict(
|
||||||
module_name='BCEWithLogitsLoss',
|
module_name='BCEWithLogitsLoss',
|
||||||
constructor_args=(torch.rand((), dtype=torch.double),),
|
constructor_args=(torch.rand((), dtype=torch.double),),
|
||||||
cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}).to(torch::kFloat64))',
|
cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}))',
|
||||||
input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
|
input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
|
||||||
target_fn=lambda: torch.randn(()).gt(0).to(torch.get_default_dtype()),
|
target_fn=lambda: torch.randn(()).gt(0).to(torch.get_default_dtype()),
|
||||||
desc='scalar_weights',
|
desc='scalar_weights',
|
||||||
|
|
@ -3826,7 +3826,7 @@ criterion_tests = [
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
module_name='MSELoss',
|
module_name='MSELoss',
|
||||||
input_fn=lambda: torch.rand((), dtype=torch.double),
|
input_size=(),
|
||||||
target_fn=lambda: torch.randn((), requires_grad=True, dtype=torch.double),
|
target_fn=lambda: torch.randn((), requires_grad=True, dtype=torch.double),
|
||||||
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
|
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
|
||||||
(i.numel() if get_reduction(m) == 'mean' else 1)),
|
(i.numel() if get_reduction(m) == 'mean' else 1)),
|
||||||
|
|
|
||||||
|
|
@ -1648,15 +1648,6 @@ def set_default_dtype(dtype):
|
||||||
finally:
|
finally:
|
||||||
torch.set_default_dtype(saved_dtype)
|
torch.set_default_dtype(saved_dtype)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def set_default_tensor_type(tensor_type):
|
|
||||||
saved_tensor_type = torch.tensor([]).type()
|
|
||||||
torch.set_default_tensor_type(tensor_type)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
torch.set_default_tensor_type(saved_tensor_type)
|
|
||||||
|
|
||||||
def iter_indices(tensor):
|
def iter_indices(tensor):
|
||||||
if tensor.dim() == 0:
|
if tensor.dim() == 0:
|
||||||
return range(0)
|
return range(0)
|
||||||
|
|
@ -2245,8 +2236,6 @@ class TestCase(expecttest.TestCase):
|
||||||
_precision: float = 0
|
_precision: float = 0
|
||||||
_rel_tol: float = 0
|
_rel_tol: float = 0
|
||||||
|
|
||||||
_avoid_default_dtype_check: bool = False
|
|
||||||
|
|
||||||
# checker to early terminate test suite if unrecoverable failure occurs.
|
# checker to early terminate test suite if unrecoverable failure occurs.
|
||||||
def _should_stop_test_suite(self):
|
def _should_stop_test_suite(self):
|
||||||
if torch.cuda.is_initialized():
|
if torch.cuda.is_initialized():
|
||||||
|
|
@ -2563,9 +2552,6 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
||||||
# decorator to disable the invariant checks.
|
# decorator to disable the invariant checks.
|
||||||
torch.sparse.check_sparse_tensor_invariants.enable()
|
torch.sparse.check_sparse_tensor_invariants.enable()
|
||||||
|
|
||||||
if not self._avoid_default_dtype_check:
|
|
||||||
assert torch.get_default_dtype() == torch.float
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# There exists test cases that override TestCase.setUp
|
# There exists test cases that override TestCase.setUp
|
||||||
# definition, so we cannot assume that _check_invariants
|
# definition, so we cannot assume that _check_invariants
|
||||||
|
|
@ -2577,9 +2563,6 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
||||||
else:
|
else:
|
||||||
torch.sparse.check_sparse_tensor_invariants.disable()
|
torch.sparse.check_sparse_tensor_invariants.disable()
|
||||||
|
|
||||||
if not self._avoid_default_dtype_check:
|
|
||||||
assert torch.get_default_dtype() == torch.float
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_crow_indices(n_rows, n_cols, nnz,
|
def _make_crow_indices(n_rows, n_cols, nnz,
|
||||||
*, device, dtype, random=True):
|
*, device, dtype, random=True):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user