mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ao][pruning] Replace assert statements with AssertionError exceptions (#164926)
Replace assert statement with explicit ValueError exception to ensure the validation check is not removed when Python runs with optimization flag (-O). This is a draft PR to confirm the process. Fixes partially #164878. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164926 Approved by: https://github.com/fffrog, https://github.com/albanD Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
This commit is contained in:
parent
a3fe1825aa
commit
fa560e1158
|
|
@ -128,13 +128,15 @@ class ActivationSparsifier:
|
|||
# if features are not None, then feature_dim must not be None
|
||||
features, feature_dim = args["features"], args["feature_dim"]
|
||||
if features is not None:
|
||||
assert feature_dim is not None, "need feature dim to select features"
|
||||
if feature_dim is None:
|
||||
raise AssertionError("need feature dim to select features")
|
||||
|
||||
# all the *_fns should be callable
|
||||
fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"]
|
||||
for key in fn_keys:
|
||||
fn = args[key]
|
||||
assert callable(fn), "function should be callable"
|
||||
if not callable(fn):
|
||||
raise AssertionError(f"{fn} must be callable")
|
||||
|
||||
def _aggregate_hook(self, name):
|
||||
"""Returns hook that computes aggregate of activations passing through."""
|
||||
|
|
@ -209,7 +211,8 @@ class ActivationSparsifier:
|
|||
- All the functions (fn) passed as argument will be called at a dim, feature level.
|
||||
"""
|
||||
name = module_to_fqn(self.model, layer)
|
||||
assert name is not None, "layer not found in the model" # satisfy mypy
|
||||
if name is None:
|
||||
raise AssertionError("layer not found in the model")
|
||||
|
||||
if name in self.data_groups: # unregister layer if already present
|
||||
warnings.warn(
|
||||
|
|
@ -261,14 +264,15 @@ class ActivationSparsifier:
|
|||
Hence, if get_mask() is called before model.forward(), an
|
||||
error will be raised.
|
||||
"""
|
||||
assert name is not None or layer is not None, (
|
||||
"Need at least name or layer obj to retrieve mask"
|
||||
)
|
||||
if name is None and layer is None:
|
||||
raise AssertionError("Need at least name or layer obj to retrieve mask")
|
||||
|
||||
if name is None:
|
||||
assert layer is not None
|
||||
if layer is None:
|
||||
raise AssertionError("layer must be provided when name is None")
|
||||
name = module_to_fqn(self.model, layer)
|
||||
assert name is not None, "layer not found in the specified model"
|
||||
if name is None:
|
||||
raise AssertionError("layer not found in the specified model")
|
||||
|
||||
if name not in self.state:
|
||||
raise ValueError("Error: layer with the given name not found")
|
||||
|
|
@ -451,7 +455,8 @@ class ActivationSparsifier:
|
|||
for name, config in self.data_groups.items():
|
||||
# fetch layer
|
||||
layer = fqn_to_module(self.model, name)
|
||||
assert layer is not None # satisfy mypy
|
||||
if layer is None:
|
||||
raise AssertionError(f"layer {name} not found in the model")
|
||||
|
||||
# if agg_mode is True, then layer in aggregate mode
|
||||
if "hook_state" in config and config["hook_state"] == "aggregate":
|
||||
|
|
|
|||
|
|
@ -91,9 +91,10 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
|||
4. By default, the config of the replaced data is used as config for the replacing data, unless something
|
||||
is specified in the config dictionary.
|
||||
"""
|
||||
assert type(data) in SUPPORTED_TYPES, (
|
||||
"specified data type not supported at the moment"
|
||||
)
|
||||
if type(data) not in SUPPORTED_TYPES:
|
||||
raise AssertionError(
|
||||
f"specified data type:{type(data)} not supported at the moment"
|
||||
)
|
||||
local_args = copy.deepcopy(self.defaults)
|
||||
local_args.update(config)
|
||||
weight = self._extract_weight(data)
|
||||
|
|
@ -116,9 +117,10 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
|||
|
||||
if reuse_mask:
|
||||
current_data = self.get_data(name=name)
|
||||
assert weight.shape == current_data.shape, (
|
||||
"to retain the old mask, the shape of the new data must be the same as the previous one"
|
||||
)
|
||||
if weight.shape != current_data.shape:
|
||||
raise AssertionError(
|
||||
"to retain the old mask, the shape of the new data must be the same as the previous one"
|
||||
)
|
||||
mask = self.get_mask(
|
||||
name=name
|
||||
) # reuse mask instead of creating a new one
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ class DataNormSparsifier(BaseDataSparsifier):
|
|||
if zeros_per_block is None:
|
||||
zeros_per_block = reduce(operator.mul, sparse_block_shape)
|
||||
|
||||
assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment"
|
||||
if norm not in ["L1", "L2"]:
|
||||
raise AssertionError("only L1 and L2 norm supported at the moment")
|
||||
|
||||
defaults = {
|
||||
"sparsity_level": sparsity_level,
|
||||
|
|
|
|||
|
|
@ -66,17 +66,20 @@ def post_training_sparse_quantize(
|
|||
|
||||
else:
|
||||
embedding_modules = []
|
||||
assert isinstance(select_embeddings, list), (
|
||||
"the embedding_modules must be a list of embedding modules"
|
||||
)
|
||||
if not isinstance(select_embeddings, list):
|
||||
raise AssertionError(
|
||||
"the embedding_modules must be a list of embedding modules"
|
||||
)
|
||||
for emb in select_embeddings:
|
||||
assert type(emb) in SUPPORTED_MODULES, (
|
||||
"the embedding_modules list must be an embedding or embedding bags"
|
||||
)
|
||||
if type(emb) not in SUPPORTED_MODULES:
|
||||
raise AssertionError(
|
||||
"the embedding_modules list must be an embedding or embedding bags"
|
||||
)
|
||||
fqn_name = module_to_fqn(model, emb)
|
||||
assert fqn_name is not None, (
|
||||
"the embedding modules must be part of input model"
|
||||
)
|
||||
if fqn_name is None:
|
||||
raise AssertionError(
|
||||
"the embedding modules must be part of input model"
|
||||
)
|
||||
embedding_modules.append((fqn_name, emb))
|
||||
|
||||
if sparsify_first:
|
||||
|
|
@ -114,7 +117,8 @@ def post_training_sparse_quantize(
|
|||
|
||||
for name, _ in embedding_modules:
|
||||
quantized_emb = fqn_to_module(model, name)
|
||||
assert quantized_emb is not None # satisfy mypy
|
||||
if quantized_emb is None:
|
||||
raise AssertionError(f"quantized embedding {name} not found in model")
|
||||
|
||||
quantized_weight = quantized_emb.weight() # type: ignore[operator]
|
||||
quantize_params["scales"][name] = quantized_weight.q_per_channel_scales()
|
||||
|
|
@ -138,7 +142,8 @@ def post_training_sparse_quantize(
|
|||
|
||||
for name, _ in embedding_modules:
|
||||
quantized_emb = fqn_to_module(model, name)
|
||||
assert quantized_emb is not None # satisfy mypy
|
||||
if quantized_emb is None:
|
||||
raise AssertionError(f"quantized embedding {name} not found in model")
|
||||
requantized_vector = torch.quantize_per_channel(
|
||||
quantize_params["dequant_weights"][name],
|
||||
scales=quantize_params["scales"][name],
|
||||
|
|
|
|||
|
|
@ -28,8 +28,12 @@ class FakeStructuredSparsity(nn.Module):
|
|||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, x):
|
||||
assert isinstance(self.mask, torch.Tensor)
|
||||
assert self.mask.shape[0] == x.shape[0]
|
||||
if not isinstance(self.mask, torch.Tensor):
|
||||
raise AssertionError("mask must be a torch.Tensor")
|
||||
if self.mask.shape[0] != x.shape[0]:
|
||||
raise AssertionError(
|
||||
f"mask shape[0] ({self.mask.shape[0]}) must match x shape[0] ({x.shape[0]})"
|
||||
)
|
||||
shape = [1] * len(x.shape)
|
||||
shape[0] = -1
|
||||
return self.mask.reshape(shape) * x
|
||||
|
|
|
|||
|
|
@ -332,9 +332,10 @@ def prune_conv2d_pool_flatten_linear(
|
|||
linear_ic = linear.weight.shape[1]
|
||||
|
||||
conv2d_oc = len(mask)
|
||||
assert linear_ic % conv2d_oc == 0, (
|
||||
f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported"
|
||||
)
|
||||
if linear_ic % conv2d_oc != 0:
|
||||
raise AssertionError(
|
||||
f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported"
|
||||
)
|
||||
|
||||
flatten_scale = linear_ic // conv2d_oc
|
||||
flattened_mask = torch.tensor(
|
||||
|
|
|
|||
|
|
@ -23,7 +23,10 @@ class SaliencyPruner(BaseStructuredSparsifier):
|
|||
"Structured pruning can only be applied to a 2+dim weight tensor!"
|
||||
)
|
||||
saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
|
||||
assert saliency.shape == mask.shape
|
||||
if saliency.shape != mask.shape:
|
||||
raise AssertionError(
|
||||
f"saliency shape ({saliency.shape}) must match mask shape ({mask.shape})"
|
||||
)
|
||||
|
||||
num_to_pick = int(len(mask) * kwargs["sparsity_level"])
|
||||
prune = saliency.topk(num_to_pick).indices
|
||||
|
|
|
|||
|
|
@ -149,7 +149,8 @@ class BaseSparsifier(abc.ABC):
|
|||
for _name, child in module.named_children():
|
||||
if type(child) in SUPPORTED_MODULES:
|
||||
module_fqn = module_to_fqn(model, child)
|
||||
assert isinstance(module_fqn, str) # for mypy
|
||||
if not isinstance(module_fqn, str):
|
||||
raise AssertionError("module_fqn must be a string")
|
||||
self.config.append({"tensor_fqn": module_fqn + ".weight"})
|
||||
else:
|
||||
stack.append(child)
|
||||
|
|
@ -172,20 +173,23 @@ class BaseSparsifier(abc.ABC):
|
|||
# TODO: Remove the configuration by reference ('module')
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for module_config in self.config:
|
||||
assert isinstance(module_config, dict), (
|
||||
"config elements should be dicts not modules i.e.:"
|
||||
"[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
|
||||
)
|
||||
if not isinstance(module_config, dict):
|
||||
raise AssertionError(
|
||||
"config elements should be dicts not modules i.e.:"
|
||||
"[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
|
||||
)
|
||||
|
||||
assert isinstance(self.defaults, dict) # for mypy
|
||||
if not isinstance(self.defaults, dict):
|
||||
raise AssertionError("defaults must be a dict")
|
||||
local_args = copy.deepcopy(self.defaults)
|
||||
local_args.update(module_config)
|
||||
|
||||
tensor_fqn = local_args.get("tensor_fqn", None)
|
||||
assert tensor_fqn is not None, (
|
||||
"tensor_fqn is a required argument in the sparsity config which"
|
||||
"replaces previous `module` and [module]`fqn` arguments"
|
||||
)
|
||||
if tensor_fqn is None:
|
||||
raise AssertionError(
|
||||
"tensor_fqn is a required argument in the sparsity config which"
|
||||
"replaces previous `module` and [module]`fqn` arguments"
|
||||
)
|
||||
|
||||
# populate all information from tensor_fqn
|
||||
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
|
||||
|
|
@ -194,16 +198,17 @@ class BaseSparsifier(abc.ABC):
|
|||
# from tensor_fqn
|
||||
for key in info_from_tensor_fqn.keys():
|
||||
if key in local_args:
|
||||
assert (
|
||||
if not (
|
||||
info_from_tensor_fqn[key] == local_args[key]
|
||||
or (
|
||||
key == "tensor_fqn"
|
||||
and "." + info_from_tensor_fqn[key] == local_args[key]
|
||||
)
|
||||
# info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
|
||||
), (
|
||||
f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
|
||||
)
|
||||
local_args.update(info_from_tensor_fqn)
|
||||
self.groups.append(local_args)
|
||||
self._prepare()
|
||||
|
|
|
|||
|
|
@ -53,9 +53,10 @@ def swap_module(
|
|||
# respect device affinity when swapping modules
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
|
||||
assert len(devices) <= 1, (
|
||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
)
|
||||
if len(devices) > 1:
|
||||
raise AssertionError(
|
||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
)
|
||||
device = next(iter(devices)) if len(devices) > 0 else None
|
||||
if device:
|
||||
new_mod.to(device)
|
||||
|
|
@ -129,7 +130,10 @@ class FakeSparsity(nn.Module):
|
|||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, x):
|
||||
assert self.mask.shape == x.shape
|
||||
if self.mask.shape != x.shape:
|
||||
raise AssertionError(
|
||||
f"mask shape ({self.mask.shape}) must match x shape ({x.shape})"
|
||||
)
|
||||
return self.mask * x
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -95,7 +95,8 @@ class WeightNormSparsifier(BaseSparsifier):
|
|||
):
|
||||
r"""Creates patches of size `block_shape` after scattering the indices."""
|
||||
if mask is None:
|
||||
assert input_shape is not None
|
||||
if input_shape is None:
|
||||
raise AssertionError("input_shape must be provided when mask is None")
|
||||
mask = torch.ones(input_shape, device=device)
|
||||
mask.scatter_(dim=dim, index=indices, value=0)
|
||||
mask.data = F.fold(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user