mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
fixing call_module on subscripting into generator (#81258)
named_modules() return a generator, which is not subscriptable and causes node support query to fail Pull Request resolved: https://github.com/pytorch/pytorch/pull/81258 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
parent
dd73c97ea2
commit
cc67a92e74
|
|
@ -130,6 +130,38 @@ class TestFxNvFuserBackend(TestCase):
|
|||
return inputs
|
||||
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_nvfuser_call_module_backend(self, device, dtype):
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, inp):
|
||||
o = self.bn(inp)
|
||||
o = self.relu(o)
|
||||
return o
|
||||
|
||||
inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device)
|
||||
m = Model().to(dtype=dtype, device=device)
|
||||
|
||||
# note that the traced module here contains only `call_module` node,
|
||||
# which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error
|
||||
traced = symbolic_trace(m)
|
||||
|
||||
nvfuser = NvFuserBackend()
|
||||
compiled_module = nvfuser.compile(traced)
|
||||
|
||||
eager_result = m(inp)
|
||||
nvfuser_result = compiled_module(inp)
|
||||
|
||||
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_nvfuser_backend(self, device, dtype):
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class CapabilityBasedPartitioner:
|
|||
logging.debug("Collecting supported nodes...")
|
||||
supported_nodes = []
|
||||
for node in self.graph_module.graph.nodes:
|
||||
if self.operator_support.is_node_supported(self.graph_module.named_modules(), node):
|
||||
if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node):
|
||||
supported_nodes.append(node)
|
||||
return supported_nodes
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user