Revert "Remove remaining global set_default_dtype calls from tests (#107246)"

This reverts commit aa8ea1d787.

Reverted https://github.com/pytorch/pytorch/pull/107246 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107246#issuecomment-1693838522))
This commit is contained in:
PyTorch MergeBot 2023-08-25 19:34:55 +00:00
parent c68d0a7042
commit 161ea463e6
17 changed files with 873 additions and 899 deletions

View File

@ -19,6 +19,8 @@ from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if
import torch.nn.functional as F import torch.nn.functional as F
torch.set_default_dtype(torch.double)
NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")
# batched grad doesn't support data parallel # batched grad doesn't support data parallel
@ -38,11 +40,11 @@ class TestDataParallel(TestCase):
def forward(self, x): def forward(self, x):
return x * self.t_rg + self.t_not_rg return x * self.t_rg + self.t_not_rg
m = TestModule(torch.randn(100, device='cuda', requires_grad=True, dtype=torch.double)) m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
self.assertTrue(m.t_rg.requires_grad) self.assertTrue(m.t_rg.requires_grad)
dpm = nn.DataParallel(m, [0, 1]) dpm = nn.DataParallel(m, [0, 1])
inp = torch.randn(2, 100, device='cuda', dtype=torch.double) inp = torch.randn(2, 100, device='cuda')
def fn(t): def fn(t):
return dpm(inp) return dpm(inp)
@ -510,11 +512,11 @@ class TestDataParallel(TestCase):
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
def test_scatter_cpu(self): def test_scatter_cpu(self):
self._test_scatter(torch.randn((4, 4), dtype=torch.double)) self._test_scatter(torch.randn((4, 4)))
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
def test_scatter_gpu(self): def test_scatter_gpu(self):
self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda()) self._test_scatter(torch.randn((4, 4)).cuda())
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed") @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
@ -537,8 +539,8 @@ class TestDataParallel(TestCase):
def _test_gather(self, output_device): def _test_gather(self, output_device):
inputs = ( inputs = (
torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double), torch.randn(2, 4, device='cuda:0', requires_grad=True),
torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double), torch.randn(2, 4, device='cuda:1', requires_grad=True),
) )
result = dp.gather(inputs, output_device) result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([4, 4])) self.assertEqual(result.size(), torch.Size([4, 4]))
@ -548,7 +550,7 @@ class TestDataParallel(TestCase):
self.assertEqual(result.get_device(), output_device) self.assertEqual(result.get_device(), output_device)
else: else:
self.assertFalse(result.is_cuda) self.assertFalse(result.is_cuda)
grad = torch.randn((4, 4), dtype=torch.double) grad = torch.randn((4, 4))
if output_device != -1: if output_device != -1:
grad = grad.cuda(output_device) grad = grad.cuda(output_device)
result.backward(grad) result.backward(grad)
@ -558,8 +560,8 @@ class TestDataParallel(TestCase):
# test scalar inputs, should stack into a vector in this case # test scalar inputs, should stack into a vector in this case
inputs = ( inputs = (
torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double), torch.randn((), device='cuda:0', requires_grad=True),
torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double), torch.randn((), device='cuda:1', requires_grad=True),
) )
result = dp.gather(inputs, output_device) result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([2])) self.assertEqual(result.size(), torch.Size([2]))
@ -569,7 +571,7 @@ class TestDataParallel(TestCase):
self.assertEqual(result.get_device(), output_device) self.assertEqual(result.get_device(), output_device)
else: else:
self.assertFalse(result.is_cuda) self.assertFalse(result.is_cuda)
grad = torch.randn(2, dtype=torch.double) grad = torch.randn(2)
if output_device != -1: if output_device != -1:
grad = grad.cuda(output_device) grad = grad.cuda(output_device)
result.backward(grad) result.backward(grad)

File diff suppressed because it is too large Load Diff

View File

