mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75359 For some models in torchbench (e.g. pyhpc_isoneutral_mixing), dynamo will generate Fx graphs that has side effects . Those graphs may - return an empty tuple - change tensors passed in as forward method arguments in-place This makes the Dynamo+LTC integration fail since we extract compiled graph based on the lazy tensors returned from the forward method. From an empty tuple, we extract nothing. To solve this problem, we extract compile graph from `union(argument lazy tensors, returned lazy tensors)` instead. The inplace mutations applied to argument lazy tensors will be captured this way. Test Plan: ``` pytest test/lazy/test_extract_compiled_graph.py ``` ``` LTC_TS_CUDA=1 gpui time python torchbench.py --speedup-ltc -dcuda --nvfuser --randomize-input --only pyhpc_isoneutral_mixing ``` Reviewed By: ZolotukhinM Differential Revision: D35478799 Pulled By: shunting314 fbshipit-source-id: 8116768fc50fe7630e481e6039319ddf5c6a9416 (cherry picked from commit 2e6531d2c80c35ae99c11d49ca01dcdb7fc032f2)
196 lines
6.2 KiB
Python
196 lines
6.2 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import unittest
|
|
|
|
from torch._lazy.ts_backend import init as init_ts_backend
|
|
init_ts_backend()
|
|
from torch._lazy import config
|
|
from torch._lazy.extract_compiled_graph import extract_compiled_graph
|
|
import torch
|
|
from torch import nn
|
|
import dis
|
|
import inspect
|
|
from torch import fx
|
|
import re
|
|
from contextlib import contextmanager
|
|
import copy
|
|
|
|
class ModuleConstScale(nn.Module):
|
|
def __init__(self):
|
|
super(ModuleConstScale, self).__init__()
|
|
|
|
def forward(self, a):
|
|
return a * 2
|
|
|
|
class ModuleSub(nn.Module):
|
|
def __init__(self):
|
|
super(ModuleSub, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
return a - b
|
|
|
|
class ModuleAddcmul(nn.Module):
|
|
"""
|
|
addcmul function takes a at::Scalar which results in a special TSData containing a Scalar rather than a Tensor.
|
|
"""
|
|
def __init__(self):
|
|
super(ModuleAddcmul, self).__init__()
|
|
|
|
def forward(self, a, b, c):
|
|
return torch.addcmul(a, b, c, value=5)
|
|
|
|
class ModuleReturnMulti(nn.Module):
|
|
def __init__(self):
|
|
super(ModuleReturnMulti, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
return (b + 1, a - 1)
|
|
|
|
# The default fx tracer will convert torch.randn to a constant.. We may need
|
|
# a custom tracer.
|
|
# class ModuleEagerTensor(nn.Module):
|
|
# def __init__(self):
|
|
# super(ModuleEagerTensor, self).__init__()
|
|
#
|
|
# def forward(self, a):
|
|
# b = torch.randn(2, 3, device="cpu") # eager device
|
|
# return a + b
|
|
|
|
# The module was planned to cover the case that a Fx graph return an eager
|
|
# tensor on the default device. It's harder than ModuleEagerTensor because
|
|
# we can not just override the device argument to Lazy since there is no
|
|
# explicit device argument.
|
|
#
|
|
# Unfortunately, the default fx tracer convert the return value of the forward
|
|
# method to a constant.. Comment out for now
|
|
# class ModuleReturnEagerTensorOnDefaultDevice(nn.Module):
|
|
# def __init__(self):
|
|
# super(ModuleReturnEagerTensorOnDefaultDevice, self).__init__()
|
|
#
|
|
# def forward(self):
|
|
# return torch.tensor((2, 3), dtype=torch.float32)
|
|
|
|
class ModuleReturnDupTensor(nn.Module):
|
|
"""
|
|
Handle the corner case that the same tensor appears multiple times in the
|
|
returned tuple. torchbench like drq will hit this corner case when running
|
|
thru torchdynamo..
|
|
"""
|
|
def __init__(self):
|
|
super(ModuleReturnDupTensor, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
c = a + b
|
|
return a - b, c, a + 1, c
|
|
|
|
class ModuleInplaceUpdate(nn.Module):
|
|
def __init__(self):
|
|
super(ModuleInplaceUpdate, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
a.sub_(b)
|
|
return b - 1, b + 1
|
|
|
|
@contextmanager
|
|
def force_fallback_ctx_mgr(fallback_op):
|
|
oldconfig = config.get_force_fallback()
|
|
config.set_force_fallback(fallback_op)
|
|
try:
|
|
yield None
|
|
finally:
|
|
config.set_force_fallback(oldconfig)
|
|
|
|
@contextmanager
|
|
def nop_ctx_mgr():
|
|
try:
|
|
yield None
|
|
finally:
|
|
pass
|
|
|
|
def gen_rand_args(mod):
|
|
args = []
|
|
for _ in range(len(inspect.signature(mod.forward).parameters)):
|
|
args.append(torch.randn(2, 3))
|
|
return args
|
|
|
|
def allclose(expected, actual):
|
|
def unwrap(cont):
|
|
if isinstance(cont, (list, tuple)) and len(cont) == 1:
|
|
return cont[0]
|
|
return cont
|
|
expected = unwrap(expected)
|
|
actual = unwrap(actual)
|
|
|
|
if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor):
|
|
return torch.allclose(expected, actual)
|
|
elif isinstance(expected, (tuple, list)) and isinstance(actual, (tuple, list)):
|
|
return len(expected) == len(actual) and all(torch.allclose(a, b) for a, b in zip(expected, actual))
|
|
else:
|
|
raise RuntimeError("Unexpected types")
|
|
|
|
def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10):
|
|
args = gen_rand_args(mod)
|
|
out = mod(*args)
|
|
|
|
dis.dis(mod.forward)
|
|
|
|
try:
|
|
optimized_mod = extract_compiled_graph(fx.symbolic_trace(mod), args)
|
|
except RuntimeError as e:
|
|
if exception_msg_pattern is None:
|
|
raise e # reraise the exception
|
|
exception_message = str(e)
|
|
if not re.search(exception_msg_pattern, exception_message):
|
|
raise RuntimeError(f"Expection message does not match the required pattern: {exception_message}")
|
|
else:
|
|
# We are done for the test case that expects an exception
|
|
return
|
|
|
|
if exception_msg_pattern is not None:
|
|
raise RuntimeError(f"Expect an exception matching pattern {exception_msg_pattern}")
|
|
print("return value of optimized_mod", optimized_mod(*args))
|
|
|
|
# check correctness
|
|
failed_index = []
|
|
for i in range(ncase):
|
|
rand_args = gen_rand_args(mod)
|
|
rand_args_copy = copy.deepcopy(rand_args)
|
|
expected = mod(*rand_args)
|
|
actual = optimized_mod(*rand_args_copy)
|
|
|
|
if not allclose(expected, actual):
|
|
print(f"Incorrect results. expected {expected}, actual {actual}")
|
|
failed_index.append(i)
|
|
continue
|
|
|
|
# make sure arguments match after calling the model forward method to handle inplace
|
|
# updates.
|
|
if not allclose(rand_args, rand_args_copy):
|
|
print(f"Incorrect updated arguments. expected {rand_args}, actual {rand_args_copy}")
|
|
failed_index.append(i)
|
|
continue
|
|
|
|
if len(failed_index) > 0:
|
|
raise RuntimeError(f"Failed {len(failed_index)}/{ncase} cases")
|
|
|
|
def maketest(module_cls, exception_msg_pattern=None, ctxmgr=None):
|
|
def wrapper(self):
|
|
nonlocal ctxmgr
|
|
if not ctxmgr:
|
|
ctxmgr = nop_ctx_mgr()
|
|
with ctxmgr:
|
|
verify_reusing_compiled_graph(module_cls(), exception_msg_pattern)
|
|
|
|
return wrapper
|
|
|
|
class OptimizeTest(unittest.TestCase):
|
|
test_sub = maketest(ModuleSub)
|
|
# Same as test_sub but force aten::sub to fallback
|
|
# We expect an exception caught because of LTC fallabck.
|
|
test_ltc_fallback = maketest(ModuleSub, exception_msg_pattern="fallback.*aten::sub", ctxmgr=force_fallback_ctx_mgr("aten::sub"))
|
|
test_const_scale = maketest(ModuleConstScale)
|
|
test_addcmul = maketest(ModuleAddcmul)
|
|
test_return_multi = maketest(ModuleReturnMulti)
|
|
test_return_dup_tensor = maketest(ModuleReturnDupTensor)
|
|
test_inplace_update = maketest(ModuleInplaceUpdate)
|