From 1982ec2d22c5145e4ffeec08064c5ca17e969c25 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Fri, 13 Jun 2025 18:10:40 +0800 Subject: [PATCH] Add api info for torch._C._nn.pyi (#148405) APis involved are as followed: - adaptive_avg_pool2d - adaptive_avg_pool3d - binary_cross_entropy - col2im ISSUE Related: https://github.com/pytorch/pytorch/issues/148404 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148405 Approved by: https://github.com/ezyang --- tools/pyi/gen_pyi.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 78018a3e080..ea8f08a86fa 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -412,6 +412,16 @@ def gen_nn_functional(fm: FileManager) -> None: "tuple[Tensor, Tensor]", ) ], + f"adaptive_avg_pool{d}d": [ + defs( + f"adaptive_avg_pool{d}d", + [ + INPUT, + "output_size: _int | _size", + ], + "Tensor", + ) + ], } ) @@ -516,6 +526,31 @@ def gen_nn_functional(fm: FileManager) -> None: "Tensor", ) ], + "binary_cross_entropy": [ + defs( + "binary_cross_entropy", + [ + INPUT, + "target: Tensor", + "weight: Tensor | None = None", + "reduction: str = ...", + ], + "Tensor", + ) + ], + "col2im": [ + defs( + "col2im", + [ + INPUT, + "output_size: _int | _size", + KERNEL_SIZE, + "dilation: _int | _size", + *STRIDE_PADDING, + ], + "Tensor", + ) + ], } )