mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add mutation checks for tensor inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79078 Approved by: https://github.com/davidberard98, https://github.com/Krovatkin
This commit is contained in:
parent
6bea742c10
commit
83c0a2bc38
|
|
@ -4,7 +4,7 @@ import os
|
|||
import sys
|
||||
import torch
|
||||
|
||||
from torch.testing._internal import schema_check_tensor
|
||||
from torch.testing._internal.schema_check_tensor import SchemaCheckTensor
|
||||
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
|
@ -18,21 +18,51 @@ if __name__ == '__main__':
|
|||
# Tests various schema checking functionalities.
|
||||
class TestSchemaCheck(JitTestCase):
|
||||
def setUp(self):
|
||||
schema_check_tensor.reset_cache()
|
||||
SchemaCheckTensor.reset_cache()
|
||||
|
||||
# Tests that SchemaCheckTensor records operator order with grad
|
||||
def test_schema_check_tensor_operator_order_grad(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
|
||||
self.assertEqual(["relu.default", "detach.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
|
||||
SchemaCheckTensor(x).relu().sin()
|
||||
self.assertEqual(["aten::relu", "aten::detach", "aten::sin"], SchemaCheckTensor.recorded_ops)
|
||||
|
||||
# Tests that SchemaCheckTensor records operator order without grad
|
||||
def test_schema_check_tensor_operator_order_no_grad(self):
|
||||
x = torch.rand((3, 3), requires_grad=False)
|
||||
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
|
||||
self.assertEqual(["relu.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
|
||||
SchemaCheckTensor(x).relu().sin()
|
||||
self.assertEqual(["aten::relu", "aten::sin"], SchemaCheckTensor.recorded_ops)
|
||||
|
||||
# Tests that SchemaCheckTensor wraps torch.Tensor
|
||||
def test_schema_check_tensor_functionality(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
self.assertEqual(x.relu().sin(), schema_check_tensor.SchemaCheckTensor(x).relu().sin().elem)
|
||||
self.assertEqual(x.relu().sin(), SchemaCheckTensor(x).relu().sin().elem)
|
||||
|
||||
# Tests that SchemaCheckTensor wraps torch.Tensor when an argument's default is overriden
|
||||
def test_schema_check_tensor_functionality_default_replaced(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
self.assertEqual(x.add(x, alpha=2), SchemaCheckTensor(x).add(SchemaCheckTensor(x), alpha=2).elem)
|
||||
|
||||
# Tests that SchemaCheckTensor wraps torch.Tensorwith a mutable op
|
||||
def test_schema_check_tensor_functionality_mutable_inputs(self):
|
||||
x = torch.rand((3, 3), requires_grad=False)
|
||||
y = torch.clone(x)
|
||||
x.sinh_()
|
||||
SchemaCheckTensor(y).sinh_()
|
||||
self.assertEqual(x, y)
|
||||
|
||||
# Tests that an exception is raised for a mismatching mutation
|
||||
def test_mutation_check_fail(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
|
||||
batch(SchemaCheckTensor(x))
|
||||
|
||||
# Tests that an exception is raised for a mismatching mutation over multiple ops
|
||||
def test_mutation_check_fail_multiple_operators(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
x.sinh_()
|
||||
x.tanh_()
|
||||
x.relu_()
|
||||
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
|
||||
batch(SchemaCheckTensor(x))
|
||||
|
|
|
|||
|
|
@ -1491,6 +1491,12 @@ void initJITBindings(PyObject* module) {
|
|||
})
|
||||
.def_property_readonly(
|
||||
"is_out", [](Argument& self) { return self.is_out(); })
|
||||
.def_property_readonly(
|
||||
"is_mutable",
|
||||
[](Argument& self) {
|
||||
const AliasInfo* aliasInfo = self.alias_info();
|
||||
return aliasInfo && aliasInfo->isWrite();
|
||||
})
|
||||
.def_property_readonly("kwarg_only", [](Argument& self) -> bool {
|
||||
return self.kwarg_only();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
schema_check_recorded_ops = []
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.testing._internal.jit_utils import clone_inputs
|
||||
|
||||
# This Tensor Subclass is used to verify op schemas
|
||||
# This Tensor currently:
|
||||
# - Records the called ops and appends to schema_check_records_ops
|
||||
# - Checks for mutations on all inputs
|
||||
|
||||
class SchemaCheckTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
recorded_ops = []
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
|
@ -30,6 +33,14 @@ class SchemaCheckTensor(torch.Tensor):
|
|||
r.elem = elem
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def reset_cache():
|
||||
SchemaCheckTensor.recorded_ops.clear()
|
||||
|
||||
@staticmethod
|
||||
def display_ops():
|
||||
print(*recorded_ops, sep=",")
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"SchemaCheckTensor({self.elem}, grad_fn={self.grad_fn})"
|
||||
|
|
@ -43,14 +54,27 @@ class SchemaCheckTensor(torch.Tensor):
|
|||
def wrap(e):
|
||||
return cls(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
global schema_check_recorded_ops
|
||||
schema_check_recorded_ops.append(func.__name__)
|
||||
def has_mutated(before, after):
|
||||
return not torch.equal(before, after) if isinstance(before, torch.Tensor) and isinstance(after, torch.Tensor) else False
|
||||
|
||||
SchemaCheckTensor.recorded_ops.append(func._schema.name)
|
||||
|
||||
arguments = normalize_function(
|
||||
func,
|
||||
tree_map(unwrap, args),
|
||||
tree_map(unwrap, kwargs),
|
||||
normalize_to_only_use_kwargs=True
|
||||
).kwargs
|
||||
|
||||
cloned_arguments = dict(zip(arguments.keys(), clone_inputs(arguments.values())))
|
||||
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
|
||||
for argument in func._schema.arguments:
|
||||
name = argument.name if argument.name != "self" else "input"
|
||||
if arguments.get(name) is not None:
|
||||
before = tree_flatten(arguments.get(name))[0]
|
||||
after = tree_flatten(cloned_arguments.get(name))[0]
|
||||
if (any([has_mutated(i, j) for i, j in zip(before, after)]) and not argument.is_mutable):
|
||||
raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated")
|
||||
|
||||
return tree_map(wrap, out)
|
||||
|
||||
def reset_cache():
|
||||
global schema_check_recorded_ops
|
||||
schema_check_recorded_ops.clear()
|
||||
|
||||
def display_ops():
|
||||
print(*schema_check_recorded_ops, sep=",")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user