fixing call_module on subscripting into generator (#81258)

named_modules() return a generator, which is not subscriptable and causes node support query to fail
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81258
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
jjsjann123 2022-07-14 16:41:18 +00:00 committed by PyTorch MergeBot
parent dd73c97ea2
commit cc67a92e74
2 changed files with 33 additions and 1 deletions

View File

@ -130,6 +130,38 @@ class TestFxNvFuserBackend(TestCase):
return inputs
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_nvfuser_call_module_backend(self, device, dtype):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
def forward(self, inp):
o = self.bn(inp)
o = self.relu(o)
return o
inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device)
m = Model().to(dtype=dtype, device=device)
# note that the traced module here contains only `call_module` node,
# which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error
traced = symbolic_trace(m)
nvfuser = NvFuserBackend()
compiled_module = nvfuser.compile(traced)
eager_result = m(inp)
nvfuser_result = compiled_module(inp)
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_nvfuser_backend(self, device, dtype):

View File

@ -90,7 +90,7 @@ class CapabilityBasedPartitioner:
logging.debug("Collecting supported nodes...")
supported_nodes = []
for node in self.graph_module.graph.nodes:
if self.operator_support.is_node_supported(self.graph_module.named_modules(), node):
if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node):
supported_nodes.append(node)
return supported_nodes