@ -1995,8 +1995,8 @@ class TestFrozenOptimizations(JitTestCase):
torch.set_default_dtype(torch.double) torch.set_default_dtype(torch.double)
def tearDown(self): def tearDown(self):
torch.set_default_dtype(self.default_dtype)
super().tearDown() super().tearDown()
torch.set_default_dtype(self.default_dtype)
def test_conv_bn_folding(self): def test_conv_bn_folding(self):
conv_bias = [True, False] conv_bias = [True, False]

View File

@ -1881,11 +1881,15 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(2, 3, 4).to(torch.int) x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
with common_utils.set_default_dtype(torch.float): prev_default = torch.get_default_dtype()
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
with common_utils.set_default_dtype(torch.double): torch.set_default_dtype(torch.float)
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y)) self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
torch.set_default_dtype(torch.double)
self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
torch.set_default_dtype(prev_default)
# In scripting x, y do not carry shape and dtype info. # In scripting x, y do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled. # The following test only works when onnx shape inference is enabled.
@ -1901,20 +1905,23 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(2, 3, 4).to(torch.int) x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
prev_default = torch.get_default_dtype()
# 1. x,y are int, and output is float. # 1. x,y are int, and output is float.
# This can be handled by the default case, where both are cast to float. # This can be handled by the default case, where both are cast to float.
# It works even if type of x, y are unknown. # It works even if type of x, y are unknown.
with common_utils.set_default_dtype(torch.float): torch.set_default_dtype(torch.float)
self.run_test(torch.jit.script(DivModule()), (x, y)) self.run_test(torch.jit.script(DivModule()), (x, y))
# 2. x,y are int, and output is double. # 2. x,y are int, and output is double.
# This can be handled by the default case, where both are cast to double. # This can be handled by the default case, where both are cast to double.
# It works even if type of x, y are unknown. # It works even if type of x, y are unknown.
with common_utils.set_default_dtype(torch.double): torch.set_default_dtype(torch.double)
self.run_test(torch.jit.script(DivModule()), (x, y)) self.run_test(torch.jit.script(DivModule()), (x, y))
# 3. x is int, y is double, and output is double. # 3. x is int, y is double, and output is double.
# This can only be handled when both type of x and y are known. # This can only be handled when both type of x and y are known.
torch.set_default_dtype(prev_default)
x = torch.randn(2, 3, 4).to(torch.int) x = torch.randn(2, 3, 4).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double) y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double)
self.run_test(torch.jit.script(DivModule()), (x, y)) self.run_test(torch.jit.script(DivModule()), (x, y))

View File

@ -6,7 +6,7 @@ from torch.testing._internal.common_device_type import (
dtypes, dtypes,
onlyCPU, onlyCPU,
) )
from torch.testing._internal.common_utils import TestCase, run_tests, set_default_dtype from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_dtype import complex_types from torch.testing._internal.common_dtype import complex_types
devices = (torch.device('cpu'), torch.device('cuda:0')) devices = (torch.device('cpu'), torch.device('cuda:0'))
@ -21,8 +21,10 @@ class TestComplexTensor(TestCase):
@dtypes(torch.float32, torch.float64) @dtypes(torch.float32, torch.float64)
def test_dtype_inference(self, device, dtype): def test_dtype_inference(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/36834 # issue: https://github.com/pytorch/pytorch/issues/36834
with set_default_dtype(dtype): default_dtype = torch.get_default_dtype()
x = torch.tensor([3., 3. + 5.j], device=device) torch.set_default_dtype(dtype)
x = torch.tensor([3., 3. + 5.j], device=device)
torch.set_default_dtype(default_dtype)
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat) self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
@onlyCPU @onlyCPU

View File

