from typing import Any import torch from ..decorators import substitute_in_graph @substitute_in_graph( # type: ignore[arg-type] torch.Tensor._make_subclass ) def make_subclass( cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any ) -> Any: with torch._C.DisableTorchFunctionSubclass(): # This is a rough approximation of `THPVariable_make_subclass`. It should # suffice for most of Dynamo tracing purposes. # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 assert len(kwargs) == 0, ( "_make_subclass only supports requires_grad as keyword arg" ) data = data.detach() # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. if data.requires_grad != requires_grad: data.requires_grad = requires_grad # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. if cls is torch.Tensor: return torch.Tensor(data) # Calling `as_subclass` because # 1. Dynamo knows how to handle it # 2. the C impls match at this point -- both `THPVariable_make_subclass` and # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. return data.as_subclass(cls) __all__ = [ "make_subclass", ]