Add mutation checks for tensor inputs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79078

Approved by: https://github.com/davidberard98, https://github.com/Krovatkin
This commit is contained in:
goldenxuett 2022-06-09 15:15:32 -07:00 committed by PyTorch MergeBot
parent 6bea742c10
commit 83c0a2bc38
3 changed files with 79 additions and 19 deletions

View File

@ -4,7 +4,7 @@ import os
import sys
import torch
from torch.testing._internal import schema_check_tensor
from torch.testing._internal.schema_check_tensor import SchemaCheckTensor
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
@ -18,21 +18,51 @@ if __name__ == '__main__':
# Tests various schema checking functionalities.
class TestSchemaCheck(JitTestCase):
def setUp(self):
schema_check_tensor.reset_cache()
SchemaCheckTensor.reset_cache()
# Tests that SchemaCheckTensor records operator order with grad
def test_schema_check_tensor_operator_order_grad(self):
x = torch.rand((3, 3), requires_grad=True)
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
self.assertEqual(["relu.default", "detach.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
SchemaCheckTensor(x).relu().sin()
self.assertEqual(["aten::relu", "aten::detach", "aten::sin"], SchemaCheckTensor.recorded_ops)
# Tests that SchemaCheckTensor records operator order without grad
def test_schema_check_tensor_operator_order_no_grad(self):
x = torch.rand((3, 3), requires_grad=False)
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
self.assertEqual(["relu.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
SchemaCheckTensor(x).relu().sin()
self.assertEqual(["aten::relu", "aten::sin"], SchemaCheckTensor.recorded_ops)
# Tests that SchemaCheckTensor wraps torch.Tensor
def test_schema_check_tensor_functionality(self):
x = torch.rand((3, 3), requires_grad=True)
self.assertEqual(x.relu().sin(), schema_check_tensor.SchemaCheckTensor(x).relu().sin().elem)
self.assertEqual(x.relu().sin(), SchemaCheckTensor(x).relu().sin().elem)
# Tests that SchemaCheckTensor wraps torch.Tensor when an argument's default is overriden
def test_schema_check_tensor_functionality_default_replaced(self):
x = torch.rand((3, 3), requires_grad=True)
self.assertEqual(x.add(x, alpha=2), SchemaCheckTensor(x).add(SchemaCheckTensor(x), alpha=2).elem)
# Tests that SchemaCheckTensor wraps torch.Tensorwith a mutable op
def test_schema_check_tensor_functionality_mutable_inputs(self):
x = torch.rand((3, 3), requires_grad=False)
y = torch.clone(x)
x.sinh_()
SchemaCheckTensor(y).sinh_()
self.assertEqual(x, y)
# Tests that an exception is raised for a mismatching mutation
def test_mutation_check_fail(self):
with self.assertRaises(RuntimeError):
x = torch.rand((3, 3), requires_grad=True)
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
batch(SchemaCheckTensor(x))
# Tests that an exception is raised for a mismatching mutation over multiple ops
def test_mutation_check_fail_multiple_operators(self):
with self.assertRaises(RuntimeError):
x = torch.rand((3, 3), requires_grad=True)
x.sinh_()
x.tanh_()
x.relu_()
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
batch(SchemaCheckTensor(x))

View File

@ -1491,6 +1491,12 @@ void initJITBindings(PyObject* module) {
})
.def_property_readonly(
"is_out", [](Argument& self) { return self.is_out(); })
.def_property_readonly(
"is_mutable",
[](Argument& self) {
const AliasInfo* aliasInfo = self.alias_info();
return aliasInfo && aliasInfo->isWrite();
})
.def_property_readonly("kwarg_only", [](Argument& self) -> bool {
return self.kwarg_only();
});

View File

@ -1,15 +1,18 @@
import torch
from torch.utils._pytree import tree_map
schema_check_recorded_ops = []
from torch.utils._pytree import tree_map, tree_flatten
from torch.fx.operator_schemas import normalize_function
from torch.testing._internal.jit_utils import clone_inputs
# This Tensor Subclass is used to verify op schemas
# This Tensor currently:
# - Records the called ops and appends to schema_check_records_ops
# - Checks for mutations on all inputs
class SchemaCheckTensor(torch.Tensor):
elem: torch.Tensor
recorded_ops = []
__slots__ = ['elem']
__torch_function__ = torch._C._disabled_torch_function_impl
@ -30,6 +33,14 @@ class SchemaCheckTensor(torch.Tensor):
r.elem = elem
return r
@staticmethod
def reset_cache():
SchemaCheckTensor.recorded_ops.clear()
@staticmethod
def display_ops():
print(*recorded_ops, sep=",")
def __repr__(self):
if self.grad_fn:
return f"SchemaCheckTensor({self.elem}, grad_fn={self.grad_fn})"
@ -43,14 +54,27 @@ class SchemaCheckTensor(torch.Tensor):
def wrap(e):
return cls(e) if isinstance(e, torch.Tensor) else e
global schema_check_recorded_ops
schema_check_recorded_ops.append(func.__name__)
def has_mutated(before, after):
return not torch.equal(before, after) if isinstance(before, torch.Tensor) and isinstance(after, torch.Tensor) else False
SchemaCheckTensor.recorded_ops.append(func._schema.name)
arguments = normalize_function(
func,
tree_map(unwrap, args),
tree_map(unwrap, kwargs),
normalize_to_only_use_kwargs=True
).kwargs
cloned_arguments = dict(zip(arguments.keys(), clone_inputs(arguments.values())))
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for argument in func._schema.arguments:
name = argument.name if argument.name != "self" else "input"
if arguments.get(name) is not None:
before = tree_flatten(arguments.get(name))[0]
after = tree_flatten(cloned_arguments.get(name))[0]
if (any([has_mutated(i, j) for i, j in zip(before, after)]) and not argument.is_mutable):
raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated")
return tree_map(wrap, out)
def reset_cache():
global schema_check_recorded_ops
schema_check_recorded_ops.clear()
def display_ops():
print(*schema_check_recorded_ops, sep=",")