mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166221 Approved by: https://github.com/pianpwk ghstack dependencies: #166187, #166188, #166220, #166189, #166190
226 lines
7.2 KiB
Python
226 lines
7.2 KiB
Python
"""Operator registry for mapping operation names to operator instances."""
|
|
|
|
from typing import Optional
|
|
|
|
from torchfuzz.operators.arg import ArgOperator
|
|
from torchfuzz.operators.argsort import ArgsortOperator
|
|
from torchfuzz.operators.base import Operator
|
|
from torchfuzz.operators.constant import ConstantOperator
|
|
from torchfuzz.operators.gather import GatherOperator
|
|
from torchfuzz.operators.index_select import IndexSelectOperator
|
|
from torchfuzz.operators.item import ItemOperator
|
|
from torchfuzz.operators.layout import (
|
|
CatOperator,
|
|
ChunkOperator,
|
|
FlattenOperator,
|
|
ReshapeOperator,
|
|
SqueezeOperator,
|
|
StackOperator,
|
|
UnsqueezeOperator,
|
|
ViewOperator,
|
|
)
|
|
from torchfuzz.operators.masked_select import MaskedSelectOperator
|
|
from torchfuzz.operators.matrix_multiply import (
|
|
AddmmOperator,
|
|
BmmOperator,
|
|
MatmulOperator,
|
|
MMOperator,
|
|
)
|
|
from torchfuzz.operators.nn_functional import (
|
|
BatchNormOperator,
|
|
DropoutOperator,
|
|
ELUOperator,
|
|
EmbeddingOperator,
|
|
GELUOperator,
|
|
GroupNormOperator,
|
|
LayerNormOperator,
|
|
LeakyReLUOperator,
|
|
LinearOperator,
|
|
MultiHeadAttentionForwardOperator,
|
|
ReLUOperator,
|
|
RMSNormOperator,
|
|
ScaledDotProductAttentionOperator,
|
|
SigmoidOperator,
|
|
SiLUOperator,
|
|
SoftmaxOperator,
|
|
TanhOperator,
|
|
)
|
|
from torchfuzz.operators.nonzero import NonzeroOperator
|
|
from torchfuzz.operators.scalar_pointwise import (
|
|
ScalarAddOperator,
|
|
ScalarDivOperator,
|
|
ScalarMulOperator,
|
|
ScalarSubOperator,
|
|
)
|
|
from torchfuzz.operators.tensor_pointwise import (
|
|
AddOperator,
|
|
ClampOperator,
|
|
CumsumOperator,
|
|
DivOperator,
|
|
MulOperator,
|
|
SubOperator,
|
|
)
|
|
from torchfuzz.operators.unique import UniqueOperator
|
|
|
|
|
|
class OperatorRegistry:
|
|
"""Registry for managing operator instances."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the registry with default operators."""
|
|
self._operators: dict[str, Operator] = {}
|
|
self._register_default_operators()
|
|
|
|
def _register_default_operators(self):
|
|
"""Register the default set of operators."""
|
|
# Individual tensor pointwise operators (preferred)
|
|
self.register(AddOperator())
|
|
self.register(MulOperator())
|
|
self.register(SubOperator())
|
|
self.register(DivOperator())
|
|
self.register(ClampOperator())
|
|
self.register(CumsumOperator())
|
|
|
|
# Individual scalar pointwise operators (preferred)
|
|
self.register(ScalarAddOperator())
|
|
self.register(ScalarMulOperator())
|
|
self.register(ScalarSubOperator())
|
|
self.register(ScalarDivOperator())
|
|
|
|
# Leaf Input operators
|
|
self.register(ConstantOperator())
|
|
self.register(ArgOperator())
|
|
|
|
# # Data-dependent operators
|
|
self.register(NonzeroOperator())
|
|
self.register(MaskedSelectOperator())
|
|
self.register(GatherOperator())
|
|
self.register(IndexSelectOperator())
|
|
self.register(ArgsortOperator())
|
|
self.register(ItemOperator())
|
|
self.register(UniqueOperator())
|
|
|
|
# Tensor layout operators
|
|
self.register(ViewOperator())
|
|
self.register(ReshapeOperator())
|
|
self.register(FlattenOperator())
|
|
self.register(SqueezeOperator())
|
|
self.register(UnsqueezeOperator())
|
|
self.register(CatOperator())
|
|
self.register(StackOperator())
|
|
self.register(ChunkOperator())
|
|
|
|
# Matrix multiplication operators
|
|
self.register(MMOperator())
|
|
self.register(AddmmOperator())
|
|
self.register(BmmOperator())
|
|
self.register(MatmulOperator())
|
|
|
|
# Neural network functional operators
|
|
self.register(EmbeddingOperator())
|
|
self.register(LinearOperator())
|
|
self.register(ScaledDotProductAttentionOperator())
|
|
self.register(MultiHeadAttentionForwardOperator())
|
|
|
|
# Activation functions
|
|
self.register(ReLUOperator())
|
|
self.register(LeakyReLUOperator())
|
|
self.register(ELUOperator())
|
|
self.register(GELUOperator())
|
|
self.register(SiLUOperator())
|
|
self.register(SigmoidOperator())
|
|
self.register(TanhOperator())
|
|
self.register(SoftmaxOperator())
|
|
|
|
# Normalization layers
|
|
self.register(LayerNormOperator())
|
|
self.register(RMSNormOperator())
|
|
self.register(BatchNormOperator())
|
|
self.register(GroupNormOperator())
|
|
|
|
# Regularization
|
|
self.register(DropoutOperator())
|
|
|
|
def register(self, operator: Operator):
|
|
"""Register an operator in the registry."""
|
|
self._operators[operator.name] = operator
|
|
|
|
def get(self, op_name: str) -> Optional[Operator]:
|
|
"""Get an operator by name."""
|
|
# Handle special arg_ operations by mapping them to the ArgOperator
|
|
if op_name.startswith("arg_"):
|
|
return self._operators.get("arg")
|
|
return self._operators.get(op_name)
|
|
|
|
def list_operators(self) -> dict[str, Operator]:
|
|
"""List all registered operators."""
|
|
return self._operators.copy()
|
|
|
|
|
|
# Global registry instance
|
|
_global_registry = OperatorRegistry()
|
|
|
|
|
|
def get_operator(op_name: str) -> Optional[Operator]:
|
|
"""Get an operator from the global registry."""
|
|
return _global_registry.get(op_name)
|
|
|
|
|
|
def register_operator(operator: Operator):
|
|
"""Register an operator in the global registry."""
|
|
_global_registry.register(operator)
|
|
|
|
|
|
def list_operators() -> dict[str, Operator]:
|
|
"""List all operators in the global registry."""
|
|
return _global_registry.list_operators()
|
|
|
|
|
|
def set_operator_weight(op_name: str, weight: float) -> None:
|
|
"""Set the selection weight for a specific operator.
|
|
|
|
Args:
|
|
op_name: The registered operator name (e.g., "add", "arg") OR fully-qualified torch op
|
|
(e.g., "torch.nn.functional.relu", "torch.matmul")
|
|
weight: New relative selection weight (must be > 0)
|
|
"""
|
|
if weight <= 0:
|
|
raise ValueError("Operator weight must be > 0")
|
|
|
|
# Try by registry key
|
|
op = _global_registry.get(op_name)
|
|
if op is not None:
|
|
op.weight = float(weight)
|
|
return
|
|
|
|
# Fallback: try to locate by fully-qualified torch op name
|
|
for candidate in _global_registry.list_operators().values():
|
|
if getattr(candidate, "torch_op_name", None) == op_name:
|
|
candidate.weight = float(weight)
|
|
return
|
|
|
|
raise KeyError(f"Operator '{op_name}' not found by registry name or torch op name")
|
|
|
|
|
|
def set_operator_weights(weights: dict[str, float]) -> None:
|
|
"""Bulk-update operator weights from a mapping of name -> weight."""
|
|
for name, w in weights.items():
|
|
set_operator_weight(name, w)
|
|
|
|
|
|
def set_operator_weight_by_torch_op(torch_op_name: str, weight: float) -> None:
|
|
"""Set operator weight by fully-qualified torch op name."""
|
|
if weight <= 0:
|
|
raise ValueError("Operator weight must be > 0")
|
|
for candidate in _global_registry.list_operators().values():
|
|
if getattr(candidate, "torch_op_name", None) == torch_op_name:
|
|
candidate.weight = float(weight)
|
|
return
|
|
raise KeyError(f"Torch op '{torch_op_name}' not found in registry")
|
|
|
|
|
|
def set_operator_weights_by_torch_op(weights: dict[str, float]) -> None:
|
|
"""Bulk-update weights by fully-qualified torch op names."""
|
|
for name, w in weights.items():
|
|
set_operator_weight_by_torch_op(name, w)
|