@ -1,5 +1,8 @@
# Owner(s): ["module: cpp"] # Owner(s): ["module: cpp"]
import torch
# NN tests use double as the default dtype
torch.set_default_dtype(torch.double)
import os import os

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: nvfuser"] # Owner(s): ["module: nvfuser"]
import torch import torch
from torch.testing._internal.common_utils import set_default_dtype, TestCase from torch.testing._internal.common_utils import set_default_dtype
try: try:
from _nvfuser.test_torchscript import * # noqa: F403,F401 from _nvfuser.test_torchscript import * # noqa: F403,F401
@ -13,5 +13,4 @@ except ImportError:
if __name__ == '__main__': if __name__ == '__main__':
# TODO: Update nvfuser to work with float default dtype # TODO: Update nvfuser to work with float default dtype
with set_default_dtype(torch.double): with set_default_dtype(torch.double):
TestCase._avoid_default_dtype_check = True
run_tests() run_tests()

View File

@ -35,6 +35,8 @@ from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum import torch.backends.opt_einsum as opt_einsum
# Protects against includes accidentally setting the default dtype # Protects against includes accidentally setting the default dtype
# NOTE: jit_metaprogramming_utils sets the default dtype to double!
torch.set_default_dtype(torch.float32)
assert torch.get_default_dtype() is torch.float32 assert torch.get_default_dtype() is torch.float32
if TEST_SCIPY: if TEST_SCIPY:

View File

@ -26,6 +26,8 @@ from torch.testing._internal.common_utils import (
) )
# Protects against includes accidentally setting the default dtype # Protects against includes accidentally setting the default dtype
# NOTE: jit_metaprogramming_utils sets the default dtype to double!
torch.set_default_dtype(torch.float32)
assert torch.get_default_dtype() is torch.float32 assert torch.get_default_dtype() is torch.float32

View File

@ -9,6 +9,9 @@ from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes) (instantiate_device_type_tests, ops, OpDTypes)
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033 # TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033
# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The # AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The
# issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33 # issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33

View File

@ -10,6 +10,9 @@ from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes) (instantiate_device_type_tests, ops, OpDTypes)
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
# gradcheck requires double precision # gradcheck requires double precision
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=[torch.double, torch.cdouble]) allowed_dtypes=[torch.double, torch.cdouble])

View File

@ -15,6 +15,9 @@ from torch.testing._internal.jit_metaprogramming_utils import create_script_fn,
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
# variant testing is only done with torch.float and torch.cfloat to avoid # variant testing is only done with torch.float and torch.cfloat to avoid
# excessive test times and maximize signal to noise ratio # excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported, _variant_ops = partial(ops, dtypes=OpDTypes.supported,

View File

