mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is the continuation to #90134 and hopefully the final PR in this series. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271 Approved by: https://github.com/kit1980
141 lines
5.2 KiB
Python
141 lines
5.2 KiB
Python
import warnings
|
|
from typing import Callable, Union
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._ops import OpOverload
|
|
from torch._subclasses.fake_tensor import (
|
|
FakeTensorMode,
|
|
tree_flatten_only,
|
|
UnsupportedFakeTensorException,
|
|
)
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
def outputs_alias_inputs(outputs, inputs):
|
|
input_storages = {
|
|
inp._typed_storage()._cdata
|
|
for inp in tree_flatten_only(torch.Tensor, inputs)
|
|
if torch._C._has_storage(inp)
|
|
}
|
|
return any(
|
|
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
|
|
for out in tree_flatten_only(torch.Tensor, outputs)
|
|
)
|
|
|
|
|
|
def outputs_are_inputs(outputs, inputs):
|
|
input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
|
|
return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
|
|
|
|
|
|
def output_alias_each_other(outputs):
|
|
storages = set()
|
|
for out in tree_flatten_only(torch.Tensor, outputs):
|
|
if not torch._C._has_storage(out):
|
|
continue
|
|
stor = out._typed_storage()._cdata
|
|
if stor in storages:
|
|
return True
|
|
storages.add(stor)
|
|
return False
|
|
|
|
|
|
class CrossRefFakeMode(TorchDispatchMode):
|
|
def __init__(
|
|
self,
|
|
ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
|
|
*,
|
|
check_strides=True,
|
|
check_aliasing=True,
|
|
):
|
|
self.ignore_op_fn = (
|
|
ignore_op_fn if ignore_op_fn is not None else lambda fn: False
|
|
)
|
|
self.check_strides = check_strides
|
|
self.check_aliasing = check_aliasing
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
fake_r = None
|
|
|
|
# empty_like excluded for now due to sparse complex
|
|
# aten._to_dense.default this one is getting called with csc
|
|
if (
|
|
func
|
|
not in (
|
|
aten.lift_fresh.default,
|
|
aten.lift_fresh_copy.default,
|
|
aten.set_.source_Storage_storage_offset,
|
|
)
|
|
and not self.ignore_op_fn(func)
|
|
and torch.Tag.dynamic_output_shape not in func.tags # type: ignore[attr-defined]
|
|
and torch.Tag.inplace_view not in func.tags # type: ignore[attr-defined]
|
|
and torch.Tag.data_dependent_output not in func.tags # type: ignore[attr-defined]
|
|
):
|
|
try:
|
|
with FakeTensorMode() as fake_mode:
|
|
fake_args, fake_kwargs = pytree.tree_map_only(
|
|
torch.Tensor, fake_mode.from_tensor, (args, kwargs)
|
|
)
|
|
with warnings.catch_warnings():
|
|
fake_r = func(*fake_args, **fake_kwargs)
|
|
except UnsupportedFakeTensorException:
|
|
pass
|
|
|
|
r = func(*args, **kwargs)
|
|
if fake_r is not None:
|
|
r_flat, _ = tree_flatten(r)
|
|
f_flat, _ = tree_flatten(fake_r)
|
|
assert len(r_flat) == len(
|
|
r_flat
|
|
), f"Mismatch {len(r_flat)} != {len(r_flat)} on {func}"
|
|
|
|
if self.check_aliasing:
|
|
r_aliasing = outputs_alias_inputs(r, (args, kwargs))
|
|
f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
|
|
assert (
|
|
r_aliasing == f_aliasing
|
|
), f"Mismatch on {func}: {r_aliasing} != {f_aliasing}"
|
|
|
|
r_identity_eq = outputs_are_inputs(r, (args, kwargs))
|
|
f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
|
|
assert (
|
|
r_identity_eq == f_identity_eq
|
|
), f"Mismatch on {func}: {r_identity_eq} != {f_identity_eq}"
|
|
|
|
r_output_alias_each_other = output_alias_each_other(r)
|
|
f_output_alias_each_other = output_alias_each_other(fake_r)
|
|
assert (
|
|
r_output_alias_each_other == f_output_alias_each_other
|
|
), f"Mismatch on {func}: {r_output_alias_each_other} != {f_output_alias_each_other}"
|
|
|
|
for r_out, fake_out in zip(tree_flatten(r)[0], tree_flatten(fake_r)[0]):
|
|
r_is_ten = isinstance(r_out, torch.Tensor)
|
|
assert r_is_ten == isinstance(
|
|
fake_out, torch.Tensor
|
|
), f"Mismatched number of tensor outputs on {func}"
|
|
if r_is_ten:
|
|
assert (
|
|
r_out.requires_grad == fake_out.requires_grad
|
|
), f"Mismatch on {func}"
|
|
if torch._C._has_storage(r_out):
|
|
r_offset = r_out.storage_offset()
|
|
f_offset = fake_out.storage_offset()
|
|
assert (
|
|
r_offset == f_offset
|
|
), f"Mismatch on {func}: {r_offset} != {f_offset}"
|
|
|
|
try:
|
|
torch._prims.utils.compare_tensor_meta(
|
|
r_out, fake_out, check_strides=self.check_strides
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Mismatch on {func}: {e}") from e
|
|
return r
|