mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Add check in test_cow_input to ensure COW data is never changed (#150723)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150723 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
24aadb40fb
commit
164d2c887b
|
|
@ -1825,6 +1825,7 @@ class TestCompositeCompliance(TestCase):
|
||||||
def check_cow_input(
|
def check_cow_input(
|
||||||
arg,
|
arg,
|
||||||
arg_copy,
|
arg_copy,
|
||||||
|
arg_raw,
|
||||||
idx_or_kw,
|
idx_or_kw,
|
||||||
backward_or_forward="forward",
|
backward_or_forward="forward",
|
||||||
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward,
|
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward,
|
||||||
|
|
@ -1837,6 +1838,13 @@ class TestCompositeCompliance(TestCase):
|
||||||
) + f" during {backward_or_forward} call"
|
) + f" during {backward_or_forward} call"
|
||||||
|
|
||||||
if is_strided_tensor(arg):
|
if is_strided_tensor(arg):
|
||||||
|
self.assertTrue(
|
||||||
|
torch._C._is_cow_tensor(arg_raw),
|
||||||
|
msg=(
|
||||||
|
f"{arg_name} raw input should remain COW, but it "
|
||||||
|
"unexpectedly materialized."
|
||||||
|
),
|
||||||
|
)
|
||||||
is_cow = torch._C._is_cow_tensor(arg)
|
is_cow = torch._C._is_cow_tensor(arg)
|
||||||
|
|
||||||
if supports_cow_input_no_materialize and not check_ignore_materialize(
|
if supports_cow_input_no_materialize and not check_ignore_materialize(
|
||||||
|
|
@ -1861,6 +1869,17 @@ class TestCompositeCompliance(TestCase):
|
||||||
"but the operation mutated its data."
|
"but the operation mutated its data."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
arg_raw, arg_copy, rtol=0, atol=0, equal_nan=True
|
||||||
|
),
|
||||||
|
msg=(
|
||||||
|
f"{arg_name} materialized, which is allowed in this "
|
||||||
|
"case, but the COW input data was mutated, which is "
|
||||||
|
"not allowed."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
args_raw = [sample.input] + list(sample.args)
|
args_raw = [sample.input] + list(sample.args)
|
||||||
|
|
@ -1901,10 +1920,10 @@ class TestCompositeCompliance(TestCase):
|
||||||
|
|
||||||
# Check that COW inputs remain COW after the forward op is executed
|
# Check that COW inputs remain COW after the forward op is executed
|
||||||
for idx, arg in enumerate(args):
|
for idx, arg in enumerate(args):
|
||||||
check_cow_input(arg, args_copy[idx], idx)
|
check_cow_input(arg, args_copy[idx], args_raw[idx], idx)
|
||||||
|
|
||||||
for kw, arg in kwargs.items():
|
for kw, arg in kwargs.items():
|
||||||
check_cow_input(arg, kwargs_copy[kw], kw)
|
check_cow_input(arg, kwargs_copy[kw], kwargs_raw[kw], kw)
|
||||||
|
|
||||||
# Call backward op if it is supported. This part of the test is
|
# Call backward op if it is supported. This part of the test is
|
||||||
# based on `composite_compliance.check_backward_formula`
|
# based on `composite_compliance.check_backward_formula`
|
||||||
|
|
@ -1954,6 +1973,7 @@ class TestCompositeCompliance(TestCase):
|
||||||
check_cow_input(
|
check_cow_input(
|
||||||
arg,
|
arg,
|
||||||
args_copy[idx],
|
args_copy[idx],
|
||||||
|
args_raw[idx],
|
||||||
idx,
|
idx,
|
||||||
backward_or_forward="backward",
|
backward_or_forward="backward",
|
||||||
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
|
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
|
||||||
|
|
@ -1965,6 +1985,7 @@ class TestCompositeCompliance(TestCase):
|
||||||
check_cow_input(
|
check_cow_input(
|
||||||
output_grad,
|
output_grad,
|
||||||
output_grads_copy[idx],
|
output_grads_copy[idx],
|
||||||
|
output_grads_raw[idx],
|
||||||
f"output grad {idx}",
|
f"output grad {idx}",
|
||||||
backward_or_forward="backward",
|
backward_or_forward="backward",
|
||||||
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
|
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user