@ -15,7 +15,6 @@ from torch.testing import make_tensor
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
set_default_dtype, set_default_tensor_type,
TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo) TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo)
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
@ -1966,36 +1965,37 @@ class TestTensorCreation(TestCase):
# TODO: this test should be updated # TODO: this test should be updated
@onlyCPU @onlyCPU
def test_constructor_dtypes(self, device): def test_constructor_dtypes(self, device):
default_type = torch.tensor([]).type()
self.assertIs(torch.tensor([]).dtype, torch.get_default_dtype()) self.assertIs(torch.tensor([]).dtype, torch.get_default_dtype())
self.assertIs(torch.uint8, torch.ByteTensor.dtype) self.assertIs(torch.uint8, torch.ByteTensor.dtype)
self.assertIs(torch.float32, torch.FloatTensor.dtype) self.assertIs(torch.float32, torch.FloatTensor.dtype)
self.assertIs(torch.float64, torch.DoubleTensor.dtype) self.assertIs(torch.float64, torch.DoubleTensor.dtype)
with set_default_tensor_type('torch.FloatTensor'): torch.set_default_tensor_type('torch.FloatTensor')
self.assertIs(torch.float32, torch.get_default_dtype()) self.assertIs(torch.float32, torch.get_default_dtype())
self.assertIs(torch.FloatStorage, torch.Storage) self.assertIs(torch.FloatStorage, torch.Storage)
# only floating-point types are supported as the default type # only floating-point types are supported as the default type
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor')) self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor'))
with set_default_dtype(torch.float64): torch.set_default_dtype(torch.float64)
self.assertIs(torch.float64, torch.get_default_dtype()) self.assertIs(torch.float64, torch.get_default_dtype())
self.assertIs(torch.DoubleStorage, torch.Storage) self.assertIs(torch.DoubleStorage, torch.Storage)
with set_default_tensor_type(torch.FloatTensor): torch.set_default_tensor_type(torch.FloatTensor)
self.assertIs(torch.float32, torch.get_default_dtype()) self.assertIs(torch.float32, torch.get_default_dtype())
self.assertIs(torch.FloatStorage, torch.Storage) self.assertIs(torch.FloatStorage, torch.Storage)
if torch.cuda.is_available(): if torch.cuda.is_available():
with set_default_tensor_type(torch.cuda.FloatTensor): torch.set_default_tensor_type(torch.cuda.FloatTensor)
self.assertIs(torch.float32, torch.get_default_dtype()) self.assertIs(torch.float32, torch.get_default_dtype())
self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype) self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
self.assertIs(torch.cuda.FloatStorage, torch.Storage) self.assertIs(torch.cuda.FloatStorage, torch.Storage)
with set_default_dtype(torch.float64): torch.set_default_dtype(torch.float64)
self.assertIs(torch.float64, torch.get_default_dtype()) self.assertIs(torch.float64, torch.get_default_dtype())
self.assertIs(torch.cuda.DoubleStorage, torch.Storage) self.assertIs(torch.cuda.DoubleStorage, torch.Storage)
# don't allow passing dtype to set_default_tensor_type # don't allow passing dtype to set_default_tensor_type
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32)) self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32))
@ -2008,11 +2008,12 @@ class TestTensorCreation(TestCase):
torch.float, torch.float,
torch.double, torch.double,
torch.bfloat16): torch.bfloat16):
with set_default_dtype(t): torch.set_default_dtype(t)
pass
else: else:
self.assertRaises(TypeError, lambda: torch.set_default_dtype(t)) self.assertRaises(TypeError, lambda: torch.set_default_dtype(t))
torch.set_default_tensor_type(default_type)
# TODO: this test should be updated # TODO: this test should be updated
@onlyCPU @onlyCPU
def test_constructor_device_legacy(self, device): def test_constructor_device_legacy(self, device):
@ -2048,10 +2049,14 @@ class TestTensorCreation(TestCase):
self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu')) self.assertRaises(RuntimeError, lambda: torch.Tensor(i, device='cpu'))
self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu')) self.assertRaises(RuntimeError, lambda: i.new(i, device='cpu'))
with set_default_tensor_type(torch.cuda.FloatTensor): default_type = torch.Tensor().type()
self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu')) torch.set_default_tensor_type(torch.cuda.FloatTensor)
self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu')) self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu')) self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu'))
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_tensor_type(default_type)
x = torch.randn((3,), device='cuda') x = torch.randn((3,), device='cuda')
self.assertRaises(RuntimeError, lambda: x.new(device='cpu')) self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu')) self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
@ -2153,6 +2158,8 @@ class TestTensorCreation(TestCase):
@onlyCPU @onlyCPU
def test_tensor_factory_type_inference(self, device): def test_tensor_factory_type_inference(self, device):
def test_inference(default_dtype): def test_inference(default_dtype):
saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(default_dtype)
default_complex_dtype = torch.complex64 if default_dtype == torch.float32 else torch.complex128 default_complex_dtype = torch.complex64 if default_dtype == torch.float32 else torch.complex128
self.assertIs(default_dtype, torch.tensor(()).dtype) self.assertIs(default_dtype, torch.tensor(()).dtype)
self.assertIs(default_dtype, torch.tensor(5.).dtype) self.assertIs(default_dtype, torch.tensor(5.).dtype)
@ -2174,10 +2181,10 @@ class TestTensorCreation(TestCase):
self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype) self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype) self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype) self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
torch.set_default_dtype(saved_dtype)
for dtype in [torch.float64, torch.float32]: test_inference(torch.float64)
with set_default_dtype(dtype): test_inference(torch.float32)
test_inference(dtype)
# TODO: this test should be updated # TODO: this test should be updated
@suppress_warnings @suppress_warnings
@ -2464,6 +2471,8 @@ class TestTensorCreation(TestCase):
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
@onlyCPU @onlyCPU
def test_arange_inference(self, device): def test_arange_inference(self, device):
saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
# end only # end only
self.assertIs(torch.float32, torch.arange(1.).dtype) self.assertIs(torch.float32, torch.arange(1.).dtype)
self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype) self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype)
@ -2492,6 +2501,7 @@ class TestTensorCreation(TestCase):
torch.arange(torch.tensor(1), torch.arange(torch.tensor(1),
torch.tensor(3), torch.tensor(3),
torch.tensor(1, dtype=torch.int16)).dtype) torch.tensor(1, dtype=torch.int16)).dtype)
torch.set_default_dtype(saved_dtype)
# cannot call storage() on meta tensor # cannot call storage() on meta tensor
@skipMeta @skipMeta
@ -2808,24 +2818,28 @@ class TestTensorCreation(TestCase):
@onlyCUDA @onlyCUDA
def test_tensor_factory_gpu_type_inference(self, device): def test_tensor_factory_gpu_type_inference(self, device):
with set_default_tensor_type(torch.cuda.DoubleTensor): saved_type = torch.tensor([]).type()
with set_default_dtype(torch.float32): torch.set_default_tensor_type(torch.cuda.DoubleTensor)
self.assertIs(torch.float32, torch.tensor(0.).dtype) torch.set_default_dtype(torch.float32)
self.assertEqual(torch.device(device), torch.tensor(0.).device) self.assertIs(torch.float32, torch.tensor(0.).dtype)
with set_default_dtype(torch.float64): self.assertEqual(torch.device(device), torch.tensor(0.).device)
self.assertIs(torch.float64, torch.tensor(0.).dtype) torch.set_default_dtype(torch.float64)
self.assertEqual(torch.device(device), torch.tensor(0.).device) self.assertIs(torch.float64, torch.tensor(0.).dtype)
self.assertEqual(torch.device(device), torch.tensor(0.).device)
torch.set_default_tensor_type(saved_type)
@onlyCUDA @onlyCUDA
def test_tensor_factory_gpu_type(self, device): def test_tensor_factory_gpu_type(self, device):
with set_default_tensor_type(torch.cuda.FloatTensor): saved_type = torch.tensor([]).type()
x = torch.zeros((5, 5)) torch.set_default_tensor_type(torch.cuda.FloatTensor)
self.assertIs(torch.float32, x.dtype) x = torch.zeros((5, 5))
self.assertTrue(x.is_cuda) self.assertIs(torch.float32, x.dtype)
with set_default_tensor_type(torch.cuda.DoubleTensor): self.assertTrue(x.is_cuda)
x = torch.zeros((5, 5)) torch.set_default_tensor_type(torch.cuda.DoubleTensor)
self.assertIs(torch.float64, x.dtype) x = torch.zeros((5, 5))
self.assertTrue(x.is_cuda) self.assertIs(torch.float64, x.dtype)
self.assertTrue(x.is_cuda)
torch.set_default_tensor_type(saved_type)
@skipCPUIf(True, 'compares device with cpu') @skipCPUIf(True, 'compares device with cpu')
@dtypes(torch.int, torch.long, torch.float, torch.double) @dtypes(torch.int, torch.long, torch.float, torch.double)
@ -3067,23 +3081,27 @@ class TestTensorCreation(TestCase):
def test_full_inference(self, device, dtype): def test_full_inference(self, device, dtype):
size = (2, 2) size = (2, 2)
with set_default_dtype(dtype): prev_default = torch.get_default_dtype()
# Tests bool fill value inference torch.set_default_dtype(dtype)
t = torch.full(size, True)
self.assertEqual(t.dtype, torch.bool)
# Tests integer fill value inference # Tests bool fill value inference
t = torch.full(size, 1) t = torch.full(size, True)
self.assertEqual(t.dtype, torch.long) self.assertEqual(t.dtype, torch.bool)
# Tests float fill value inference # Tests integer fill value inference
t = torch.full(size, 1.) t = torch.full(size, 1)
self.assertEqual(t.dtype, dtype) self.assertEqual(t.dtype, torch.long)
# Tests complex inference # Tests float fill value inference
t = torch.full(size, (1 + 1j)) t = torch.full(size, 1.)
ctype = torch.complex128 if dtype is torch.double else torch.complex64 self.assertEqual(t.dtype, dtype)
self.assertEqual(t.dtype, ctype)
# Tests complex inference
t = torch.full(size, (1 + 1j))
ctype = torch.complex128 if dtype is torch.double else torch.complex64
self.assertEqual(t.dtype, ctype)
torch.set_default_dtype(prev_default)
def test_full_out(self, device): def test_full_out(self, device):
size = (5,) size = (5,)

