[dynamo] Graph break on pack_padded_sequence (#108096)

This is to workaround #93501.

Fixes errors in:
```
./benchmarks/dynamo/torchbench.py --inference --performance --no-skip --inductor --freezing --only tacotron2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108096
Approved by: https://github.com/davidberard98
This commit is contained in:
Jason Ansel 2023-08-28 14:04:30 -07:00 committed by PyTorch MergeBot
parent d4ff06ec84
commit 73235d08c3
2 changed files with 4 additions and 1 deletions

View File

@ -655,6 +655,8 @@ class TorchVariable(VariableTracker):
return v
return torch.utils._pytree.tree_map(map_fn, tree)
elif self.value is torch.nn.utils.rnn.pack_padded_sequence:
unimplemented("workaround https://github.com/pytorch/pytorch/issues/93501")
elif isinstance(self.value, types.ModuleType):
unimplemented("TypeError(\"'module' object is not callable\")")
else:

View File

@ -340,6 +340,7 @@ class TritonTemplateKernel(TritonKernel):
def call_kernel(self, name: str):
wrapper = V.graph.wrapper_code
_, call_args, _ = self.args.python_argdefs()
call_args = [str(a) for a in call_args]
for i in range(len(call_args)):
if V.graph.is_unspec_arg(call_args[i]):
@ -354,7 +355,7 @@ class TritonTemplateKernel(TritonKernel):
device_index=V.graph.scheduler.current_device.index,
)
else:
call_args = ", ".join(call_args)
call_args = ", ".join(call_args) # type: ignore[assignment]
stream_name = wrapper.write_get_cuda_stream(
V.graph.scheduler.current_device.index
)