mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Graph break on pack_padded_sequence (#108096)
This is to workaround #93501. Fixes errors in: ``` ./benchmarks/dynamo/torchbench.py --inference --performance --no-skip --inductor --freezing --only tacotron2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/108096 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
d4ff06ec84
commit
73235d08c3
|
|
@ -655,6 +655,8 @@ class TorchVariable(VariableTracker):
|
|||
return v
|
||||
|
||||
return torch.utils._pytree.tree_map(map_fn, tree)
|
||||
elif self.value is torch.nn.utils.rnn.pack_padded_sequence:
|
||||
unimplemented("workaround https://github.com/pytorch/pytorch/issues/93501")
|
||||
elif isinstance(self.value, types.ModuleType):
|
||||
unimplemented("TypeError(\"'module' object is not callable\")")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -340,6 +340,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
def call_kernel(self, name: str):
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _ = self.args.python_argdefs()
|
||||
call_args = [str(a) for a in call_args]
|
||||
|
||||
for i in range(len(call_args)):
|
||||
if V.graph.is_unspec_arg(call_args[i]):
|
||||
|
|
@ -354,7 +355,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
device_index=V.graph.scheduler.current_device.index,
|
||||
)
|
||||
else:
|
||||
call_args = ", ".join(call_args)
|
||||
call_args = ", ".join(call_args) # type: ignore[assignment]
|
||||
stream_name = wrapper.write_get_cuda_stream(
|
||||
V.graph.scheduler.current_device.index
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user