pytorch/torch/utils/benchmark/utils/fuzzer.py
Yuanyuan Chen 3cda34ebde [2/N] Apply ruff UP035 check in torch files (#164054)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
`TypeAlias` is also imported from `typing`.
This PR is the follow-up of #163947.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-09-29 03:35:32 +00:00

464 lines
18 KiB
Python

# mypy: allow-untyped-defs
import functools
import itertools as it
from typing import Any, Optional, Union
from collections.abc import Callable
import torch
__all__ = [
"Fuzzer",
"FuzzedParameter", "ParameterAlias",
"FuzzedTensor",
]
_DISTRIBUTIONS = (
"loguniform",
"uniform",
)
class FuzzedParameter:
"""Specification for a parameter to be generated during fuzzing."""
def __init__(
self,
name: str,
minval: Optional[Union[int, float]] = None,
maxval: Optional[Union[int, float]] = None,
distribution: Optional[Union[str, dict[Any, float]]] = None,
strict: bool = False,
):
"""
Args:
name:
A string name with which to identify the parameter.
FuzzedTensors can reference this string in their
specifications.
minval:
The lower bound for the generated value. See the description
of `distribution` for type behavior.
maxval:
The upper bound for the generated value. Type behavior is
identical to `minval`.
distribution:
Specifies the distribution from which this parameter should
be drawn. There are three possibilities:
- "loguniform"
Samples between `minval` and `maxval` (inclusive) such
that the probabilities are uniform in log space. As a
concrete example, if minval=1 and maxval=100, a sample
is as likely to fall in [1, 10) as it is [10, 100].
- "uniform"
Samples are chosen with uniform probability between
`minval` and `maxval` (inclusive). If either `minval`
or `maxval` is a float then the distribution is the
continuous uniform distribution; otherwise samples
are constrained to the integers.
- dict:
If a dict is passed, the keys are taken to be choices
for the variables and the values are interpreted as
probabilities. (And must sum to one.)
If a dict is passed, `minval` and `maxval` must not be set.
Otherwise, they must be set.
strict:
If a parameter is strict, it will not be included in the
iterative resampling process which Fuzzer uses to find a
valid parameter configuration. This allows an author to
prevent skew from resampling for a given parameter (for
instance, a low size limit could inadvertently bias towards
Tensors with fewer dimensions) at the cost of more iterations
when generating parameters.
"""
self._name = name
self._minval = minval
self._maxval = maxval
self._distribution = self._check_distribution(distribution)
self.strict = strict
@property
def name(self):
return self._name
def sample(self, state):
if self._distribution == "loguniform":
return self._loguniform(state)
if self._distribution == "uniform":
return self._uniform(state)
if isinstance(self._distribution, dict):
return self._custom_distribution(state)
def _check_distribution(self, distribution):
if not isinstance(distribution, dict):
assert distribution in _DISTRIBUTIONS
else:
assert not any(i < 0 for i in distribution.values()), "Probabilities cannot be negative"
assert abs(sum(distribution.values()) - 1) <= 1e-5, "Distribution is not normalized"
assert self._minval is None
assert self._maxval is None
return distribution
def _loguniform(self, state):
import numpy as np
output = int(2 ** state.uniform(
low=np.log2(self._minval) if self._minval is not None else None,
high=np.log2(self._maxval) if self._maxval is not None else None,
))
if self._minval is not None and output < self._minval:
return self._minval
if self._maxval is not None and output > self._maxval:
return self._maxval
return output
def _uniform(self, state):
if isinstance(self._minval, int) and isinstance(self._maxval, int):
return int(state.randint(low=self._minval, high=self._maxval + 1))
return state.uniform(low=self._minval, high=self._maxval)
def _custom_distribution(self, state):
import numpy as np
# If we directly pass the keys to `choice`, numpy will convert
# them to numpy dtypes.
index = state.choice(
np.arange(len(self._distribution)),
p=tuple(self._distribution.values()))
return list(self._distribution.keys())[index]
class ParameterAlias:
"""Indicates that a parameter should alias the value of another parameter.
When used in conjunction with a custom distribution, this allows fuzzed
tensors to represent a broader range of behaviors. For example, the
following sometimes produces Tensors which broadcast:
Fuzzer(
parameters=[
FuzzedParameter("x_len", 4, 1024, distribution="uniform"),
# `y` will either be size one, or match the size of `x`.
FuzzedParameter("y_len", distribution={
0.5: 1,
0.5: ParameterAlias("x_len")
}),
],
tensors=[
FuzzedTensor("x", size=("x_len",)),
FuzzedTensor("y", size=("y_len",)),
],
)
Chains of alias' are allowed, but may not contain cycles.
"""
def __init__(self, alias_to):
self.alias_to = alias_to
def __repr__(self):
return f"ParameterAlias[alias_to: {self.alias_to}]"
def dtype_size(dtype):
if dtype == torch.bool:
return 1
if dtype.is_floating_point or dtype.is_complex:
return int(torch.finfo(dtype).bits / 8)
return int(torch.iinfo(dtype).bits / 8)
def prod(values, base=1):
"""np.prod can overflow, so for sizes the product should be done in Python.
Even though np.prod type promotes to int64, it can still overflow in which
case the negative value will pass the size check and OOM when attempting to
actually allocate the Tensor.
"""
return functools.reduce(lambda x, y: int(x) * int(y), values, base)
class FuzzedTensor:
def __init__(
self,
name: str,
size: tuple[Union[str, int], ...],
steps: Optional[tuple[Union[str, int], ...]] = None,
probability_contiguous: float = 0.5,
min_elements: Optional[int] = None,
max_elements: Optional[int] = None,
max_allocation_bytes: Optional[int] = None,
dim_parameter: Optional[str] = None,
roll_parameter: Optional[str] = None,
dtype=torch.float32,
cuda=False,
tensor_constructor: Optional[Callable] = None
):
"""
Args:
name:
A string identifier for the generated Tensor.
size:
A tuple of integers or strings specifying the size of the generated
Tensor. String values will replaced with a concrete int during the
generation process, while ints are simply passed as literals.
steps:
An optional tuple with the same length as `size`. This indicates
that a larger Tensor should be allocated, and then sliced to
produce the generated Tensor. For instance, if size is (4, 8)
and steps is (1, 4), then a tensor `t` of size (4, 32) will be
created and then `t[:, ::4]` will be used. (Allowing one to test
Tensors with strided memory.)
probability_contiguous:
A number between zero and one representing the chance that the
generated Tensor has a contiguous memory layout. This is achieved by
randomly permuting the shape of a Tensor, calling `.contiguous()`,
and then permuting back. This is applied before `steps`, which can
also cause a Tensor to be non-contiguous.
min_elements:
The minimum number of parameters that this Tensor must have for a
set of parameters to be valid. (Otherwise they are resampled.)
max_elements:
Like `min_elements`, but setting an upper bound.
max_allocation_bytes:
Like `max_elements`, but for the size of Tensor that must be
allocated prior to slicing for `steps` (if applicable). For
example, a FloatTensor with size (1024, 1024) and steps (4, 4)
would have 1M elements, but would require a 64 MB allocation.
dim_parameter:
The length of `size` and `steps` will be truncated to this value.
This allows Tensors of varying dimensions to be generated by the
Fuzzer.
dtype:
The PyTorch dtype of the generated Tensor.
cuda:
Whether to place the Tensor on a GPU.
tensor_constructor:
Callable which will be used instead of the default Tensor
construction method. This allows the author to enforce properties
of the Tensor (e.g. it can only have certain values). The dtype and
concrete shape of the Tensor to be created will be passed, and
concrete values of all parameters will be passed as kwargs. Note
that transformations to the result (permuting, slicing) will be
performed by the Fuzzer; the tensor_constructor is only responsible
for creating an appropriately sized Tensor.
"""
self._name = name
self._size = size
self._steps = steps
self._probability_contiguous = probability_contiguous
self._min_elements = min_elements
self._max_elements = max_elements
self._max_allocation_bytes = max_allocation_bytes
self._dim_parameter = dim_parameter
self._dtype = dtype
self._cuda = cuda
self._tensor_constructor = tensor_constructor
@property
def name(self):
return self._name
@staticmethod
def default_tensor_constructor(size, dtype, **kwargs):
if dtype.is_floating_point or dtype.is_complex:
return torch.rand(size=size, dtype=dtype, device="cpu")
else:
return torch.randint(1, 127, size=size, dtype=dtype, device="cpu")
def _make_tensor(self, params, state):
import numpy as np
size, steps, allocation_size = self._get_size_and_steps(params)
constructor = (
self._tensor_constructor or
self.default_tensor_constructor
)
raw_tensor = constructor(size=allocation_size, dtype=self._dtype, **params)
if self._cuda:
raw_tensor = raw_tensor.cuda()
# Randomly permute the Tensor and call `.contiguous()` to force re-ordering
# of the memory, and then permute it back to the original shape.
dim = len(size)
order = np.arange(dim)
if state.rand() > self._probability_contiguous:
while dim > 1 and np.all(order == np.arange(dim)):
order = state.permutation(raw_tensor.dim())
raw_tensor = raw_tensor.permute(tuple(order)).contiguous()
raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
tensor = raw_tensor[tuple(slices)]
properties = {
"numel": int(tensor.numel()),
"order": order,
"steps": steps,
"is_contiguous": tensor.is_contiguous(),
"dtype": str(self._dtype),
}
return tensor, properties
def _get_size_and_steps(self, params):
dim = (
params[self._dim_parameter]
if self._dim_parameter is not None
else len(self._size)
)
def resolve(values, dim):
"""Resolve values into concrete integers."""
values = tuple(params.get(i, i) for i in values)
if len(values) > dim:
values = values[:dim]
if len(values) < dim:
values = values + tuple(1 for _ in range(dim - len(values)))
return values
size = resolve(self._size, dim)
steps = resolve(self._steps or (), dim)
allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps))
return size, steps, allocation_size
def satisfies_constraints(self, params):
size, _, allocation_size = self._get_size_and_steps(params)
# Product is computed in Python to avoid integer overflow.
num_elements = prod(size)
assert num_elements >= 0
allocation_bytes = prod(allocation_size, base=dtype_size(self._dtype))
def nullable_greater(left, right):
if left is None or right is None:
return False
return left > right
return not any((
nullable_greater(num_elements, self._max_elements),
nullable_greater(self._min_elements, num_elements),
nullable_greater(allocation_bytes, self._max_allocation_bytes),
))
class Fuzzer:
def __init__(
self,
parameters: list[Union[FuzzedParameter, list[FuzzedParameter]]],
tensors: list[Union[FuzzedTensor, list[FuzzedTensor]]],
constraints: Optional[list[Callable]] = None,
seed: Optional[int] = None
):
"""
Args:
parameters:
List of FuzzedParameters which provide specifications
for generated parameters. Iterable elements will be
unpacked, though arbitrary nested structures will not.
tensors:
List of FuzzedTensors which define the Tensors which
will be created each step based on the parameters for
that step. Iterable elements will be unpacked, though
arbitrary nested structures will not.
constraints:
List of callables. They will be called with params
as kwargs, and if any of them return False the current
set of parameters will be rejected.
seed:
Seed for the RandomState used by the Fuzzer. This will
also be used to set the PyTorch random seed so that random
ops will create reproducible Tensors.
"""
import numpy as np
if seed is None:
seed = int(np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64))
self._seed = seed
self._parameters = Fuzzer._unpack(parameters, FuzzedParameter)
self._tensors = Fuzzer._unpack(tensors, FuzzedTensor)
self._constraints = constraints or ()
p_names = {p.name for p in self._parameters}
t_names = {t.name for t in self._tensors}
name_overlap = p_names.intersection(t_names)
if name_overlap:
raise ValueError(f"Duplicate names in parameters and tensors: {name_overlap}")
self._rejections = 0
self._total_generated = 0
@staticmethod
def _unpack(values, cls):
return tuple(it.chain.from_iterable(
[[i] if isinstance(i, cls) else i for i in values]
))
def take(self, n):
import numpy as np
state = np.random.RandomState(self._seed)
torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64))
for _ in range(n):
params = self._generate(state)
tensors = {}
tensor_properties = {}
for t in self._tensors:
tensor, properties = t._make_tensor(params, state)
tensors[t.name] = tensor
tensor_properties[t.name] = properties
yield tensors, tensor_properties, params
@property
def rejection_rate(self):
if not self._total_generated:
return 0.
return self._rejections / self._total_generated
def _generate(self, state):
strict_params: dict[str, Union[float, int, ParameterAlias]] = {}
for _ in range(1000):
candidate_params: dict[str, Union[float, int, ParameterAlias]] = {}
for p in self._parameters:
if p.strict:
if p.name in strict_params:
candidate_params[p.name] = strict_params[p.name]
else:
candidate_params[p.name] = p.sample(state)
strict_params[p.name] = candidate_params[p.name]
else:
candidate_params[p.name] = p.sample(state)
candidate_params = self._resolve_aliases(candidate_params)
self._total_generated += 1
if not all(f(candidate_params) for f in self._constraints):
self._rejections += 1
continue
if not all(t.satisfies_constraints(candidate_params) for t in self._tensors):
self._rejections += 1
continue
return candidate_params
raise ValueError("Failed to generate a set of valid parameters.")
@staticmethod
def _resolve_aliases(params):
params = dict(params)
alias_count = sum(isinstance(v, ParameterAlias) for v in params.values())
keys = list(params.keys())
while alias_count:
for k in keys:
v = params[k]
if isinstance(v, ParameterAlias):
params[k] = params[v.alias_to]
alias_count_new = sum(isinstance(v, ParameterAlias) for v in params.values())
if alias_count == alias_count_new:
raise ValueError(f"ParameterAlias cycle detected\n{params}")
alias_count = alias_count_new
return params