pytorch/torch/ao/pruning/scheduler/lambda_scheduler.py
Yuanyuan Chen a60d9e1f6d Fix flake8 B028 warnings (#166224)
This PR fixes flake8 B028 warning by specifying stacklevel=2 in `warnings.warn`. The advantage is that users can know more contextual information about PyTorch warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166224
Approved by: https://github.com/ezyang
2025-10-26 06:18:55 +00:00

66 lines
2.4 KiB
Python

import warnings
from collections.abc import Callable
from typing import Union
from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier
from .base_scheduler import BaseScheduler
__all__ = ["LambdaSL"]
class LambdaSL(BaseScheduler):
"""Sets the sparsity level of each parameter group to the final sl
times a given function. When last_epoch=-1, sets initial sl as zero.
Args:
sparsifier (BaseSparsifier): Wrapped sparsifier.
sl_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in sparsifier.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> # Assuming sparsifier has two groups.
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95**epoch
>>> # xdoctest: +SKIP
>>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(
self,
sparsifier: BaseSparsifier,
sl_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
last_epoch: int = -1,
verbose: bool = False,
) -> None:
self.sparsifier = sparsifier
if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
else:
if len(sl_lambda) != len(sparsifier.groups):
raise ValueError(
f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}"
)
self.sl_lambdas = list(sl_lambda)
super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call]
def get_sl(self) -> list[float]:
if not self._get_sl_called_within_step:
warnings.warn(
"To get the last sparsity level computed by the scheduler, "
"please use `get_last_sl()`.",
stacklevel=2,
)
return [
base_sl * lmbda(self.last_epoch)
for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)
]