mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FX] Add requires_grad to TensorMetadata (#60972)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60972 For PyTorch model memory requirement calculation, requires_grad is needed. Output tensors with requires_grad are saved in module context and increases memory during forward pass. Test Plan: Existing test cases Reviewed By: jamesr66a Differential Revision: D29024932 fbshipit-source-id: def990f8c6ff6fa4537bfc377c646b9d44464ebd
This commit is contained in:
parent
ce232e7847
commit
287c0ab170
|
|
@ -211,6 +211,7 @@ def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
|
|||
weight_dict: Dict[str, Dict] = {name: {}}
|
||||
weight_dict[name]["dtype"] = str(tensor.dtype)
|
||||
weight_dict[name]["shape"] = serialize_shape(tensor.shape)
|
||||
weight_dict[name]["requires_grad"] = str(tensor.requires_grad)
|
||||
weight_dict[name]["is_quantized"] = tensor.is_quantized
|
||||
weight_dict[name]["stride"] = serialize_stride(tensor.stride())
|
||||
|
||||
|
|
@ -303,6 +304,7 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
|
|||
node_rep = {
|
||||
"shape": serialize_shape(tensor_meta.shape),
|
||||
"dtype": str(tensor_meta.dtype),
|
||||
"requires_grad": str(tensor_meta.requires_grad),
|
||||
"stride": serialize_stride(tensor_meta.stride),
|
||||
"is_quantized": tensor_meta.is_quantized,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ class FxGraphDrawer:
|
|||
print("tm", tm)
|
||||
result += "|" + "dtype" + "=" + str(tm.dtype) + r"\l"
|
||||
result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\l"
|
||||
result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\l"
|
||||
result += "|" + "stride" + "=" + str(tm.stride) + r"\l"
|
||||
if tm.is_quantized:
|
||||
if tm.qscheme in {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class TensorMetadata(NamedTuple):
|
|||
# General Tensor metadata
|
||||
shape : torch.Size
|
||||
dtype : torch.dtype
|
||||
requires_grad : bool
|
||||
stride : Tuple[int]
|
||||
memory_format : Optional[torch.memory_format]
|
||||
|
||||
|
|
@ -25,6 +26,7 @@ def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
|
|||
"""
|
||||
shape = result.shape
|
||||
dtype = result.dtype
|
||||
requires_grad = result.requires_grad
|
||||
stride = result.stride()
|
||||
|
||||
memory_formats = {
|
||||
|
|
@ -54,7 +56,7 @@ def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
|
|||
|
||||
|
||||
return TensorMetadata(
|
||||
shape, dtype, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point)
|
||||
shape, dtype, requires_grad, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point)
|
||||
|
||||
|
||||
class ShapeProp(torch.fx.Interpreter):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user