mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add FakeTensorMode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77972 Approved by: https://github.com/ezyang
This commit is contained in:
parent
4c18f362a9
commit
cea7dd1646
|
|
@ -1,11 +1,13 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
# Owner(s): ["module: meta tensors"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import torch
|
||||
import itertools
|
||||
from torch.testing._internal.jit_utils import RUN_CUDA
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
import unittest
|
||||
from torch._subclasses import FakeTensor
|
||||
|
||||
|
||||
class FakeTensorTest(TestCase):
|
||||
def test_basic(self):
|
||||
|
|
@ -44,11 +46,30 @@ class FakeTensorTest(TestCase):
|
|||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_type_as(self):
|
||||
x = FakeTensor.from_tensor(torch.rand([16, 1], device='cpu'))
|
||||
y = FakeTensor.from_tensor(torch.rand([4, 4], device='cuda'))
|
||||
x = FakeTensor.from_tensor(torch.rand([16, 1], device="cpu"))
|
||||
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
|
||||
out = x.type_as(y)
|
||||
self.assertEqual(out.device.type, "cuda")
|
||||
|
||||
def test_constructor(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode):
|
||||
x = torch.rand([4, 4], device="cpu")
|
||||
|
||||
self.assertTrue(isinstance(x, FakeTensor))
|
||||
self.assertTrue(x.device.type == "cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_fake_mode_non_fake_inputs(self):
|
||||
x = torch.tensor(0.1)
|
||||
y = torch.rand([4, 4], device="cuda")
|
||||
|
||||
with enable_torch_dispatch_mode(FakeTensorMode):
|
||||
out = x + y
|
||||
|
||||
self.assertTrue(isinstance(out, FakeTensor))
|
||||
self.assertTrue(out.device.type == "cuda")
|
||||
|
||||
|
||||
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
|
||||
return maybe_contained_type.isSubtypeOf(type) or any(
|
||||
contains_type(e, maybe_contained_type) for e in type.containedTypes()
|
||||
|
|
@ -56,12 +77,14 @@ def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
|
|||
|
||||
|
||||
class FakeTensorOperatorInvariants(TestCase):
|
||||
@staticmethod
|
||||
def get_aten_op(schema):
|
||||
namespace, name = schema.name.split("::")
|
||||
overload = schema.overload_name if schema.overload_name else "default"
|
||||
assert namespace == "aten"
|
||||
return getattr(getattr(torch.ops.aten, name), overload)
|
||||
|
||||
def test_non_kwarg_only_device(self):
|
||||
def get_op(schema):
|
||||
namespace, name = schema.name.split("::")
|
||||
overload = schema.overload_name if schema.overload_name else "default"
|
||||
assert namespace == "aten"
|
||||
return getattr(getattr(torch.ops.aten, name), overload)
|
||||
|
||||
for schema in torch._C._jit_get_all_schemas():
|
||||
namespace = schema.name.split("::")[0]
|
||||
|
|
@ -82,9 +105,29 @@ class FakeTensorOperatorInvariants(TestCase):
|
|||
)
|
||||
if has_non_kwarg_device:
|
||||
self.assertTrue(
|
||||
get_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
|
||||
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
|
||||
)
|
||||
|
||||
def test_tensor_constructors_all_have_kwarg_device(self):
|
||||
for schema in torch._C._jit_get_all_schemas():
|
||||
namespace = schema.name.split("::")[0]
|
||||
if namespace != "aten":
|
||||
continue
|
||||
|
||||
op = self.get_aten_op(schema)
|
||||
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
|
||||
continue
|
||||
|
||||
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
|
||||
has_kwarg_device = any(
|
||||
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
|
||||
for arg in schema.arguments
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -5,4 +5,5 @@ from torch._subclasses.fake_tensor import FakeTensor, _device_not_kwarg_ops
|
|||
__all__ = [
|
||||
"FakeTensor",
|
||||
"_device_not_kwarg_ops",
|
||||
"_is_tensor_constructor",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from functools import partial
|
|||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from typing import Union
|
||||
from torch._ops import OpOverload
|
||||
import functools
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
@ -20,6 +22,29 @@ _device_not_kwarg_ops = (
|
|||
aten._resize_output.out,
|
||||
)
|
||||
|
||||
# this op is never actually used
|
||||
_non_kwarg_device_constructors = (torch.ops.aten._list_to_tensor,)
|
||||
|
||||
|
||||
def contains_tensor_types(type):
|
||||
tensor_type = torch._C.TensorType.get()
|
||||
return type.isSubtypeOf(tensor_type) or any(
|
||||
contains_tensor_types(e) for e in type.containedTypes()
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _is_tensor_constructor(func: OpOverload):
|
||||
assert isinstance(func, OpOverload)
|
||||
schema = func._schema
|
||||
if any(contains_tensor_types(arg.type) for arg in schema.arguments):
|
||||
return False
|
||||
# TODO: no real reason to restrict multiple outputs
|
||||
return (
|
||||
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
|
||||
)
|
||||
|
||||
|
||||
# Meta tensors give you the ability to run PyTorch code without having to
|
||||
# actually do computation through tensors allocated on a `meta` device.
|
||||
# Because the device is `meta`, meta tensors do not model device propagation.
|
||||
|
|
@ -60,7 +85,20 @@ class FakeTensor(torch.Tensor):
|
|||
assert len(args) == 1 and isinstance(args[0], FakeTensor)
|
||||
return args[0].fake_device
|
||||
|
||||
# Run the original computation
|
||||
def wrap(e, device=None):
|
||||
if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor):
|
||||
if device:
|
||||
return FakeTensor(e, device)
|
||||
else:
|
||||
return FakeTensor.from_tensor(e)
|
||||
else:
|
||||
return e
|
||||
|
||||
# if we are in the dispatch mode, we will enter this function even if the inputs
|
||||
# are not FakeTensors, and they need to be wrapped
|
||||
if cls == FakeTensorMode:
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
# _to_copy fails when run with FakeTensors to cuda device
|
||||
# TODO: debug
|
||||
|
|
@ -71,22 +109,25 @@ class FakeTensor(torch.Tensor):
|
|||
out_device = new_kwargs.pop("device", new_kwargs["input"].device)
|
||||
with no_dispatch():
|
||||
input = new_kwargs.pop("input").to("meta")
|
||||
return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device)
|
||||
return FakeTensor(
|
||||
torch.ops.aten._to_copy(input, **new_kwargs), out_device
|
||||
)
|
||||
|
||||
if _is_tensor_constructor(func):
|
||||
assert func not in _non_kwarg_device_constructors
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
# cpu is default device if none is specified
|
||||
out_device = new_kwargs.pop("device", torch.device("cpu"))
|
||||
new_kwargs["device"] = torch.device("meta")
|
||||
r = super().__torch_dispatch__(func, types, (), new_kwargs)
|
||||
return FakeTensor(r, out_device)
|
||||
|
||||
r = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
def wrap(e, device):
|
||||
# inplace ops can return fake tensors
|
||||
if isinstance(e, torch.Tensor) and not isinstance(e, cls):
|
||||
return FakeTensor(e, device)
|
||||
else:
|
||||
return e
|
||||
|
||||
# TODO: handle non-kwarg devices
|
||||
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
|
||||
assert (
|
||||
func != aten._pin_memory.default and func != aten.pin_memory.default
|
||||
), f"NYI: {func}"
|
||||
|
||||
# if device is specified, use that
|
||||
if kwargs.get("device", None):
|
||||
|
|
@ -159,3 +200,6 @@ class FakeTensor(torch.Tensor):
|
|||
return common_device
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
class FakeTensorMode(FakeTensor):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user