pytorch/test/test_fake_tensor.py
eellison ad53b53518 Generate patterns in fp16 and fp32 (#109142)
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917
2023-09-20 06:38:02 +00:00

1183 lines
45 KiB
Python

# Owner(s): ["module: meta tensors"]
from torch.testing._internal.common_utils import (
TestCase, TEST_WITH_TORCHDYNAMO, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
instantiate_parametrized_tests)
import torch
import torch._dynamo
import itertools
import numpy as np
from torch.testing._internal.jit_utils import RUN_CUDA
from torch._subclasses.fake_tensor import (
FakeTensor,
FakeTensorMode,
FakeTensorConverter,
DynamicOutputShapeException,
UnsupportedOperatorException,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch._dynamo.testing import rand_strided
from torch.testing import FileCheck
import unittest
import torch._prims as prims
import contextlib
import weakref
import copy
import torch._functorch.config
import torch.testing._internal.optests as optests
from unittest.mock import patch
from torch import distributed as dist
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device.type, device_str)
self.assertEqual(list(t.size()), size)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cuda_initialized(self):
# doesnt error
with FakeTensorMode():
p = torch.randn(4, 2, requires_grad=True, device='cuda')
x = torch.randn(8, 4, device='cuda')
y = torch.mm(x, p).square().sum()
y.backward()
def test_basic(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x = mode.from_tensor(x)
y = mode.from_tensor(y)
z = x + y
self.assertEqual(z.shape, (4, 2, 2))
self.assertEqual(z.device, torch.device("cpu"))
self.assertTrue(isinstance(z, FakeTensor))
def test_basic_forced_memo_only(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x_fake = mode.from_tensor(x)
x2 = mode.from_tensor(x, memoized_only=True)
self.assertTrue(x2 is not None)
y = mode.from_tensor(y, memoized_only=True)
self.assertIs(y, None)
def test_custom_op_fallback(self):
from torch.library import Library, impl
test_lib = Library("my_test_op", "DEF")
test_lib.define('foo(Tensor self) -> Tensor')
@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()
x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)
def test_parameter_instantiation(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.nn.parameter.Parameter(x)
self.assertTrue(isinstance(y, torch.nn.Parameter))
@unittest.skipIf(not dist.is_available(), "requires distributed")
def test_fsdp_flat_param(self):
from torch.distributed.fsdp.flat_param import FlatParameter
with FakeTensorMode() as m:
data = torch.randn(2, 2)
param = FlatParameter(data, requires_grad=True)
self.assertIsInstance(param, FlatParameter)
self.assertIsInstance(param, torch.nn.Parameter)
self.assertIsInstance(param, FakeTensor)
def test_non_parameter_grad(self):
mode = FakeTensorMode()
t = torch.rand([4], requires_grad=True)
fake_t = mode.from_tensor(t)
self.assertEqual(fake_t.requires_grad, t.requires_grad)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_index_cuda_with_cpu(self):
with FakeTensorMode():
x = torch.rand([2048], device='cuda')
out = x[torch.zeros([36], dtype=torch.int64)]
self.checkType(out, "cuda", [36])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_shape_take_not_device(self):
with FakeTensorMode():
x = torch.empty(1, device="cpu")
y = torch.empty(8, 8, device="cuda")
out = x.resize_as_(y)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device.type, "cpu")
self.assertTrue(isinstance(out, FakeTensor))
def test_repr(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
x = torch.empty(2, 2, device="meta")
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
with FakeTensorMode() as mode:
x = torch.tensor(0.)
y = torch.rand([4, 4], device="cuda")
out = x + y
self.assertEqual(out.shape, (4, 4))
self.assertEqual(out.device, y.device)
self.assertTrue(isinstance(out, FakeTensor))
def test_nan_to_num(self):
with FakeTensorMode():
for dtype in [torch.float16, torch.float32]:
x = torch.rand([4], dtype=dtype)
y = torch.nan_to_num(x, nan=None)
z = torch.nan_to_num(x, 0.0)
self.assertEqual(dtype, y.dtype)
self.assertEqual(dtype, z.dtype)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_throw(self):
x = torch.tensor(0.) # TODO: tensor() errors
with FakeTensorMode() as mode:
x_conv = mode.from_tensor(x)
y = torch.rand([4, 4], device="cuda")
z = torch.rand([4, 4], device="cpu")
self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_type_as(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = torch.rand([4, 4], device="cuda")
out = x.type_as(y)
self.assertEqual(out.device.type, "cuda")
self.assertTrue(isinstance(out, FakeTensor))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_setitem(self):
for device in ["cpu", "cuda"]:
with FakeTensorMode():
x = torch.rand([16, 1], device=device)
x[..., 0] = 0
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_device_inplace_copy(self):
with FakeTensorMode():
x = torch.rand([8, 8], device="cpu")
y = torch.rand([8, 8], device="cuda")
assert x.copy_(y).device.type == "cpu"
assert y.copy_(x).device.type == "cuda"
def test_fake_dispatch_keys(self):
with FakeTensorMode():
x = torch.rand([4])
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
f.run(torch._C._dispatch_key_set(x))
with torch.inference_mode():
x = torch.rand([4])
y = x + x
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4], device="cpu")
self.assertTrue(isinstance(x, FakeTensor))
self.assertTrue(x.device.type == "cpu")
def test_mode(self):
with FakeTensorMode():
y = torch.rand([4], device="cpu")
out = y + y
self.assertTrue(isinstance(out, FakeTensor))
def test_full(self):
# Test torch.full returns tensor with correct dtype
with torch._subclasses.CrossRefFakeMode():
y = torch.full((4, 4), 1)
def check_function_with_fake(self, fn):
out = fn()
with torch._subclasses.FakeTensorMode():
out_fake = fn()
for a, b in zip(tree_flatten(out), tree_flatten(out_fake)):
if not isinstance(a, FakeTensor):
self.assertTrue(not isinstance(b, FakeTensor))
continue
prims.utils.compare_tensor_meta(a, b, check_strides=True)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_non_kwarg_device(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = x.to(torch.device("cpu"))
self.assertIs(x, y)
z = x.to(torch.device("cuda"))
self.assertEqual(z.device.type, "cuda")
def test_non_overlapping_stride_zero(self):
def foo():
x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
return x.half()
self.check_function_with_fake(foo)
def test_fake_mode_error(self):
x = torch.rand([4, 4])
with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
with FakeTensorMode():
y = x[0]
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_fake_grad_copy(self):
x = torch.rand([4, 4], requires_grad=True)
x.grad = torch.rand([4, 4])
mode = FakeTensorMode()
fake_x = mode.from_tensor(x)
prims.utils.compare_tensor_meta(fake_x, x)
prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_like_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4])
y = torch.ones_like(x)
self.assertTrue(isinstance(y, FakeTensor))
self.assertEqual(y.device.type, "cpu")
z = torch.ones_like(x, device="cuda")
self.assertTrue(isinstance(z, FakeTensor))
self.assertEqual(z.device.type, "cuda")
def test_binary_op_type_promotion(self):
with FakeTensorMode():
x = torch.empty([2, 2], dtype=torch.float)
y = torch.empty([2, 2], dtype=torch.int64)
out = x / y
self.assertEqual(out.dtype, torch.float)
self.assertEqual(out.device.type, "cpu")
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_from_numpy(self):
with FakeTensorMode():
x = torch.tensor(np.zeros([4, 4]))
self.checkType(x, "cpu", [4, 4])
def test_randperm(self):
x = torch.randperm(10)
y = torch.randperm(5, device="cpu")
with FakeTensorMode():
x1 = torch.randperm(10)
prims.utils.compare_tensor_meta(x, x1)
y1 = torch.randperm(5, device="cpu")
prims.utils.compare_tensor_meta(y, y1)
def test_print_in_fake_mode(self):
x = torch.zeros(2)
# does not fail
with FakeTensorMode():
out = str(x)
assert "FakeTensor" not in out
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_upsample_bilinear_small_channels(self):
out = []
mode = FakeTensorMode()
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='cuda')
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))
self.assertTrue(out[1].is_contiguous())
self.checkMetaProps(out[0], out[1])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cpu_fallback(self):
with FakeTensorMode(allow_fallback_kernels=False):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "cuda")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
with FakeTensorMode(allow_fallback_kernels=True):
# intentionally bad inputs
filters = torch.randn(8, 20, 3, 3).cuda()
inputs = torch.randn(1, 7, 10, 5).cuda()
with self.assertRaises(RuntimeError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "cuda")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_out_multi_device(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.rand([4], device="cuda")
with self.assertRaisesRegex(Exception, "found two different devices"):
torch.sin(x, out=y)
with self.assertRaisesRegex(Exception, "found two different devices"):
x.add_(y)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_normalize_device(self):
with FakeTensorMode():
x = torch.empty(1, device="cuda")
y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
out = x + y
self.checkType(out, "cuda", [1])
def test_recursive_invocation(self):
mode = FakeTensorMode()
with mode:
x = torch.tensor(2)
mode.in_kernel_invocation = True
y = x + x
self.assertTrue(mode.in_kernel_invocation)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@skipIfRocm
@parametrize("allow_fallback_kernels", [False, True],
lambda a: 'with_fallback' if a else 'without_fallback')
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cudnn_rnn(self, allow_fallback_kernels):
def fn(
a0,
b0,
b1,
b2,
b3,
b4,
b5,
b6,
b7,
b8,
b9,
b10,
b11,
b12,
b13,
b14,
b15,
a3,
a4,
a5,
):
a1 = [
b0,
b1,
b2,
b3,
b4,
b5,
b6,
b7,
b8,
b9,
b10,
b11,
b12,
b13,
b14,
b15,
]
return torch.ops.aten._cudnn_rnn(
a0,
a1,
4,
a3,
a4,
a5,
2,
2048,
0,
2,
False,
0.0,
False,
True,
[],
None,
)
mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
inps1 = [
torch.randn([92, 8, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192, 4096]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192, 4096]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([167837696]).cuda(),
torch.randn([4, 8, 2048]).cuda(),
torch.randn([4, 8, 2048]).cuda(),
]
inps2 = inps1
inps2[len(inps2) - 1] = None # argument `cx` can be None
for inps in [inps1, inps2]:
out = fn(*inps)
self.assertIs(out[4], inps[-3])
for ten in out:
if i == 1:
self.assertTrue(isinstance(ten, FakeTensor))
self.assertEqual(ten.device.type, 'cuda')
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cuda_lstm(self):
# Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
with torch.backends.cudnn.flags(enabled=False):
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
with fake_tensor_mode:
N = 5
L = 4
H_in = 2
hidden_size = 3
proj_size = 2
num_layers = 2
bidir = False
D = 2 if bidir else 1
H_out = proj_size if proj_size > 0 else hidden_size
lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
num_layers=num_layers, proj_size=proj_size, batch_first=False,
bias=True, bidirectional=bidir, device='cuda')
h_0 = torch.randn((num_layers * D, N, H_out), device='cuda')
c_0 = torch.randn((num_layers * D, N, hidden_size), device='cuda')
inp = torch.randn((L, N, H_in), device='cuda')
(output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
output.sum().backward()
self.assertEqual(output.shape, (L, N, D * H_out))
self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
def test_data_dependent_operator(self):
with FakeTensorMode(allow_fallback_kernels=False):
x = torch.rand([10, 10])
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
def checkMetaProps(self, t1, t2):
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
@skipIfCrossRef
def test_deepcopy(self):
with FakeTensorMode() as mode:
pass
mod = torch.nn.BatchNorm2d(10)
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
def check_copy(mod, mod_copied):
for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
param_copied = getattr(mod_copied, name)
self.checkMetaProps(param, param_copied)
self.assertTrue(isinstance(param_copied, FakeTensor))
self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
self.assertEqual(param.requires_grad, param_copied.requires_grad)
check_copy(mod, mod_copied)
class ModuleNew(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.rand([10, 2])
self.b = self.a
self.c = self.a[0]
mod = ModuleNew()
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
self.assertIs(mod_copied.a, mod_copied.b)
self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_new(self):
with FakeTensorMode():
a = torch.rand([16, 1])
self.checkType(a.new(10, 10), "cpu", [10, 10])
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
b = torch.rand([4, 4], device='cuda')
self.checkType(b.new(device='cuda'), "cuda", [0])
self.checkType(a.new(torch.rand([1])), "cpu", [1])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_scalar_inputs(self):
with FakeTensorMode():
self.checkType(torch.div(3, 2), "cpu", [])
ten = torch.zeros(2, dtype=torch.int32) * 2.0
self.assertEqual(ten.dtype, torch.float)
self.checkType(ten, "cpu", [2])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_allow_meta(self):
def run_meta():
with FakeTensorMode():
x = torch.rand([4], device="meta")
return x + x
self.checkType(run_meta(), "meta", [4])
with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
self.assertRaises(Exception, run_meta)
def test_embedding_bag_meta(self):
def f():
# This behavior was originally unintentional but we see people
# relying on it
embedding = torch.nn.EmbeddingBag(10, 3, mode='sum', device='meta')
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
offsets = torch.tensor([0, 4], dtype=torch.long)
return embedding(input, offsets)
real_out = f()
with FakeTensorMode():
fake_out = f()
for r, f in zip(real_out, fake_out):
self.assertEqual(r.size(), f.size())
self.assertEqual(r.device, f.device)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_mixed_real_and_fake_inputs(self):
class _TestPattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
def forward(self, input):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.conv.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.conv.weight.shape)
bias_shape[1] = -1
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv
example_inputs = (torch.randn(1, 1, 3, 3),)
mod = _TestPattern()
with FakeTensorMode(allow_non_fake_inputs=True):
out = mod(torch.randn(1, 1, 3, 3))
self.checkType(out, "cpu", (1, 1, 3, 3))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_aten_copy_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, device="cpu")
x2 = torch.rand(4, device="cuda")
copy1 = torch.ops.aten.copy.default(x1, x2)
copy2 = torch.ops.aten.copy.default(x2, x1)
out = torch.empty(4, device="cpu")
torch.ops.aten.copy.out(x1, x2, out=out)
self.checkType(copy1, "cpu", (4,))
self.checkType(copy2, "cuda", (4,))
self.checkType(out, "cpu", (4,))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_aten_index_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
x2 = torch.rand(4, 4, device="cuda")
i1 = torch.tensor([0, 1], device="cuda")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)
y1 = torch.rand(4, device="cpu")
y2 = torch.rand(4, device="cuda")
j1 = torch.tensor([2], device="cuda")
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
self.checkType(r2, "cuda", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "cuda", (4, 4))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_aten_slice_scatter_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
y1 = torch.rand(2, 4, device="cuda")
x2 = torch.rand(4, 4, device="cuda")
y2 = torch.rand(2, 4, device="cpu")
out = torch.empty(4, 4, device="cpu")
r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
self.checkType(r1, "cpu", (4, 4))
self.checkType(r2, "cuda", (4, 4))
self.checkType(r3, "cpu", (4, 4))
self.checkType(out, "cpu", (4, 4))
def test__adaptive_avg_pool2d_backward(self):
with FakeTensorMode():
grad_out = torch.rand(2, 3, 4, 4)
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is not None)
def assertNotConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is None)
def test_simple(self):
with FakeTensorMode():
x = torch.tensor(4.)
self.assertEqual(x.item(), 4.)
def test_inplace_add(self):
with FakeTensorMode():
x = torch.tensor(4.)
y = x.add_(1)
self.assertEqual(x.item(), 5.)
self.assertEqual(y.item(), 5.)
self.assertConst(x, y)
def test_shared_storages(self):
with FakeTensorMode():
x = torch.tensor([4.])
y = x[:]
self.assertEqual(x.storage()._cdata, y.storage()._cdata)
self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
def test_constant_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
self.assertConst(x)
y = torch.rand([1])
x.add_(y)
self.assertNotConst(x)
def test_inplace_view_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1])
self.assertConst(x)
x.resize_([2])
self.assertEqual(x.size(0), 2)
self.assertNotConst(x)
def test_fake_tensor_in_intlist_repro(self):
def fn(tensors):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
return tensors[0].new_full(batch_shape, 0.0)
with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
with torch._subclasses.fake_tensor.FakeTensorMode():
a = torch.randn(3, 800, 1199)
b = torch.randn(3, 800, 800)
inputs = [a, b]
ref = fn(inputs)
def test_fake_tensor_batch_norm_cpu(self):
with torch._subclasses.CrossRefFakeMode():
m = torch.nn.Sequential(
torch.nn.BatchNorm2d(10),
torch.nn.ReLU(),
)
m.eval()
out = m(torch.randn([2, 10, 8, 8]))
def test_shared_storage_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
y = x[:]
self.assertConst(x, y)
y.add_(torch.rand([1]))
self.assertNotConst(x, y)
def test_aliased_const_write(self):
with FakeTensorMode():
x = torch.tensor([1])
y = x.expand([4])
self.assertNotConst(y)
y[0] = 1
self.assertNotConst(x)
def test_constant_propagate_through_functions(self):
with FakeTensorMode():
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
)
class FakeTensorOpInfoTest(TestCase):
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_fake(self, device, dtype, op):
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
optests.fake_check(op, args, kwargs)
class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
def test_memoized_conversion_from_meta(self):
x = torch.rand(2, 2).to(device="meta")
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))
def test_separate_tensor_storages_view(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
x_conv = converter(mode, x)
y_conv = converter(mode, y)
self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_separate_tensor_storages_non_view(self):
x = torch.rand(2, 2, 2)
y = torch.rand(4, 2)
y.set_(x.storage())
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
x_conv = converter(mode, x)
y_conv = converter(mode, y)
stor_id = torch._C._storage_id(x_conv)
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
del x
self.assertEqual(len(converter.tensor_memo), 1)
converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
del y
self.assertEqual(len(converter.tensor_memo), 0)
converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_dead_weak_ref(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter(mode, x)
x_conv_storage = torch._C._storage_id(x_conv)
del x_conv
self.assertFalse(x in converter.tensor_memo)
y_conv = converter(mode, y)
self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_dead_key(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter(mode, x)
self.assertEqual(len(converter.tensor_memo), 1)
x_conv2 = converter(mode, x)
assert x_conv2 is x_conv
del x
self.assertEqual(len(converter.tensor_memo), 0)
def test_no_active_mode(self):
with FakeTensorMode() as mode:
x = torch.empty(2, 2, device="cpu")
y = torch.empty(2, 2, device="cpu")
out = x + y
self.assertEqual(mode, out.fake_mode)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(out.device.type, "cpu")
def test_multiple_modes(self):
t = torch.rand([4])
t2 = torch.rand([4])
with FakeTensorMode() as m:
with FakeTensorMode() as m2:
t_fake = m.from_tensor(t)
t2_fake = m2.from_tensor(t2)
with self.assertRaisesRegex(Exception, "Mixing fake modes"):
t_fake + t2_fake
def test_separate_mode_error(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
with FakeTensorMode():
y = torch.empty(2, 2, device="cpu")
self.assertRaises(Exception, lambda: x, y)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_no_ref_cycle(self):
x = torch.rand([4])
mode = FakeTensorMode()
y = mode.from_tensor(x)
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
mode_weak = weakref.ref(mode)
y_weak = weakref.ref(mode)
del mode
del y
assert mode_weak() is None
assert y_weak() is None
class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)
@staticmethod
def get_all_aten_schemas():
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue
yield schema
def test_non_kwarg_only_device(self):
for schema in self.get_all_aten_schemas():
ten_type = torch._C.TensorType.get()
if not any(
contains_type(arg.type, ten_type)
for arg in itertools.chain(schema.arguments, schema.returns)
):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_non_kwarg_device = any(
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
if has_non_kwarg_device:
self.assertTrue(
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)
def test_tensor_constructors_all_have_kwarg_device(self):
for schema in self.get_all_aten_schemas():
op = self.get_aten_op(schema)
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_kwarg_device = any(
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
self.assertTrue(
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
)
@unittest.expectedFailure
def test_sparse_new(self):
with FakeTensorMode():
indices = torch.randn(1, 1, dtype=torch.int64)
values = torch.randn(1)
extra = (2,)
sparse = torch.randn(1).to_sparse()
# This used to segfault, now it does not, but it still raises an
# error
sparse2 = sparse.new(indices, values, extra)
def test_tensor_new(self):
with FakeTensorMode():
x = torch.Tensor([1, 2, 3])
self.assertIsInstance(x, FakeTensor)
def test_like_ops(self):
for schema in self.get_all_aten_schemas():
if "_like" == schema.name[-5:]:
op = self.get_aten_op(schema)
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
# at::_embedding_bag has no op info,
# and returns extra tensors that at::embedding bag throws away
def test_embedding_bag_private(self):
args = [
torch.ones(6, 1),
torch.ones(6, dtype=torch.int64),
torch.arange(2, dtype=torch.int64),
False,
2, # mode = max
]
ref_out = torch.ops.aten._embedding_bag(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.ops.aten._embedding_bag(*meta_args)
self.assertEqual(len(ref_out), len(meta_out))
for ref_o, meta_o in zip(ref_out, meta_out):
self.assertEqual(ref_o.size(), meta_o.size())
def test_cross_entropy_loss(self):
inp = torch.randn(3, 5)
target = torch.randint(5, (3,), dtype=torch.long)
weight = torch.rand(5)
fn = torch.nn.functional.cross_entropy
for w in (weight, None):
args = (inp, target, w)
ref = fn(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)
self.assertEqual(ref.size(), meta_out.size())
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
def test_flash_attention(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3, scale=0.17677669529663687)
args_new = [
[
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
],
[
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
]
]
for args_list in args_new:
args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
(bsz, num_heads, seq_len, head_dim) in args_list]
try:
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
except RuntimeError as e:
# We expect the cross ref to succed for the first output to fail
# for the rng state, see Note [Seed and Offset]
self.assertTrue("output[0]" not in str(e))
self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_conv_c1_backward(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten.convolution_backward.default(
arg1,
arg2,
arg3,
[1],
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
[True, True, False],
)
args_new = [
((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
def test_no_dispatch_with_like_function(self):
class CountingMode(TorchDispatchMode):
def __init__(self):
self.count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.count += 1
return func(*args, **kwargs)
with FakeTensorMode():
x = torch.randn(2)
with CountingMode() as mode:
with no_dispatch():
torch.zeros_like(x)
self.assertEqual(mode.count, 0)
class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value):
value = self.layer1(value)
value = torch.relu(value)
value = self.layer2(value)
return value
model = ToyNnModuleWithParameters()
value = torch.randn(5, 4)
# Convert nn.Module to GraphModule so that FakeTensorProp runs.
graph_model = torch.fx.symbolic_trace(model, (value,))
# The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
#
# TODO(wschin): there should be an API to run FakeTensorProp for GraphModule
# with parameters and buffers.
with FakeTensorMode() as fake_tensor_mode:
def to_fake_tensor(x):
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
return fake_tensor_mode.from_tensor(x)
return x
fake_parameters_and_buffers = {
k: to_fake_tensor(v)
for k, v in itertools.chain(
graph_model.named_parameters(), graph_model.named_buffers()
)
}
with torch.nn.utils.stateless._reparametrize_module(
graph_model, fake_parameters_and_buffers
):
# This case uses the **same** fake tensor mode to
# 1. create fake parameters and fake buffers, and
# 2. run FakeTensorProp
# The result should be correct.
result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
self.assertTrue(isinstance(result, FakeTensor))
self.assertEqual(result.shape, (5, 2))
# This case uses the **different** fake tensor modes to
# 1. create fake parameters and fake buffers, and
# 2. run FakeTensorProp
# The following code should fail.
failed = False
try:
FakeTensorProp(graph_model).propagate(value)
except AssertionError:
# AssertionError: tensor's device must be `meta`, got cpu instead
failed = True
self.assertTrue(failed)
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value, another_value=None, another_optional_value=None):
# Mimic huggingface's `forward` methods which have several optional arguments.
# For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
# To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
if another_value is None:
another_value = torch.rand_like(value)
if another_optional_value is None:
another_optional_value = torch.rand_like(value)
value = value + another_value + another_optional_value
return value * value
fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
with fake_mode:
model = OptionalArgumentInBetween()
value = torch.randn(5, 4)
another_optional_value = torch.randn(5, 4)
graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)
instantiate_parametrized_tests(FakeTensorTest)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)
if __name__ == "__main__":
run_tests()