View File

@ -32,7 +32,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON, TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
TEST_WITH_CROSSREF, skipIfTorchDynamo, set_default_dtype, TEST_WITH_CROSSREF, skipIfTorchDynamo,
skipCUDAMemoryLeakCheckIf, BytesIOContext, skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
@ -6010,7 +6010,6 @@ class TestTorch(TestCase):
self.assertEqual(added, -tensor) self.assertEqual(added, -tensor)
@skipIfTorchInductor("AssertionError: RuntimeError not raised by <lambda>") @skipIfTorchInductor("AssertionError: RuntimeError not raised by <lambda>")
@set_default_dtype(torch.double)
def test_index_add_correctness(self): def test_index_add_correctness(self):
# Check whether index_add can get correct result when # Check whether index_add can get correct result when
# alpha is 1, and dtype of index is torch.long, # alpha is 1, and dtype of index is torch.long,
@ -7274,21 +7273,21 @@ tensor([4.0000+0.j, inf+0.j, 1.5000+infj, -inf+4.j, 0.0000+0.j, nan+infj
self.assertExpectedInline(str(y), expected_str) self.assertExpectedInline(str(y), expected_str)
# test dtype # test dtype
with set_default_dtype(torch.float): torch.set_default_dtype(torch.float)
x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64) x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
self.assertEqual(x.__repr__(), str(x)) self.assertEqual(x.__repr__(), str(x))
expected_str = '''\ expected_str = '''\
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
inf], dtype=torch.float64)''' inf], dtype=torch.float64)'''
self.assertExpectedInline(str(x), expected_str) self.assertExpectedInline(str(x), expected_str)
# test changing default dtype # test changing default dtype
with set_default_dtype(torch.float64): torch.set_default_dtype(torch.float64)
self.assertEqual(x.__repr__(), str(x)) self.assertEqual(x.__repr__(), str(x))
expected_str = '''\ expected_str = '''\
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
inf])''' inf])'''
self.assertExpectedInline(str(x), expected_str) self.assertExpectedInline(str(x), expected_str)
# test summary # test summary
x = torch.zeros(10000) x = torch.zeros(10000)

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: typing"] # Owner(s): ["module: typing"]
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests, set_default_dtype from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY, load_tests
# load_tests from common_utils is used to automatically filter tests for # load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings # sharding on sandcastle. This line silences flake warnings
@ -38,6 +38,7 @@ class TestDTypeInfo(TestCase):
@unittest.skipIf(not TEST_NUMPY, "Numpy not found") @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_finfo(self): def test_finfo(self):
initial_default_type = torch.get_default_dtype()
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]: for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
x = torch.zeros((2, 2), dtype=dtype) x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.finfo(x.dtype) xinfo = torch.finfo(x.dtype)
@ -51,8 +52,8 @@ class TestDTypeInfo(TestCase):
self.assertEqual(xinfo.resolution, xninfo.resolution) self.assertEqual(xinfo.resolution, xninfo.resolution)
self.assertEqual(xinfo.dtype, xninfo.dtype) self.assertEqual(xinfo.dtype, xninfo.dtype)
if not dtype.is_complex: if not dtype.is_complex:
with set_default_dtype(dtype): torch.set_default_dtype(dtype)
self.assertEqual(torch.finfo(dtype), torch.finfo()) self.assertEqual(torch.finfo(dtype), torch.finfo())
# Special test case for BFloat16 type # Special test case for BFloat16 type
x = torch.zeros((2, 2), dtype=torch.bfloat16) x = torch.zeros((2, 2), dtype=torch.bfloat16)
@ -65,8 +66,11 @@ class TestDTypeInfo(TestCase):
self.assertEqual(xinfo.tiny, xinfo.smallest_normal) self.assertEqual(xinfo.tiny, xinfo.smallest_normal)
self.assertEqual(xinfo.resolution, 0.01) self.assertEqual(xinfo.resolution, 0.01)
self.assertEqual(xinfo.dtype, "bfloat16") self.assertEqual(xinfo.dtype, "bfloat16")
with set_default_dtype(x.dtype): torch.set_default_dtype(x.dtype)
self.assertEqual(torch.finfo(x.dtype), torch.finfo()) self.assertEqual(torch.finfo(x.dtype), torch.finfo())
# Restore the default type to ensure that the test has no side effect
torch.set_default_dtype(initial_default_type)
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -1785,14 +1785,14 @@ new_module_tests = [
dict( dict(
fullname='EmbeddingBag_sparse', fullname='EmbeddingBag_sparse',
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)',
input_fn=lambda: torch.randperm(2).repeat(1, 2), input_fn=lambda: torch.randperm(2).repeat(1, 2),
check_gradgrad=False, check_gradgrad=False,
has_sparse_gradients=True, has_sparse_gradients=True,
), ),
dict( dict(
constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)',
input_fn=lambda: torch.randperm(2).repeat(1, 2), input_fn=lambda: torch.randperm(2).repeat(1, 2),
fullname='Embedding_sparse', fullname='Embedding_sparse',
check_gradgrad=False, check_gradgrad=False,
@ -3168,7 +3168,7 @@ criterion_tests = [
), ),
dict( dict(
module_name='MSELoss', module_name='MSELoss',
input_fn=lambda: torch.rand((2, 3, 4, 5), dtype=torch.double), input_size=(2, 3, 4, 5),
target_fn=lambda: torch.randn((2, 3, 4, 5), dtype=torch.double, requires_grad=True), target_fn=lambda: torch.randn((2, 3, 4, 5), dtype=torch.double, requires_grad=True),
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel() reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
if get_reduction(m) == 'mean' else 1)), if get_reduction(m) == 'mean' else 1)),
@ -3314,9 +3314,9 @@ criterion_tests = [
dict( dict(
module_name='MultiMarginLoss', module_name='MultiMarginLoss',
constructor_args=(1, 1., torch.rand(10, dtype=torch.double)), constructor_args=(1, 1., torch.rand(10, dtype=torch.double)),
cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10).to(torch::kFloat64))', cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10))',
legacy_constructor_args=(1, torch.rand(10, dtype=torch.double)), legacy_constructor_args=(1, torch.rand(10, dtype=torch.double)),
input_fn=lambda: torch.rand(5, 10, dtype=torch.double), input_size=(5, 10),
target_fn=lambda: torch.rand(5).mul(8).floor().long(), target_fn=lambda: torch.rand(5).mul(8).floor().long(),
reference_fn=lambda i, t, m: reference_fn=lambda i, t, m:
multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)), multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
@ -3403,7 +3403,7 @@ criterion_tests = [
dict( dict(
module_name='BCEWithLogitsLoss', module_name='BCEWithLogitsLoss',
constructor_args=(torch.rand(10, dtype=torch.double),), constructor_args=(torch.rand(10, dtype=torch.double),),
cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10).to(torch::kFloat64))', cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10))',
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).to(torch.get_default_dtype()), target_fn=lambda: torch.randn(15, 10).gt(0).to(torch.get_default_dtype()),
desc='weights', desc='weights',
@ -3412,7 +3412,7 @@ criterion_tests = [
dict( dict(
module_name='BCEWithLogitsLoss', module_name='BCEWithLogitsLoss',
constructor_args=(torch.rand((), dtype=torch.double),), constructor_args=(torch.rand((), dtype=torch.double),),
cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}).to(torch::kFloat64))', cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}))',
input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2), input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(()).gt(0).to(torch.get_default_dtype()), target_fn=lambda: torch.randn(()).gt(0).to(torch.get_default_dtype()),
desc='scalar_weights', desc='scalar_weights',
@ -3826,7 +3826,7 @@ criterion_tests = [
), ),
dict( dict(
module_name='MSELoss', module_name='MSELoss',
input_fn=lambda: torch.rand((), dtype=torch.double), input_size=(),
target_fn=lambda: torch.randn((), requires_grad=True, dtype=torch.double), target_fn=lambda: torch.randn((), requires_grad=True, dtype=torch.double),
reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
(i.numel() if get_reduction(m) == 'mean' else 1)), (i.numel() if get_reduction(m) == 'mean' else 1)),

