mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchfuzz] make pointwise subclasses defined torch_op_name (#166220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166220 Approved by: https://github.com/pianpwk ghstack dependencies: #166187, #166188
This commit is contained in:
parent
46d17e8871
commit
f2450798cd
|
|
@ -15,14 +15,8 @@ from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
|||
class MatrixMultiplyOperator(Operator):
|
||||
"""Base class for matrix multiplication operations."""
|
||||
|
||||
def __init__(self, name: str, torch_op: str):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self._torch_op = torch_op
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return self._torch_op
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
|
||||
|
|
@ -47,12 +41,6 @@ class MatrixMultiplyOperator(Operator):
|
|||
|
||||
def _get_compatible_dtype(self, output_dtype):
|
||||
"""Get a compatible dtype for matrix multiplication."""
|
||||
# For matrix multiplication, we need to be flexible with input dtypes
|
||||
# since earlier operations may have performed type promotion.
|
||||
# We'll let the fuzzer generate whatever dtypes result from earlier operations
|
||||
# and rely on the operation graph to ensure compatibility.
|
||||
# Return the output dtype as a starting point, but this may be overridden
|
||||
# by the actual tensor specs generated by the fuzzer.
|
||||
return [output_dtype, output_dtype]
|
||||
|
||||
|
||||
|
|
@ -60,9 +48,14 @@ class MMOperator(MatrixMultiplyOperator):
|
|||
"""Operator for matrix multiplication (torch.mm)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("mm", "torch.mm")
|
||||
super().__init__("mm")
|
||||
self.weight = 5.0
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.mm"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""MM requires exactly 2D tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
|
|
@ -96,7 +89,6 @@ class MMOperator(MatrixMultiplyOperator):
|
|||
# Choose a random inner dimension k
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
# First tensor: [m, k]
|
||||
|
|
@ -141,9 +133,14 @@ class AddmmOperator(MatrixMultiplyOperator):
|
|||
"""Operator for additive matrix multiplication (torch.addmm)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("addmm", "torch.addmm")
|
||||
super().__init__("addmm")
|
||||
self.weight = 5.0
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.addmm"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Addmm requires exactly 2D tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
|
|
@ -177,7 +174,6 @@ class AddmmOperator(MatrixMultiplyOperator):
|
|||
# Choose a random inner dimension k
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
# Bias tensor: [m, n] (same shape as output)
|
||||
|
|
@ -230,9 +226,14 @@ class BmmOperator(MatrixMultiplyOperator):
|
|||
"""Operator for batch matrix multiplication (torch.bmm)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("bmm", "torch.bmm")
|
||||
super().__init__("bmm")
|
||||
self.weight = 5.0
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.bmm"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Batch matrix multiply requires 3D tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
|
|
@ -266,7 +267,6 @@ class BmmOperator(MatrixMultiplyOperator):
|
|||
# Choose a random inner dimension k
|
||||
k = random.randint(1, 16)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
# First tensor: [b, m, k]
|
||||
|
|
@ -311,9 +311,14 @@ class MatmulOperator(MatrixMultiplyOperator):
|
|||
"""Operator for general matrix multiplication (torch.matmul)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("matmul", "torch.matmul")
|
||||
super().__init__("matmul")
|
||||
self.weight = 500.0
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.matmul"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Matmul can handle various tensor dimensions >= 1."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
|
|
@ -343,7 +348,6 @@ class MatmulOperator(MatrixMultiplyOperator):
|
|||
output_size = output_spec.size
|
||||
output_dims = len(output_size)
|
||||
|
||||
# Get compatible dtypes
|
||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||
|
||||
if output_dims == 1:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Tensor pointwise operator implementation."""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -17,16 +16,10 @@ from torchfuzz.type_promotion import (
|
|||
class PointwiseOperator(Operator):
|
||||
"""Base class for element-wise pointwise operations."""
|
||||
|
||||
def __init__(self, name: str, torch_op: str, symbol: str):
|
||||
def __init__(self, name: str, symbol: str):
|
||||
super().__init__(name)
|
||||
self._torch_op = torch_op
|
||||
self.symbol = symbol
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return self._torch_op
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Tensor pointwise operations can produce tensors but not scalars."""
|
||||
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
|
||||
|
|
@ -74,9 +67,7 @@ class PointwiseOperator(Operator):
|
|||
) -> str:
|
||||
"""Generate code for pointwise operation."""
|
||||
if len(input_names) == 2:
|
||||
return (
|
||||
f"{output_name} = {self._torch_op}({input_names[0]}, {input_names[1]})"
|
||||
)
|
||||
return f"{output_name} = {self.torch_op_name}({input_names[0]}, {input_names[1]})"
|
||||
else:
|
||||
# Chain operations using symbols for readability
|
||||
expr = f" {self.symbol} ".join(input_names)
|
||||
|
|
@ -87,26 +78,42 @@ class AddOperator(PointwiseOperator):
|
|||
"""Operator for element-wise addition."""
|
||||
|
||||
def __init__(self, weight: float = 1.0):
|
||||
super().__init__("add", "torch.add", "+")
|
||||
super().__init__("add", "+")
|
||||
self.weight = float(weight)
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.add"
|
||||
|
||||
|
||||
class MulOperator(PointwiseOperator):
|
||||
"""Operator for element-wise multiplication."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("mul", "torch.mul", "*")
|
||||
super().__init__("mul", "*")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.mul"
|
||||
|
||||
|
||||
class SubOperator(PointwiseOperator):
|
||||
"""Operator for element-wise subtraction."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("sub", "torch.sub", "-")
|
||||
super().__init__("sub", "-")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.sub"
|
||||
|
||||
|
||||
class DivOperator(PointwiseOperator):
|
||||
"""Operator for element-wise division."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("div", "torch.div", "/")
|
||||
super().__init__("div", "/")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.div"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user