mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[test][scan] refactor inductor test and prepare for adding bw tests (#161557)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161557 Approved by: https://github.com/zou3519
This commit is contained in:
parent
e78792a70d
commit
8f15d6a0c9
|
|
@ -1,5 +1,5 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import contextlib
|
||||
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
|
|
@ -1889,27 +1889,56 @@ class ScanTests(TestCase):
|
|||
inputs,
|
||||
device,
|
||||
dynamic,
|
||||
requires_grad=False,
|
||||
autograd=False,
|
||||
):
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
|
||||
model
|
||||
)
|
||||
import copy
|
||||
|
||||
inputs = [inp.requires_grad_(autograd) for inp in inputs]
|
||||
inputs = [inp.to(device=device) for inp in inputs]
|
||||
model = model.to(device=device)
|
||||
cloned_inputs = [inp.clone() for inp in inputs]
|
||||
grad_ctx = contextlib.nullcontext() if requires_grad else torch.no_grad()
|
||||
with grad_ctx:
|
||||
result = model(scan, *cloned_inputs)
|
||||
result_exp = model(_fake_scan, *cloned_inputs)
|
||||
for p in model.parameters():
|
||||
p.requires_grad_(autograd)
|
||||
|
||||
result_compiled = compiled_model(scan, *cloned_inputs)
|
||||
result_compiled_exp = compiled_model(_fake_scan, *cloned_inputs)
|
||||
model1 = copy.deepcopy(model)
|
||||
model2 = copy.deepcopy(model)
|
||||
model3 = copy.deepcopy(model)
|
||||
model4 = copy.deepcopy(model)
|
||||
torch.compile(fullgraph=True, dynamic=dynamic)(model)
|
||||
|
||||
self.assertEqual(result, result_exp)
|
||||
def _run_model(model, inputs):
|
||||
cloned_inputs = [
|
||||
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
|
||||
]
|
||||
fw_result = model(*cloned_inputs)
|
||||
loss = loss_fn(fw_result)
|
||||
if autograd:
|
||||
loss.backward()
|
||||
return (
|
||||
fw_result,
|
||||
loss,
|
||||
[
|
||||
inp.grad
|
||||
for inp in cloned_inputs
|
||||
if isinstance(inp, torch.Tensor)
|
||||
],
|
||||
{n: p.grad for n, p in model.named_parameters()},
|
||||
)
|
||||
else:
|
||||
return fw_result, loss
|
||||
|
||||
result_exp = _run_model(model1, [_fake_scan] + inputs)
|
||||
result_eager = _run_model(model2, [scan] + inputs)
|
||||
result_compiled = _run_model(
|
||||
torch.compile(fullgraph=True, dynamic=dynamic)(model3), [scan] + inputs
|
||||
)
|
||||
result_compiled_exp = _run_model(
|
||||
torch.compile(fullgraph=True, dynamic=dynamic)(model4),
|
||||
[_fake_scan] + inputs,
|
||||
)
|
||||
|
||||
self.assertEqual(result_exp, result_eager)
|
||||
self.assertEqual(result_exp, result_compiled)
|
||||
self.assertEqual(result_compiled, result_compiled_exp)
|
||||
self.assertEqual(result_exp, result_compiled_exp)
|
||||
|
||||
def _compare_result(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user