pytorch/tools/experimental/torchfuzz/operators/registry.py

206 lines
6.5 KiB
Python

"""Operator registry for mapping operation names to operator instances."""
from typing import Optional
from torchfuzz.operators.arg import ArgOperator
from torchfuzz.operators.base import Operator
from torchfuzz.operators.constant import ConstantOperator
from torchfuzz.operators.item import ItemOperator
from torchfuzz.operators.layout import (
FlattenOperator,
ReshapeOperator,
SqueezeOperator,
UnsqueezeOperator,
ViewOperator,
)
from torchfuzz.operators.masked_select import MaskedSelectOperator
from torchfuzz.operators.matrix_multiply import (
AddmmOperator,
BmmOperator,
MatmulOperator,
MMOperator,
)
from torchfuzz.operators.nn_functional import (
BatchNormOperator,
DropoutOperator,
ELUOperator,
EmbeddingOperator,
GELUOperator,
GroupNormOperator,
LayerNormOperator,
LeakyReLUOperator,
LinearOperator,
ReLUOperator,
RMSNormOperator,
SigmoidOperator,
SiLUOperator,
SoftmaxOperator,
TanhOperator,
)
from torchfuzz.operators.nonzero import NonzeroOperator
from torchfuzz.operators.scalar_pointwise import (
ScalarAddOperator,
ScalarDivOperator,
ScalarMulOperator,
ScalarSubOperator,
)
from torchfuzz.operators.tensor_pointwise import (
AddOperator,
DivOperator,
MulOperator,
SubOperator,
)
from torchfuzz.operators.unique import UniqueOperator
class OperatorRegistry:
"""Registry for managing operator instances."""
def __init__(self):
"""Initialize the registry with default operators."""
self._operators: dict[str, Operator] = {}
self._register_default_operators()
def _register_default_operators(self):
"""Register the default set of operators."""
# Individual tensor pointwise operators (preferred)
self.register(AddOperator())
self.register(MulOperator())
self.register(SubOperator())
self.register(DivOperator())
# Individual scalar pointwise operators (preferred)
self.register(ScalarAddOperator())
self.register(ScalarMulOperator())
self.register(ScalarSubOperator())
self.register(ScalarDivOperator())
# Leaf Input operators
self.register(ConstantOperator())
self.register(ArgOperator())
# # Data-dependent operators
self.register(NonzeroOperator())
self.register(MaskedSelectOperator())
self.register(ItemOperator())
self.register(UniqueOperator())
# Tensor layout operators
self.register(ViewOperator())
self.register(ReshapeOperator())
self.register(FlattenOperator())
self.register(SqueezeOperator())
self.register(UnsqueezeOperator())
# Matrix multiplication operators
self.register(MMOperator())
self.register(AddmmOperator())
self.register(BmmOperator())
self.register(MatmulOperator())
# Neural network functional operators
self.register(EmbeddingOperator())
self.register(LinearOperator())
# Activation functions
self.register(ReLUOperator())
self.register(LeakyReLUOperator())
self.register(ELUOperator())
self.register(GELUOperator())
self.register(SiLUOperator())
self.register(SigmoidOperator())
self.register(TanhOperator())
self.register(SoftmaxOperator())
# Normalization layers
self.register(LayerNormOperator())
self.register(RMSNormOperator())
self.register(BatchNormOperator())
self.register(GroupNormOperator())
# Regularization
self.register(DropoutOperator())
def register(self, operator: Operator):
"""Register an operator in the registry."""
self._operators[operator.name] = operator
def get(self, op_name: str) -> Optional[Operator]:
"""Get an operator by name."""
# Handle special arg_ operations by mapping them to the ArgOperator
if op_name.startswith("arg_"):
return self._operators.get("arg")
return self._operators.get(op_name)
def list_operators(self) -> dict[str, Operator]:
"""List all registered operators."""
return self._operators.copy()
# Global registry instance
_global_registry = OperatorRegistry()
def get_operator(op_name: str) -> Optional[Operator]:
"""Get an operator from the global registry."""
return _global_registry.get(op_name)
def register_operator(operator: Operator):
"""Register an operator in the global registry."""
_global_registry.register(operator)
def list_operators() -> dict[str, Operator]:
"""List all operators in the global registry."""
return _global_registry.list_operators()
def set_operator_weight(op_name: str, weight: float) -> None:
"""Set the selection weight for a specific operator.
Args:
op_name: The registered operator name (e.g., "add", "arg") OR fully-qualified torch op
(e.g., "torch.nn.functional.relu", "torch.matmul")
weight: New relative selection weight (must be > 0)
"""
if weight <= 0:
raise ValueError("Operator weight must be > 0")
# Try by registry key
op = _global_registry.get(op_name)
if op is not None:
op.weight = float(weight)
return
# Fallback: try to locate by fully-qualified torch op name
for candidate in _global_registry.list_operators().values():
if getattr(candidate, "torch_op_name", None) == op_name:
candidate.weight = float(weight)
return
raise KeyError(f"Operator '{op_name}' not found by registry name or torch op name")
def set_operator_weights(weights: dict[str, float]) -> None:
"""Bulk-update operator weights from a mapping of name -> weight."""
for name, w in weights.items():
set_operator_weight(name, w)
def set_operator_weight_by_torch_op(torch_op_name: str, weight: float) -> None:
"""Set operator weight by fully-qualified torch op name."""
if weight <= 0:
raise ValueError("Operator weight must be > 0")
for candidate in _global_registry.list_operators().values():
if getattr(candidate, "torch_op_name", None) == torch_op_name:
candidate.weight = float(weight)
return
raise KeyError(f"Torch op '{torch_op_name}' not found in registry")
def set_operator_weights_by_torch_op(weights: dict[str, float]) -> None:
"""Bulk-update weights by fully-qualified torch op names."""
for name, w in weights.items():
set_operator_weight_by_torch_op(name, w)