mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
builtins.getattr is not serializable, so we replace it with a custom op that has more refined schema. Differential Revision: [D68899421](https://our.internmc.facebook.com/intern/diff/D68899421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145772 Approved by: https://github.com/bdhirsh
24 lines
742 B
Python
24 lines
742 B
Python
import torch
|
|
|
|
|
|
lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
|
|
|
|
lib.define(
|
|
"access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
|
|
)
|
|
|
|
|
|
@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
|
|
def _access_subclass_inner_tensor(
|
|
src_subclass_tensor: torch.Tensor, attr: str
|
|
) -> torch.Tensor:
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
|
assert is_traceable_wrapper_subclass(src_subclass_tensor)
|
|
val = getattr(src_subclass_tensor, attr, None)
|
|
if val is None or not isinstance(val, torch.Tensor):
|
|
raise RuntimeError(
|
|
f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
|
|
)
|
|
return val
|