[ao][pruning] Replace assert statements with AssertionError exceptions (#164926)

Replace assert statement with explicit ValueError exception to ensure the validation check is not removed when Python runs with optimization flag (-O).

This is a draft PR to confirm the process.

Fixes partially #164878.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164926
Approved by: https://github.com/fffrog, https://github.com/albanD

Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
This commit is contained in:
Rohit Singh Rathaur 2025-10-29 17:46:42 +00:00 committed by PyTorch MergeBot
parent a3fe1825aa
commit fa560e1158
10 changed files with 83 additions and 52 deletions

View File

@ -128,13 +128,15 @@ class ActivationSparsifier:
# if features are not None, then feature_dim must not be None
features, feature_dim = args["features"], args["feature_dim"]
if features is not None:
assert feature_dim is not None, "need feature dim to select features"
if feature_dim is None:
raise AssertionError("need feature dim to select features")
# all the *_fns should be callable
fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"]
for key in fn_keys:
fn = args[key]
assert callable(fn), "function should be callable"
if not callable(fn):
raise AssertionError(f"{fn} must be callable")
def _aggregate_hook(self, name):
"""Returns hook that computes aggregate of activations passing through."""
@ -209,7 +211,8 @@ class ActivationSparsifier:
- All the functions (fn) passed as argument will be called at a dim, feature level.
"""
name = module_to_fqn(self.model, layer)
assert name is not None, "layer not found in the model" # satisfy mypy
if name is None:
raise AssertionError("layer not found in the model")
if name in self.data_groups: # unregister layer if already present
warnings.warn(
@ -261,14 +264,15 @@ class ActivationSparsifier:
Hence, if get_mask() is called before model.forward(), an
error will be raised.
"""
assert name is not None or layer is not None, (
"Need at least name or layer obj to retrieve mask"
)
if name is None and layer is None:
raise AssertionError("Need at least name or layer obj to retrieve mask")
if name is None:
assert layer is not None
if layer is None:
raise AssertionError("layer must be provided when name is None")
name = module_to_fqn(self.model, layer)
assert name is not None, "layer not found in the specified model"
if name is None:
raise AssertionError("layer not found in the specified model")
if name not in self.state:
raise ValueError("Error: layer with the given name not found")
@ -451,7 +455,8 @@ class ActivationSparsifier:
for name, config in self.data_groups.items():
# fetch layer
layer = fqn_to_module(self.model, name)
assert layer is not None # satisfy mypy
if layer is None:
raise AssertionError(f"layer {name} not found in the model")
# if agg_mode is True, then layer in aggregate mode
if "hook_state" in config and config["hook_state"] == "aggregate":

View File

@ -91,9 +91,10 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
4. By default, the config of the replaced data is used as config for the replacing data, unless something
is specified in the config dictionary.
"""
assert type(data) in SUPPORTED_TYPES, (
"specified data type not supported at the moment"
)
if type(data) not in SUPPORTED_TYPES:
raise AssertionError(
f"specified data type:{type(data)} not supported at the moment"
)
local_args = copy.deepcopy(self.defaults)
local_args.update(config)
weight = self._extract_weight(data)
@ -116,9 +117,10 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
if reuse_mask:
current_data = self.get_data(name=name)
assert weight.shape == current_data.shape, (
"to retain the old mask, the shape of the new data must be the same as the previous one"
)
if weight.shape != current_data.shape:
raise AssertionError(
"to retain the old mask, the shape of the new data must be the same as the previous one"
)
mask = self.get_mask(
name=name
) # reuse mask instead of creating a new one

View File

@ -47,7 +47,8 @@ class DataNormSparsifier(BaseDataSparsifier):
if zeros_per_block is None:
zeros_per_block = reduce(operator.mul, sparse_block_shape)
assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment"
if norm not in ["L1", "L2"]:
raise AssertionError("only L1 and L2 norm supported at the moment")
defaults = {
"sparsity_level": sparsity_level,

View File

@ -66,17 +66,20 @@ def post_training_sparse_quantize(
else:
embedding_modules = []
assert isinstance(select_embeddings, list), (
"the embedding_modules must be a list of embedding modules"
)
if not isinstance(select_embeddings, list):
raise AssertionError(
"the embedding_modules must be a list of embedding modules"
)
for emb in select_embeddings:
assert type(emb) in SUPPORTED_MODULES, (
"the embedding_modules list must be an embedding or embedding bags"
)
if type(emb) not in SUPPORTED_MODULES:
raise AssertionError(
"the embedding_modules list must be an embedding or embedding bags"
)
fqn_name = module_to_fqn(model, emb)
assert fqn_name is not None, (
"the embedding modules must be part of input model"
)
if fqn_name is None:
raise AssertionError(
"the embedding modules must be part of input model"
)
embedding_modules.append((fqn_name, emb))
if sparsify_first:
@ -114,7 +117,8 @@ def post_training_sparse_quantize(
for name, _ in embedding_modules:
quantized_emb = fqn_to_module(model, name)
assert quantized_emb is not None # satisfy mypy
if quantized_emb is None:
raise AssertionError(f"quantized embedding {name} not found in model")
quantized_weight = quantized_emb.weight() # type: ignore[operator]
quantize_params["scales"][name] = quantized_weight.q_per_channel_scales()
@ -138,7 +142,8 @@ def post_training_sparse_quantize(
for name, _ in embedding_modules:
quantized_emb = fqn_to_module(model, name)
assert quantized_emb is not None # satisfy mypy
if quantized_emb is None:
raise AssertionError(f"quantized embedding {name} not found in model")
requantized_vector = torch.quantize_per_channel(
quantize_params["dequant_weights"][name],
scales=quantize_params["scales"][name],

View File

@ -28,8 +28,12 @@ class FakeStructuredSparsity(nn.Module):
self.register_buffer("mask", mask)
def forward(self, x):
assert isinstance(self.mask, torch.Tensor)
assert self.mask.shape[0] == x.shape[0]
if not isinstance(self.mask, torch.Tensor):
raise AssertionError("mask must be a torch.Tensor")
if self.mask.shape[0] != x.shape[0]:
raise AssertionError(
f"mask shape[0] ({self.mask.shape[0]}) must match x shape[0] ({x.shape[0]})"
)
shape = [1] * len(x.shape)
shape[0] = -1
return self.mask.reshape(shape) * x

View File

@ -332,9 +332,10 @@ def prune_conv2d_pool_flatten_linear(
linear_ic = linear.weight.shape[1]
conv2d_oc = len(mask)
assert linear_ic % conv2d_oc == 0, (
f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported"
)
if linear_ic % conv2d_oc != 0:
raise AssertionError(
f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported"
)
flatten_scale = linear_ic // conv2d_oc
flattened_mask = torch.tensor(

View File

@ -23,7 +23,10 @@ class SaliencyPruner(BaseStructuredSparsifier):
"Structured pruning can only be applied to a 2+dim weight tensor!"
)
saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
assert saliency.shape == mask.shape
if saliency.shape != mask.shape:
raise AssertionError(
f"saliency shape ({saliency.shape}) must match mask shape ({mask.shape})"
)
num_to_pick = int(len(mask) * kwargs["sparsity_level"])
prune = saliency.topk(num_to_pick).indices

View File

@ -149,7 +149,8 @@ class BaseSparsifier(abc.ABC):
for _name, child in module.named_children():
if type(child) in SUPPORTED_MODULES:
module_fqn = module_to_fqn(model, child)
assert isinstance(module_fqn, str) # for mypy
if not isinstance(module_fqn, str):
raise AssertionError("module_fqn must be a string")
self.config.append({"tensor_fqn": module_fqn + ".weight"})
else:
stack.append(child)
@ -172,20 +173,23 @@ class BaseSparsifier(abc.ABC):
# TODO: Remove the configuration by reference ('module')
# pyrefly: ignore [not-iterable]
for module_config in self.config:
assert isinstance(module_config, dict), (
"config elements should be dicts not modules i.e.:"
"[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
)
if not isinstance(module_config, dict):
raise AssertionError(
"config elements should be dicts not modules i.e.:"
"[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
)
assert isinstance(self.defaults, dict) # for mypy
if not isinstance(self.defaults, dict):
raise AssertionError("defaults must be a dict")
local_args = copy.deepcopy(self.defaults)
local_args.update(module_config)
tensor_fqn = local_args.get("tensor_fqn", None)
assert tensor_fqn is not None, (
"tensor_fqn is a required argument in the sparsity config which"
"replaces previous `module` and [module]`fqn` arguments"
)
if tensor_fqn is None:
raise AssertionError(
"tensor_fqn is a required argument in the sparsity config which"
"replaces previous `module` and [module]`fqn` arguments"
)
# populate all information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
@ -194,16 +198,17 @@ class BaseSparsifier(abc.ABC):
# from tensor_fqn
for key in info_from_tensor_fqn.keys():
if key in local_args:
assert (
if not (
info_from_tensor_fqn[key] == local_args[key]
or (
key == "tensor_fqn"
and "." + info_from_tensor_fqn[key] == local_args[key]
)
# info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
), (
f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
)
):
raise AssertionError(
f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
)
local_args.update(info_from_tensor_fqn)
self.groups.append(local_args)
self._prepare()

View File

@ -53,9 +53,10 @@ def swap_module(
# respect device affinity when swapping modules
# pyrefly: ignore [bad-argument-type]
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
assert len(devices) <= 1, (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
)
if len(devices) > 1:
raise AssertionError(
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
if device:
new_mod.to(device)
@ -129,7 +130,10 @@ class FakeSparsity(nn.Module):
self.register_buffer("mask", mask)
def forward(self, x):
assert self.mask.shape == x.shape
if self.mask.shape != x.shape:
raise AssertionError(
f"mask shape ({self.mask.shape}) must match x shape ({x.shape})"
)
return self.mask * x
def state_dict(self, *args, **kwargs):

View File

@ -95,7 +95,8 @@ class WeightNormSparsifier(BaseSparsifier):
):
r"""Creates patches of size `block_shape` after scattering the indices."""
if mask is None:
assert input_shape is not None
if input_shape is None:
raise AssertionError("input_shape must be provided when mask is None")
mask = torch.ones(input_shape, device=device)
mask.scatter_(dim=dim, index=indices, value=0)
mask.data = F.fold(