mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torchfuzz] make pointwise subclasses defined torch_op_name (#166220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166220 Approved by: https://github.com/pianpwk ghstack dependencies: #166187, #166188
This commit is contained in:
parent
46d17e8871
commit
f2450798cd
|
|
@ -15,14 +15,8 @@ from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||||
class MatrixMultiplyOperator(Operator):
|
class MatrixMultiplyOperator(Operator):
|
||||||
"""Base class for matrix multiplication operations."""
|
"""Base class for matrix multiplication operations."""
|
||||||
|
|
||||||
def __init__(self, name: str, torch_op: str):
|
def __init__(self, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self._torch_op = torch_op
|
|
||||||
|
|
||||||
@property
|
|
||||||
def torch_op_name(self) -> Optional[str]:
|
|
||||||
"""Return the torch operation name."""
|
|
||||||
return self._torch_op
|
|
||||||
|
|
||||||
def can_produce(self, output_spec: Spec) -> bool:
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
|
"""Matrix multiply operations can produce float/complex tensors of dimension >= 2."""
|
||||||
|
|
@ -47,12 +41,6 @@ class MatrixMultiplyOperator(Operator):
|
||||||
|
|
||||||
def _get_compatible_dtype(self, output_dtype):
|
def _get_compatible_dtype(self, output_dtype):
|
||||||
"""Get a compatible dtype for matrix multiplication."""
|
"""Get a compatible dtype for matrix multiplication."""
|
||||||
# For matrix multiplication, we need to be flexible with input dtypes
|
|
||||||
# since earlier operations may have performed type promotion.
|
|
||||||
# We'll let the fuzzer generate whatever dtypes result from earlier operations
|
|
||||||
# and rely on the operation graph to ensure compatibility.
|
|
||||||
# Return the output dtype as a starting point, but this may be overridden
|
|
||||||
# by the actual tensor specs generated by the fuzzer.
|
|
||||||
return [output_dtype, output_dtype]
|
return [output_dtype, output_dtype]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,9 +48,14 @@ class MMOperator(MatrixMultiplyOperator):
|
||||||
"""Operator for matrix multiplication (torch.mm)."""
|
"""Operator for matrix multiplication (torch.mm)."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("mm", "torch.mm")
|
super().__init__("mm")
|
||||||
self.weight = 5.0
|
self.weight = 5.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> Optional[str]:
|
||||||
|
"""Return the torch operation name."""
|
||||||
|
return "torch.mm"
|
||||||
|
|
||||||
def can_produce(self, output_spec: Spec) -> bool:
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
"""MM requires exactly 2D tensors."""
|
"""MM requires exactly 2D tensors."""
|
||||||
if not isinstance(output_spec, TensorSpec):
|
if not isinstance(output_spec, TensorSpec):
|
||||||
|
|
@ -96,7 +89,6 @@ class MMOperator(MatrixMultiplyOperator):
|
||||||
# Choose a random inner dimension k
|
# Choose a random inner dimension k
|
||||||
k = random.randint(1, 16)
|
k = random.randint(1, 16)
|
||||||
|
|
||||||
# Get compatible dtypes
|
|
||||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||||
|
|
||||||
# First tensor: [m, k]
|
# First tensor: [m, k]
|
||||||
|
|
@ -141,9 +133,14 @@ class AddmmOperator(MatrixMultiplyOperator):
|
||||||
"""Operator for additive matrix multiplication (torch.addmm)."""
|
"""Operator for additive matrix multiplication (torch.addmm)."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("addmm", "torch.addmm")
|
super().__init__("addmm")
|
||||||
self.weight = 5.0
|
self.weight = 5.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> Optional[str]:
|
||||||
|
"""Return the torch operation name."""
|
||||||
|
return "torch.addmm"
|
||||||
|
|
||||||
def can_produce(self, output_spec: Spec) -> bool:
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
"""Addmm requires exactly 2D tensors."""
|
"""Addmm requires exactly 2D tensors."""
|
||||||
if not isinstance(output_spec, TensorSpec):
|
if not isinstance(output_spec, TensorSpec):
|
||||||
|
|
@ -177,7 +174,6 @@ class AddmmOperator(MatrixMultiplyOperator):
|
||||||
# Choose a random inner dimension k
|
# Choose a random inner dimension k
|
||||||
k = random.randint(1, 16)
|
k = random.randint(1, 16)
|
||||||
|
|
||||||
# Get compatible dtypes
|
|
||||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||||
|
|
||||||
# Bias tensor: [m, n] (same shape as output)
|
# Bias tensor: [m, n] (same shape as output)
|
||||||
|
|
@ -230,9 +226,14 @@ class BmmOperator(MatrixMultiplyOperator):
|
||||||
"""Operator for batch matrix multiplication (torch.bmm)."""
|
"""Operator for batch matrix multiplication (torch.bmm)."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("bmm", "torch.bmm")
|
super().__init__("bmm")
|
||||||
self.weight = 5.0
|
self.weight = 5.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> Optional[str]:
|
||||||
|
"""Return the torch operation name."""
|
||||||
|
return "torch.bmm"
|
||||||
|
|
||||||
def can_produce(self, output_spec: Spec) -> bool:
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
"""Batch matrix multiply requires 3D tensors."""
|
"""Batch matrix multiply requires 3D tensors."""
|
||||||
if not isinstance(output_spec, TensorSpec):
|
if not isinstance(output_spec, TensorSpec):
|
||||||
|
|
@ -266,7 +267,6 @@ class BmmOperator(MatrixMultiplyOperator):
|
||||||
# Choose a random inner dimension k
|
# Choose a random inner dimension k
|
||||||
k = random.randint(1, 16)
|
k = random.randint(1, 16)
|
||||||
|
|
||||||
# Get compatible dtypes
|
|
||||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||||
|
|
||||||
# First tensor: [b, m, k]
|
# First tensor: [b, m, k]
|
||||||
|
|
@ -311,9 +311,14 @@ class MatmulOperator(MatrixMultiplyOperator):
|
||||||
"""Operator for general matrix multiplication (torch.matmul)."""
|
"""Operator for general matrix multiplication (torch.matmul)."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("matmul", "torch.matmul")
|
super().__init__("matmul")
|
||||||
self.weight = 500.0
|
self.weight = 500.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> Optional[str]:
|
||||||
|
"""Return the torch operation name."""
|
||||||
|
return "torch.matmul"
|
||||||
|
|
||||||
def can_produce(self, output_spec: Spec) -> bool:
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
"""Matmul can handle various tensor dimensions >= 1."""
|
"""Matmul can handle various tensor dimensions >= 1."""
|
||||||
if not isinstance(output_spec, TensorSpec):
|
if not isinstance(output_spec, TensorSpec):
|
||||||
|
|
@ -343,7 +348,6 @@ class MatmulOperator(MatrixMultiplyOperator):
|
||||||
output_size = output_spec.size
|
output_size = output_spec.size
|
||||||
output_dims = len(output_size)
|
output_dims = len(output_size)
|
||||||
|
|
||||||
# Get compatible dtypes
|
|
||||||
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
dtypes = self._get_compatible_dtype(output_spec.dtype)
|
||||||
|
|
||||||
if output_dims == 1:
|
if output_dims == 1:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""Tensor pointwise operator implementation."""
|
"""Tensor pointwise operator implementation."""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -17,16 +16,10 @@ from torchfuzz.type_promotion import (
|
||||||
class PointwiseOperator(Operator):
|
class PointwiseOperator(Operator):
|
||||||
"""Base class for element-wise pointwise operations."""
|
"""Base class for element-wise pointwise operations."""
|
||||||
|
|
||||||
def __init__(self, name: str, torch_op: str, symbol: str):
|
def __init__(self, name: str, symbol: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self._torch_op = torch_op
|
|
||||||
self.symbol = symbol
|
self.symbol = symbol
|
||||||
|
|
||||||
@property
|
|
||||||
def torch_op_name(self) -> Optional[str]:
|
|
||||||
"""Return the torch operation name."""
|
|
||||||
return self._torch_op
|
|
||||||
|
|
||||||
def can_produce(self, output_spec: Spec) -> bool:
|
def can_produce(self, output_spec: Spec) -> bool:
|
||||||
"""Tensor pointwise operations can produce tensors but not scalars."""
|
"""Tensor pointwise operations can produce tensors but not scalars."""
|
||||||
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
|
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
|
||||||
|
|
@ -74,9 +67,7 @@ class PointwiseOperator(Operator):
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate code for pointwise operation."""
|
"""Generate code for pointwise operation."""
|
||||||
if len(input_names) == 2:
|
if len(input_names) == 2:
|
||||||
return (
|
return f"{output_name} = {self.torch_op_name}({input_names[0]}, {input_names[1]})"
|
||||||
f"{output_name} = {self._torch_op}({input_names[0]}, {input_names[1]})"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Chain operations using symbols for readability
|
# Chain operations using symbols for readability
|
||||||
expr = f" {self.symbol} ".join(input_names)
|
expr = f" {self.symbol} ".join(input_names)
|
||||||
|
|
@ -87,26 +78,42 @@ class AddOperator(PointwiseOperator):
|
||||||
"""Operator for element-wise addition."""
|
"""Operator for element-wise addition."""
|
||||||
|
|
||||||
def __init__(self, weight: float = 1.0):
|
def __init__(self, weight: float = 1.0):
|
||||||
super().__init__("add", "torch.add", "+")
|
super().__init__("add", "+")
|
||||||
self.weight = float(weight)
|
self.weight = float(weight)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> str:
|
||||||
|
return "torch.add"
|
||||||
|
|
||||||
|
|
||||||
class MulOperator(PointwiseOperator):
|
class MulOperator(PointwiseOperator):
|
||||||
"""Operator for element-wise multiplication."""
|
"""Operator for element-wise multiplication."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("mul", "torch.mul", "*")
|
super().__init__("mul", "*")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> str:
|
||||||
|
return "torch.mul"
|
||||||
|
|
||||||
|
|
||||||
class SubOperator(PointwiseOperator):
|
class SubOperator(PointwiseOperator):
|
||||||
"""Operator for element-wise subtraction."""
|
"""Operator for element-wise subtraction."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("sub", "torch.sub", "-")
|
super().__init__("sub", "-")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> str:
|
||||||
|
return "torch.sub"
|
||||||
|
|
||||||
|
|
||||||
class DivOperator(PointwiseOperator):
|
class DivOperator(PointwiseOperator):
|
||||||
"""Operator for element-wise division."""
|
"""Operator for element-wise division."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("div", "torch.div", "/")
|
super().__init__("div", "/")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_op_name(self) -> str:
|
||||||
|
return "torch.div"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user