[torchfuzz] make pointwise subclasses defined torch_op_name (#166220)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166220
Approved by: https://github.com/pianpwk
ghstack dependencies: #166187, #166188
This commit is contained in:
bobrenjc93 2025-10-27 13:34:15 -07:00 committed by PyTorch MergeBot
parent 46d17e8871
commit f2450798cd
2 changed files with 47 additions and 36 deletions

View File

@ -15,14 +15,8 @@ from torchfuzz.tensor_fuzzer import Spec, TensorSpec
class MatrixMultiplyOperator(Operator):
"""Base class for matrix multiplication operations."""
def __init__(self, name: str, torch_op: str):
def __init__(self, name: str):
super().__init__(name)
self._torch_op = torch_op
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return self._torch_op
def can_produce(self, output_spec: Spec) -> bool:
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
@ -47,12 +41,6 @@ class MatrixMultiplyOperator(Operator):
def _get_compatible_dtype(self, output_dtype):
"""Get a compatible dtype for matrix multiplication."""
# For matrix multiplication, we need to be flexible with input dtypes
# since earlier operations may have performed type promotion.
# We'll let the fuzzer generate whatever dtypes result from earlier operations
# and rely on the operation graph to ensure compatibility.
# Return the output dtype as a starting point, but this may be overridden
# by the actual tensor specs generated by the fuzzer.
return [output_dtype, output_dtype]
@ -60,9 +48,14 @@ class MMOperator(MatrixMultiplyOperator):
"""Operator for matrix multiplication (torch.mm)."""
def __init__(self):
super().__init__("mm", "torch.mm")
super().__init__("mm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.mm"
def can_produce(self, output_spec: Spec) -> bool:
"""MM requires exactly 2D tensors."""
if not isinstance(output_spec, TensorSpec):
@ -96,7 +89,6 @@ class MMOperator(MatrixMultiplyOperator):
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# First tensor: [m, k]
@ -141,9 +133,14 @@ class AddmmOperator(MatrixMultiplyOperator):
"""Operator for additive matrix multiplication (torch.addmm)."""
def __init__(self):
super().__init__("addmm", "torch.addmm")
super().__init__("addmm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.addmm"
def can_produce(self, output_spec: Spec) -> bool:
"""Addmm requires exactly 2D tensors."""
if not isinstance(output_spec, TensorSpec):
@ -177,7 +174,6 @@ class AddmmOperator(MatrixMultiplyOperator):
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# Bias tensor: [m, n] (same shape as output)
@ -230,9 +226,14 @@ class BmmOperator(MatrixMultiplyOperator):
"""Operator for batch matrix multiplication (torch.bmm)."""
def __init__(self):
super().__init__("bmm", "torch.bmm")
super().__init__("bmm")
self.weight = 5.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.bmm"
def can_produce(self, output_spec: Spec) -> bool:
"""Batch matrix multiply requires 3D tensors."""
if not isinstance(output_spec, TensorSpec):
@ -266,7 +267,6 @@ class BmmOperator(MatrixMultiplyOperator):
# Choose a random inner dimension k
k = random.randint(1, 16)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
# First tensor: [b, m, k]
@ -311,9 +311,14 @@ class MatmulOperator(MatrixMultiplyOperator):
"""Operator for general matrix multiplication (torch.matmul)."""
def __init__(self):
super().__init__("matmul", "torch.matmul")
super().__init__("matmul")
self.weight = 500.0
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.matmul"
def can_produce(self, output_spec: Spec) -> bool:
"""Matmul can handle various tensor dimensions >= 1."""
if not isinstance(output_spec, TensorSpec):
@ -343,7 +348,6 @@ class MatmulOperator(MatrixMultiplyOperator):
output_size = output_spec.size
output_dims = len(output_size)
# Get compatible dtypes
dtypes = self._get_compatible_dtype(output_spec.dtype)
if output_dims == 1:

View File

@ -1,7 +1,6 @@
"""Tensor pointwise operator implementation."""
import random
from typing import Optional
import torch
@ -17,16 +16,10 @@ from torchfuzz.type_promotion import (
class PointwiseOperator(Operator):
"""Base class for element-wise pointwise operations."""
def __init__(self, name: str, torch_op: str, symbol: str):
def __init__(self, name: str, symbol: str):
super().__init__(name)
self._torch_op = torch_op
self.symbol = symbol
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return self._torch_op
def can_produce(self, output_spec: Spec) -> bool:
"""Tensor pointwise operations can produce tensors but not scalars."""
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
@ -74,9 +67,7 @@ class PointwiseOperator(Operator):
) -> str:
"""Generate code for pointwise operation."""
if len(input_names) == 2:
return (
f"{output_name} = {self._torch_op}({input_names[0]}, {input_names[1]})"
)
return f"{output_name} = {self.torch_op_name}({input_names[0]}, {input_names[1]})"
else:
# Chain operations using symbols for readability
expr = f" {self.symbol} ".join(input_names)
@ -87,26 +78,42 @@ class AddOperator(PointwiseOperator):
"""Operator for element-wise addition."""
def __init__(self, weight: float = 1.0):
super().__init__("add", "torch.add", "+")
super().__init__("add", "+")
self.weight = float(weight)
@property
def torch_op_name(self) -> str:
return "torch.add"
class MulOperator(PointwiseOperator):
"""Operator for element-wise multiplication."""
def __init__(self):
super().__init__("mul", "torch.mul", "*")
super().__init__("mul", "*")
@property
def torch_op_name(self) -> str:
return "torch.mul"
class SubOperator(PointwiseOperator):
"""Operator for element-wise subtraction."""
def __init__(self):
super().__init__("sub", "torch.sub", "-")
super().__init__("sub", "-")
@property
def torch_op_name(self) -> str:
return "torch.sub"
class DivOperator(PointwiseOperator):
"""Operator for element-wise division."""
def __init__(self):
super().__init__("div", "torch.div", "/")
super().__init__("div", "/")
@property
def torch_op_name(self) -> str:
return "torch.div"