mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Special path for cloning of torch dispatch tensors (#164081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164081 Approved by: https://github.com/tugsbayasgalan, https://github.com/mlazos
This commit is contained in:
parent
ace89350fc
commit
bbf6816f35
|
|
@ -87,6 +87,7 @@ from torch._utils_internal import (
|
|||
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
|
||||
from torch.monitor import _WaitCounter
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
|
@ -2146,6 +2147,10 @@ def clone_input(
|
|||
x.shape,
|
||||
layout=x.layout,
|
||||
)
|
||||
elif is_traceable_wrapper_subclass(x):
|
||||
# Questionable - but this is required to not fail executorch related
|
||||
# torchao tests.
|
||||
return torch_clone(x)
|
||||
|
||||
needed_size = sum(
|
||||
(shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user