[FX] Add requires_grad to TensorMetadata (#60972)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60972

For PyTorch model memory requirement calculation, requires_grad is needed. Output tensors with requires_grad are saved in module context and increases memory during forward pass.

Test Plan: Existing test cases

Reviewed By: jamesr66a

Differential Revision: D29024932

fbshipit-source-id: def990f8c6ff6fa4537bfc377c646b9d44464ebd
This commit is contained in:
Malay Bag 2021-06-29 23:04:07 -07:00 committed by Facebook GitHub Bot
parent ce232e7847
commit 287c0ab170
3 changed files with 6 additions and 1 deletions

View File

@ -211,6 +211,7 @@ def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict:
weight_dict: Dict[str, Dict] = {name: {}}
weight_dict[name]["dtype"] = str(tensor.dtype)
weight_dict[name]["shape"] = serialize_shape(tensor.shape)
weight_dict[name]["requires_grad"] = str(tensor.requires_grad)
weight_dict[name]["is_quantized"] = tensor.is_quantized
weight_dict[name]["stride"] = serialize_stride(tensor.stride())
@ -303,6 +304,7 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
node_rep = {
"shape": serialize_shape(tensor_meta.shape),
"dtype": str(tensor_meta.dtype),
"requires_grad": str(tensor_meta.requires_grad),
"stride": serialize_stride(tensor_meta.stride),
"is_quantized": tensor_meta.is_quantized,
}

View File

@ -164,6 +164,7 @@ class FxGraphDrawer:
print("tm", tm)
result += "|" + "dtype" + "=" + str(tm.dtype) + r"\l"
result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\l"
result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\l"
result += "|" + "stride" + "=" + str(tm.stride) + r"\l"
if tm.is_quantized:
if tm.qscheme in {

View File

@ -10,6 +10,7 @@ class TensorMetadata(NamedTuple):
# General Tensor metadata
shape : torch.Size
dtype : torch.dtype
requires_grad : bool
stride : Tuple[int]
memory_format : Optional[torch.memory_format]
@ -25,6 +26,7 @@ def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride()
memory_formats = {
@ -54,7 +56,7 @@ def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
return TensorMetadata(
shape, dtype, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point)
shape, dtype, requires_grad, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point)
class ShapeProp(torch.fx.Interpreter):