View File

@ -1648,15 +1648,6 @@ def set_default_dtype(dtype):
finally: finally:
torch.set_default_dtype(saved_dtype) torch.set_default_dtype(saved_dtype)
@contextlib.contextmanager
def set_default_tensor_type(tensor_type):
saved_tensor_type = torch.tensor([]).type()
torch.set_default_tensor_type(tensor_type)
try:
yield
finally:
torch.set_default_tensor_type(saved_tensor_type)
def iter_indices(tensor): def iter_indices(tensor):
if tensor.dim() == 0: if tensor.dim() == 0:
return range(0) return range(0)
@ -2245,8 +2236,6 @@ class TestCase(expecttest.TestCase):
_precision: float = 0 _precision: float = 0
_rel_tol: float = 0 _rel_tol: float = 0
_avoid_default_dtype_check: bool = False
# checker to early terminate test suite if unrecoverable failure occurs. # checker to early terminate test suite if unrecoverable failure occurs.
def _should_stop_test_suite(self): def _should_stop_test_suite(self):
if torch.cuda.is_initialized(): if torch.cuda.is_initialized():
@ -2563,9 +2552,6 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
# decorator to disable the invariant checks. # decorator to disable the invariant checks.
torch.sparse.check_sparse_tensor_invariants.enable() torch.sparse.check_sparse_tensor_invariants.enable()
if not self._avoid_default_dtype_check:
assert torch.get_default_dtype() == torch.float
def tearDown(self): def tearDown(self):
# There exists test cases that override TestCase.setUp # There exists test cases that override TestCase.setUp
# definition, so we cannot assume that _check_invariants # definition, so we cannot assume that _check_invariants
@ -2577,9 +2563,6 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
else: else:
torch.sparse.check_sparse_tensor_invariants.disable() torch.sparse.check_sparse_tensor_invariants.disable()
if not self._avoid_default_dtype_check:
assert torch.get_default_dtype() == torch.float
@staticmethod @staticmethod
def _make_crow_indices(n_rows, n_cols, nnz, def _make_crow_indices(n_rows, n_cols, nnz,
*, device, dtype, random=True): *, device, dtype, random=True):