pytorch/tools/experimental/torchfuzz/operators/base.py

80 lines
2.7 KiB
Python

"""Base operator implementation."""
from abc import ABC, abstractmethod
from typing import Optional
from torchfuzz.tensor_fuzzer import Spec
class Operator(ABC):
"""Base class for all operators in torchfuzz."""
def __init__(self, name: str, weight: float = 1.0):
"""Initialize operator with name and optional selection weight.
Args:
name: Unique operator name used in the registry
weight: Relative selection weight when sampling among compatible operators
(default 1.0). Higher values increase selection likelihood.
"""
self.name = name
self.weight: float = float(weight)
@property
@abstractmethod
def torch_op_name(self) -> Optional[str]:
"""
Return the torch operation name this operator represents.
Returns:
Optional[str]: The torch operation name (e.g., "torch.ops.aten.add", "torch.nonzero").
Returns None for non-torch operations like "arg" and "constant".
"""
raise NotImplementedError("Subclasses must implement torch_op_name")
@abstractmethod
def can_produce(self, output_spec: Spec) -> bool:
"""Check if this operator can produce the given output spec."""
raise NotImplementedError("Subclasses must implement can_produce()")
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""
Get input specifications for fuzzing.
Subclasses must implement this to return a list of input Specs that,
when used with this operator, can produce the given output_spec. Leaf
operators should return an empty list.
"""
raise NotImplementedError("Subclasses must implement fuzz_inputs_specs()")
@abstractmethod
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for this operation."""
raise NotImplementedError("Subclasses must implement codegen()")
def get_weight(
self,
*,
target_spec: Optional[Spec] = None,
depth: Optional[int] = None,
stack_size: Optional[int] = None,
template: Optional[str] = None,
) -> float:
"""
Return the selection weight for this operator.
Subclasses may override to implement context-sensitive weighting.
The default implementation returns the static attribute `self.weight`.
"""
return self.weight
def __str__(self) -> str:
"""String representation of the operator."""
return f"{self.__class__.__name__}({self.name})"
def __repr__(self) -> str:
"""Repr representation of the operator."""
return self.__str__()