mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[torchfuzz] split, chunk, stack, cat, expand, gather, cumsum, clamp, index_select, split (#166221)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166221 Approved by: https://github.com/pianpwk ghstack dependencies: #166187, #166188, #166220, #166189, #166190
This commit is contained in:
parent
7045aab143
commit
8af9ed0824
|
|
@ -205,12 +205,23 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
|||
"torch.sub",
|
||||
"torch.mul",
|
||||
"torch.div",
|
||||
"torch.clamp",
|
||||
"torch.cumsum",
|
||||
# Tensor shape operations
|
||||
"torch.Tensor.view",
|
||||
"torch.reshape",
|
||||
"torch.flatten",
|
||||
"torch.squeeze",
|
||||
"torch.unsqueeze",
|
||||
"torch.split",
|
||||
"torch.chunk",
|
||||
"torch.expand",
|
||||
"torch.cat",
|
||||
"torch.stack",
|
||||
# Indexing operations
|
||||
"torch.gather",
|
||||
"torch.index_select",
|
||||
"torch.argsort",
|
||||
# Matrix operations
|
||||
"torch.mm",
|
||||
"torch.addmm",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
"""Torchfuzz operators module."""
|
||||
|
||||
from torchfuzz.operators.arg import ArgOperator
|
||||
from torchfuzz.operators.argsort import ArgsortOperator
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.operators.constant import ConstantOperator
|
||||
from torchfuzz.operators.gather import GatherOperator
|
||||
from torchfuzz.operators.index_select import IndexSelectOperator
|
||||
from torchfuzz.operators.item import ItemOperator
|
||||
from torchfuzz.operators.layout import (
|
||||
CatOperator,
|
||||
ExpandOperator,
|
||||
FlattenOperator,
|
||||
ReshapeOperator,
|
||||
SplitOperator,
|
||||
SqueezeOperator,
|
||||
UnsqueezeOperator,
|
||||
ViewOperator,
|
||||
|
|
@ -45,6 +51,7 @@ from torchfuzz.operators.scalar_pointwise import (
|
|||
)
|
||||
from torchfuzz.operators.tensor_pointwise import (
|
||||
AddOperator,
|
||||
ClampOperator,
|
||||
DivOperator,
|
||||
MulOperator,
|
||||
PointwiseOperator,
|
||||
|
|
@ -59,6 +66,7 @@ __all__ = [
|
|||
"MulOperator",
|
||||
"SubOperator",
|
||||
"DivOperator",
|
||||
"ClampOperator",
|
||||
"ScalarPointwiseOperator",
|
||||
"ScalarAddOperator",
|
||||
"ScalarMulOperator",
|
||||
|
|
@ -67,11 +75,17 @@ __all__ = [
|
|||
"ItemOperator",
|
||||
"ConstantOperator",
|
||||
"ArgOperator",
|
||||
"ArgsortOperator",
|
||||
"GatherOperator",
|
||||
"IndexSelectOperator",
|
||||
"ViewOperator",
|
||||
"ReshapeOperator",
|
||||
"FlattenOperator",
|
||||
"SqueezeOperator",
|
||||
"UnsqueezeOperator",
|
||||
"SplitOperator",
|
||||
"ExpandOperator",
|
||||
"CatOperator",
|
||||
"MMOperator",
|
||||
"AddmmOperator",
|
||||
"BmmOperator",
|
||||
|
|
|
|||
73
tools/experimental/torchfuzz/operators/argsort.py
Normal file
73
tools/experimental/torchfuzz/operators/argsort.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Argsort operator implementation."""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import fuzz_valid_stride, Spec, TensorSpec
|
||||
|
||||
|
||||
class ArgsortOperator(Operator):
|
||||
"""Operator for torch.argsort() operation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ArgsortOperator."""
|
||||
super().__init__("argsort")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.argsort"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Argsort can produce tensor outputs with integer dtype (long)."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# argsort returns indices, so it must be integer type (long)
|
||||
return output_spec.dtype == torch.long and len(output_spec.size) > 0
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input spec for argsort operation.
|
||||
|
||||
torch.argsort(input, dim=-1, descending=False) returns a tensor with:
|
||||
- Same shape as input
|
||||
- dtype is torch.long (indices)
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ArgsortOperator can only produce TensorSpec outputs")
|
||||
|
||||
# Input tensor has the same shape as output but can have any numeric dtype
|
||||
input_size = output_spec.size
|
||||
|
||||
# Generate a valid stride for the input
|
||||
input_stride = fuzz_valid_stride(input_size)
|
||||
|
||||
# Choose a random float dtype for input (argsort works on numeric types)
|
||||
# Using float32 as a reasonable default
|
||||
input_dtype = torch.float32
|
||||
|
||||
return [TensorSpec(size=input_size, stride=input_stride, dtype=input_dtype)]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for argsort operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ArgsortOperator can only produce TensorSpec outputs")
|
||||
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("ArgsortOperator requires exactly one input")
|
||||
|
||||
# Randomly choose a dimension to sort along
|
||||
# Default to -1 (last dimension) as it's most common
|
||||
if len(output_spec.size) > 1:
|
||||
dim = random.randint(-len(output_spec.size), len(output_spec.size) - 1)
|
||||
else:
|
||||
dim = 0
|
||||
|
||||
# Randomly choose ascending or descending order
|
||||
descending = random.choice([True, False])
|
||||
|
||||
return f"{output_name} = torch.argsort({input_names[0]}, dim={dim}, descending={descending})"
|
||||
105
tools/experimental/torchfuzz/operators/gather.py
Normal file
105
tools/experimental/torchfuzz/operators/gather.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
class GatherOperator(Operator):
|
||||
"""Operator for gathering values along an axis specified by dim using indices."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("gather")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.gather"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Gather can produce tensors of various shapes, but not 0-dimensional tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# gather requires at least one dimension
|
||||
return len(output_spec.size) > 0
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 2) -> list[Spec]:
|
||||
"""Generate input specs for gather operation.
|
||||
|
||||
torch.gather(input, dim, index) returns a tensor with:
|
||||
- output.shape == index.shape
|
||||
- output[i][j][k] = input[i][j][index[i][j][k]] (for dim=2 example)
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("GatherOperator can only produce TensorSpec outputs")
|
||||
|
||||
# The output shape matches the index shape
|
||||
output_size = output_spec.size
|
||||
dim = 0 # Gather along dimension 0 for simplicity
|
||||
|
||||
# Input tensor - create a shape that matches output except for the gather dimension
|
||||
# which can be any size >= max(indices) + 1
|
||||
# For simplicity, make input larger in the gather dimension
|
||||
if len(output_size) == 1:
|
||||
# Output is 1D
|
||||
input_size = (output_size[0] + 2,)
|
||||
input_stride = (1,)
|
||||
elif len(output_size) == 2:
|
||||
# Output is 2D, make input 2D with first dim larger
|
||||
input_size = (output_size[0] + 2, output_size[1])
|
||||
input_stride = (output_size[1], 1) # Contiguous
|
||||
else:
|
||||
# For higher dimensions
|
||||
input_size = tuple(
|
||||
s + 2 if i == dim else s for i, s in enumerate(output_size)
|
||||
)
|
||||
# Contiguous stride
|
||||
input_stride = tuple(
|
||||
int(torch.tensor(input_size[i + 1 :]).prod().item())
|
||||
if i < len(input_size) - 1
|
||||
else 1
|
||||
for i in range(len(input_size))
|
||||
)
|
||||
|
||||
input_tensor_spec = TensorSpec(
|
||||
size=input_size,
|
||||
stride=input_stride,
|
||||
dtype=output_spec.dtype,
|
||||
)
|
||||
|
||||
# Index tensor - same shape as output, long dtype
|
||||
index_spec = TensorSpec(
|
||||
size=output_size,
|
||||
stride=tuple(
|
||||
int(torch.tensor(output_size[i + 1 :]).prod().item())
|
||||
if i < len(output_size) - 1
|
||||
else 1
|
||||
for i in range(len(output_size))
|
||||
),
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
return [input_tensor_spec, index_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for gather.
|
||||
|
||||
Creates appropriate indices to gather from the input tensor.
|
||||
"""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("GatherOperator requires exactly two inputs")
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("GatherOperator requires TensorSpec output")
|
||||
|
||||
# Determine dimension
|
||||
dim = 0 # Gather along dimension 0 for simplicity
|
||||
|
||||
# Generate code that creates valid indices within the input tensor's dimension
|
||||
return (
|
||||
f"_input_size_{output_name} = {input_names[0]}.size({dim})\n"
|
||||
f"_index_{output_name} = torch.randint(0, _input_size_{output_name}, {output_spec.size}, device={input_names[0]}.device)\n"
|
||||
f"{output_name} = torch.gather({input_names[0]}, {dim}, _index_{output_name})"
|
||||
)
|
||||
101
tools/experimental/torchfuzz/operators/index_select.py
Normal file
101
tools/experimental/torchfuzz/operators/index_select.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
||||
|
||||
|
||||
class IndexSelectOperator(Operator):
|
||||
"""Operator for selecting elements from a tensor along a dimension using indices."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("index_select")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.index_select"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Index select can produce tensors of various shapes, but not 0-dimensional tensors."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# index_select requires at least one dimension to select from
|
||||
return len(output_spec.size) > 0
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 3) -> list[Spec]:
|
||||
"""Generate input specs for index_select operation.
|
||||
|
||||
torch.index_select(input, dim, index) returns a tensor with:
|
||||
- output.shape[dim] = len(index)
|
||||
- output.shape[other_dims] = input.shape[other_dims]
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("IndexSelectOperator can only produce TensorSpec outputs")
|
||||
|
||||
# For simplicity, we'll work with a 2D input tensor
|
||||
# and select along dimension 0
|
||||
dim = 0
|
||||
output_size = output_spec.size
|
||||
|
||||
# Input tensor - create a shape where we can select from it
|
||||
# If output is (k, m), input can be (n, m) where n >= k
|
||||
if len(output_size) == 1:
|
||||
# Output is 1D, input should be at least 1D
|
||||
input_size = (output_size[0] + 2,) # Make input larger
|
||||
input_stride = (1,)
|
||||
elif len(output_size) == 2:
|
||||
# Output is 2D, input should be 2D with first dim >= output first dim
|
||||
input_size = (output_size[0] + 2, output_size[1])
|
||||
input_stride = (output_size[1], 1) # Contiguous
|
||||
else:
|
||||
# For higher dimensions, keep it simple
|
||||
input_size = tuple(
|
||||
s + 2 if i == dim else s for i, s in enumerate(output_size)
|
||||
)
|
||||
# Contiguous stride
|
||||
input_stride = tuple(
|
||||
int(torch.tensor(input_size[i + 1 :]).prod().item())
|
||||
if i < len(input_size) - 1
|
||||
else 1
|
||||
for i in range(len(input_size))
|
||||
)
|
||||
|
||||
input_tensor_spec = TensorSpec(
|
||||
size=input_size,
|
||||
stride=input_stride,
|
||||
dtype=output_spec.dtype,
|
||||
)
|
||||
|
||||
# Index tensor - 1D tensor of long dtype with indices
|
||||
index_spec = TensorSpec(
|
||||
size=(output_size[dim],) if len(output_size) > 0 else (1,),
|
||||
stride=(1,),
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
return [input_tensor_spec, index_spec]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for index_select.
|
||||
|
||||
Creates appropriate indices to select from the input tensor.
|
||||
"""
|
||||
if len(input_names) != 2:
|
||||
raise ValueError("IndexSelectOperator requires exactly two inputs")
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("IndexSelectOperator requires TensorSpec output")
|
||||
|
||||
# Determine dimension and number of indices needed
|
||||
dim = 0 # Select along dimension 0 for simplicity
|
||||
num_indices = output_spec.size[dim] if len(output_spec.size) > 0 else 1
|
||||
|
||||
# Generate code that creates valid indices within the input tensor's dimension
|
||||
return (
|
||||
f"_input_size_{output_name} = {input_names[0]}.size({dim})\n"
|
||||
f"_index_{output_name} = torch.randint(0, _input_size_{output_name}, ({num_indices},), device={input_names[0]}.device)\n"
|
||||
f"{output_name} = torch.index_select({input_names[0]}, {dim}, _index_{output_name})"
|
||||
)
|
||||
|
|
@ -400,3 +400,423 @@ class UnsqueezeOperator(LayoutOperatorBase):
|
|||
dim = len(output_spec.size) - 1
|
||||
|
||||
return f"{output_name} = torch.unsqueeze({input_names[0]}, dim={dim})"
|
||||
|
||||
|
||||
class SplitOperator(LayoutOperatorBase):
|
||||
"""Operator for torch.split() operation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SplitOperator."""
|
||||
super().__init__("split")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.split"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Split can produce any tensor output."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# Split can produce any tensor with at least one dimension
|
||||
return len(output_spec.size) > 0
|
||||
|
||||
def _get_split_params(self, output_spec: TensorSpec) -> tuple[int, int]:
|
||||
"""Get consistent split parameters based on output spec.
|
||||
|
||||
This method uses the output_spec to deterministically choose split parameters,
|
||||
ensuring that fuzz_inputs_specs and codegen make the same choices.
|
||||
"""
|
||||
# Use output_spec properties to seed random choices
|
||||
# This ensures both methods make the same choices
|
||||
seed_value = hash((output_spec.size, output_spec.dtype))
|
||||
rng = random.Random(seed_value)
|
||||
|
||||
split_dim = rng.randint(0, len(output_spec.size) - 1)
|
||||
num_chunks = rng.randint(2, 4)
|
||||
|
||||
return split_dim, num_chunks
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input spec for split operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("SplitOperator can only produce TensorSpec outputs")
|
||||
|
||||
# torch.split() splits a tensor along a dimension
|
||||
# We'll use split_size_or_sections as an integer (split_size)
|
||||
# The output will be one of the chunks from the split
|
||||
if len(output_spec.size) == 0:
|
||||
raise ValueError("Cannot split a scalar tensor")
|
||||
|
||||
split_dim, num_chunks = self._get_split_params(output_spec)
|
||||
|
||||
# Calculate input size: input will have split_dim with size = output_size * num_chunks
|
||||
# (or slightly larger to account for uneven splits)
|
||||
input_size = list(output_spec.size)
|
||||
input_size[split_dim] = output_spec.size[split_dim] * num_chunks
|
||||
|
||||
# Create input tensor spec
|
||||
from torchfuzz.tensor_fuzzer import fuzz_valid_stride
|
||||
|
||||
input_stride = fuzz_valid_stride(tuple(input_size))
|
||||
|
||||
return [
|
||||
TensorSpec(
|
||||
size=tuple(input_size), stride=input_stride, dtype=output_spec.dtype
|
||||
)
|
||||
]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for split operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("SplitOperator can only produce TensorSpec outputs")
|
||||
|
||||
split_dim, _ = self._get_split_params(output_spec)
|
||||
|
||||
# Use output size along split_dim as the split_size
|
||||
split_size = output_spec.size[split_dim]
|
||||
|
||||
# Generate the split and select the first chunk
|
||||
return f"{output_name} = torch.split({input_names[0]}, {split_size}, dim={split_dim})[0]"
|
||||
|
||||
|
||||
class ExpandOperator(LayoutOperatorBase):
|
||||
"""Operator for torch.expand() operation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ExpandOperator."""
|
||||
super().__init__("expand")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.expand"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Expand can produce any tensor output."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# Expand can produce any tensor with at least one dimension
|
||||
return len(output_spec.size) > 0
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input spec for expand operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ExpandOperator can only produce TensorSpec outputs")
|
||||
|
||||
# torch.expand() broadcasts a tensor to a new shape
|
||||
# For expand to work, each dimension of the input must either:
|
||||
# 1. Match the corresponding output dimension
|
||||
# 2. Be 1 (to be broadcasted)
|
||||
# 3. Not exist (input can have fewer dimensions than output)
|
||||
|
||||
# Generate input size with same or fewer dimensions
|
||||
output_size = output_spec.size
|
||||
input_ndim = random.randint(1, len(output_size))
|
||||
|
||||
# Create input size by choosing dimensions to broadcast
|
||||
input_size = []
|
||||
for i in range(input_ndim):
|
||||
output_dim_idx = len(output_size) - input_ndim + i
|
||||
output_dim = output_size[output_dim_idx]
|
||||
|
||||
# Randomly choose to either match the output dimension or use 1 for broadcasting
|
||||
# Use 1 with higher probability to test broadcasting behavior
|
||||
if random.random() < 0.6 and output_dim > 1:
|
||||
input_size.append(1)
|
||||
else:
|
||||
input_size.append(output_dim)
|
||||
|
||||
input_size = tuple(input_size)
|
||||
|
||||
# Create input tensor spec
|
||||
from torchfuzz.tensor_fuzzer import fuzz_valid_stride
|
||||
|
||||
input_stride = fuzz_valid_stride(input_size)
|
||||
|
||||
return [
|
||||
TensorSpec(size=input_size, stride=input_stride, dtype=output_spec.dtype)
|
||||
]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for expand operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ExpandOperator can only produce TensorSpec outputs")
|
||||
|
||||
shape_str = str(list(output_spec.size))
|
||||
return f"{output_name} = {input_names[0]}.expand({shape_str})"
|
||||
|
||||
|
||||
class CatOperator(LayoutOperatorBase):
|
||||
"""Operator for torch.cat() operation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CatOperator."""
|
||||
super().__init__("cat")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.cat"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Cat can produce any tensor output."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# Cat can produce any tensor with at least one dimension
|
||||
return len(output_spec.size) > 0
|
||||
|
||||
def _get_cat_params(self, output_spec: TensorSpec) -> tuple[int, int]:
|
||||
"""Get consistent cat parameters based on output spec.
|
||||
|
||||
This method uses the output_spec to deterministically choose cat parameters,
|
||||
ensuring that fuzz_inputs_specs and codegen make the same choices.
|
||||
"""
|
||||
# Use output_spec properties to seed random choices
|
||||
# This ensures both methods make the same choices
|
||||
seed_value = hash((output_spec.size, output_spec.dtype))
|
||||
rng = random.Random(seed_value)
|
||||
|
||||
cat_dim = rng.randint(0, len(output_spec.size) - 1)
|
||||
num_tensors = rng.randint(2, 4)
|
||||
|
||||
return cat_dim, num_tensors
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for cat operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("CatOperator can only produce TensorSpec outputs")
|
||||
|
||||
# torch.cat() concatenates tensors along a dimension
|
||||
# Choose a random dimension to concatenate along
|
||||
if len(output_spec.size) == 0:
|
||||
raise ValueError("Cannot concatenate scalar tensors")
|
||||
|
||||
cat_dim, num_tensors = self._get_cat_params(output_spec)
|
||||
|
||||
# Distribute output size along cat_dim across input tensors
|
||||
total_size = output_spec.size[cat_dim]
|
||||
|
||||
# Use deterministic RNG for splitting sizes
|
||||
seed_value = hash((output_spec.size, output_spec.dtype))
|
||||
rng = random.Random(seed_value + 1) # +1 to differentiate from param selection
|
||||
|
||||
# Generate sizes for each input tensor along cat_dim
|
||||
input_sizes_at_cat_dim = []
|
||||
remaining_size = total_size
|
||||
|
||||
for i in range(num_tensors - 1):
|
||||
if remaining_size > 0:
|
||||
# Randomly split the remaining size
|
||||
max_size = max(1, remaining_size - (num_tensors - i - 1))
|
||||
size_for_this_tensor = rng.randint(1, max_size)
|
||||
input_sizes_at_cat_dim.append(size_for_this_tensor)
|
||||
remaining_size -= size_for_this_tensor
|
||||
else:
|
||||
input_sizes_at_cat_dim.append(0)
|
||||
|
||||
# Last tensor gets the remaining size
|
||||
input_sizes_at_cat_dim.append(max(0, remaining_size))
|
||||
|
||||
# Create input tensor specs
|
||||
from torchfuzz.tensor_fuzzer import fuzz_valid_stride
|
||||
|
||||
input_specs = []
|
||||
for size_at_cat_dim in input_sizes_at_cat_dim:
|
||||
input_size = list(output_spec.size)
|
||||
input_size[cat_dim] = size_at_cat_dim
|
||||
input_size = tuple(input_size)
|
||||
|
||||
input_stride = fuzz_valid_stride(input_size)
|
||||
|
||||
input_specs.append(
|
||||
TensorSpec(
|
||||
size=input_size, stride=input_stride, dtype=output_spec.dtype
|
||||
)
|
||||
)
|
||||
|
||||
return input_specs
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for cat operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("CatOperator can only produce TensorSpec outputs")
|
||||
|
||||
# Use the same cat_dim that was used in fuzz_inputs_specs
|
||||
cat_dim, _ = self._get_cat_params(output_spec)
|
||||
|
||||
# Generate the cat operation
|
||||
tensors_str = ", ".join(input_names)
|
||||
return f"{output_name} = torch.cat([{tensors_str}], dim={cat_dim})"
|
||||
|
||||
|
||||
class StackOperator(LayoutOperatorBase):
|
||||
"""Operator for torch.stack() operation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize StackOperator."""
|
||||
super().__init__("stack")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.stack"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Stack can produce any tensor output with at least one dimension."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# Stack creates a new dimension, so output must have at least one dimension
|
||||
# Also, no dimension can be 0 since that would require stacking 0 tensors
|
||||
# Limit to outputs where all dimensions are <= 4 to avoid creating too large graphs
|
||||
return (
|
||||
len(output_spec.size) > 0
|
||||
and 0 not in output_spec.size
|
||||
and all(dim <= 4 for dim in output_spec.size)
|
||||
)
|
||||
|
||||
def _get_stack_params(self, output_spec: TensorSpec) -> int:
|
||||
"""Get consistent stack dimension based on output spec.
|
||||
|
||||
This method uses the output_spec to deterministically choose stack parameters,
|
||||
ensuring that fuzz_inputs_specs and codegen make the same choices.
|
||||
"""
|
||||
# Use output_spec properties to seed random choices
|
||||
# This ensures both methods make the same choices
|
||||
seed_value = hash((output_spec.size, output_spec.dtype))
|
||||
rng = random.Random(seed_value)
|
||||
|
||||
stack_dim = rng.randint(0, len(output_spec.size) - 1)
|
||||
|
||||
return stack_dim
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for stack operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("StackOperator can only produce TensorSpec outputs")
|
||||
|
||||
# torch.stack() stacks tensors along a new dimension
|
||||
# Choose a random dimension to stack along (0 to len(output_spec.size))
|
||||
if len(output_spec.size) == 0:
|
||||
raise ValueError("Cannot stack into a scalar tensor")
|
||||
|
||||
stack_dim = self._get_stack_params(output_spec)
|
||||
|
||||
# Number of tensors to stack equals the size of the new dimension
|
||||
# Limit to max 4 tensors to avoid creating too large graphs
|
||||
num_tensors = min(output_spec.size[stack_dim], 4)
|
||||
|
||||
# Input tensors have the output shape with the stack_dim removed
|
||||
input_size = list(output_spec.size)
|
||||
input_size.pop(stack_dim)
|
||||
input_size = tuple(input_size)
|
||||
|
||||
# Create input tensor specs (all inputs have the same shape)
|
||||
from torchfuzz.tensor_fuzzer import fuzz_valid_stride
|
||||
|
||||
input_specs = []
|
||||
for _ in range(num_tensors):
|
||||
input_stride = fuzz_valid_stride(input_size)
|
||||
input_specs.append(
|
||||
TensorSpec(
|
||||
size=input_size, stride=input_stride, dtype=output_spec.dtype
|
||||
)
|
||||
)
|
||||
|
||||
return input_specs
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for stack operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("StackOperator can only produce TensorSpec outputs")
|
||||
|
||||
# Use the same stack_dim that was used in fuzz_inputs_specs
|
||||
stack_dim = self._get_stack_params(output_spec)
|
||||
|
||||
# Generate the stack operation
|
||||
tensors_str = ", ".join(input_names)
|
||||
return f"{output_name} = torch.stack([{tensors_str}], dim={stack_dim})"
|
||||
|
||||
|
||||
class ChunkOperator(LayoutOperatorBase):
|
||||
"""Operator for torch.chunk() operation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ChunkOperator."""
|
||||
super().__init__("chunk")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> Optional[str]:
|
||||
"""Return the torch operation name."""
|
||||
return "torch.chunk"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Chunk can produce any tensor output."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
return False
|
||||
# Chunk can produce any tensor with at least one dimension
|
||||
return len(output_spec.size) > 0
|
||||
|
||||
def _get_chunk_params(self, output_spec: TensorSpec) -> tuple[int, int]:
|
||||
"""Get consistent chunk parameters based on output spec.
|
||||
|
||||
This method uses the output_spec to deterministically choose chunk parameters,
|
||||
ensuring that fuzz_inputs_specs and codegen make the same choices.
|
||||
"""
|
||||
# Use output_spec properties to seed random choices
|
||||
# This ensures both methods make the same choices
|
||||
seed_value = hash((output_spec.size, output_spec.dtype))
|
||||
rng = random.Random(seed_value)
|
||||
|
||||
chunk_dim = rng.randint(0, len(output_spec.size) - 1)
|
||||
num_chunks = rng.randint(2, 4)
|
||||
|
||||
return chunk_dim, num_chunks
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input spec for chunk operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ChunkOperator can only produce TensorSpec outputs")
|
||||
|
||||
# torch.chunk() splits a tensor into chunks along a dimension
|
||||
# The output will be one of the chunks from the split
|
||||
if len(output_spec.size) == 0:
|
||||
raise ValueError("Cannot chunk a scalar tensor")
|
||||
|
||||
chunk_dim, num_chunks = self._get_chunk_params(output_spec)
|
||||
|
||||
# Calculate input size: input will have chunk_dim with size = output_size * num_chunks
|
||||
# torch.chunk() tries to split evenly, but the last chunk may be smaller
|
||||
input_size = list(output_spec.size)
|
||||
input_size[chunk_dim] = output_spec.size[chunk_dim] * num_chunks
|
||||
|
||||
# Create input tensor spec
|
||||
from torchfuzz.tensor_fuzzer import fuzz_valid_stride
|
||||
|
||||
input_stride = fuzz_valid_stride(tuple(input_size))
|
||||
|
||||
return [
|
||||
TensorSpec(
|
||||
size=tuple(input_size), stride=input_stride, dtype=output_spec.dtype
|
||||
)
|
||||
]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for chunk operation."""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ChunkOperator can only produce TensorSpec outputs")
|
||||
|
||||
chunk_dim, num_chunks = self._get_chunk_params(output_spec)
|
||||
|
||||
# Generate the chunk operation and select the first chunk
|
||||
return f"{output_name} = torch.chunk({input_names[0]}, {num_chunks}, dim={chunk_dim})[0]"
|
||||
|
|
|
|||
|
|
@ -3,13 +3,19 @@
|
|||
from typing import Optional
|
||||
|
||||
from torchfuzz.operators.arg import ArgOperator
|
||||
from torchfuzz.operators.argsort import ArgsortOperator
|
||||
from torchfuzz.operators.base import Operator
|
||||
from torchfuzz.operators.constant import ConstantOperator
|
||||
from torchfuzz.operators.gather import GatherOperator
|
||||
from torchfuzz.operators.index_select import IndexSelectOperator
|
||||
from torchfuzz.operators.item import ItemOperator
|
||||
from torchfuzz.operators.layout import (
|
||||
CatOperator,
|
||||
ChunkOperator,
|
||||
FlattenOperator,
|
||||
ReshapeOperator,
|
||||
SqueezeOperator,
|
||||
StackOperator,
|
||||
UnsqueezeOperator,
|
||||
ViewOperator,
|
||||
)
|
||||
|
|
@ -48,6 +54,8 @@ from torchfuzz.operators.scalar_pointwise import (
|
|||
)
|
||||
from torchfuzz.operators.tensor_pointwise import (
|
||||
AddOperator,
|
||||
ClampOperator,
|
||||
CumsumOperator,
|
||||
DivOperator,
|
||||
MulOperator,
|
||||
SubOperator,
|
||||
|
|
@ -70,6 +78,8 @@ class OperatorRegistry:
|
|||
self.register(MulOperator())
|
||||
self.register(SubOperator())
|
||||
self.register(DivOperator())
|
||||
self.register(ClampOperator())
|
||||
self.register(CumsumOperator())
|
||||
|
||||
# Individual scalar pointwise operators (preferred)
|
||||
self.register(ScalarAddOperator())
|
||||
|
|
@ -84,6 +94,9 @@ class OperatorRegistry:
|
|||
# # Data-dependent operators
|
||||
self.register(NonzeroOperator())
|
||||
self.register(MaskedSelectOperator())
|
||||
self.register(GatherOperator())
|
||||
self.register(IndexSelectOperator())
|
||||
self.register(ArgsortOperator())
|
||||
self.register(ItemOperator())
|
||||
self.register(UniqueOperator())
|
||||
|
||||
|
|
@ -93,6 +106,9 @@ class OperatorRegistry:
|
|||
self.register(FlattenOperator())
|
||||
self.register(SqueezeOperator())
|
||||
self.register(UnsqueezeOperator())
|
||||
self.register(CatOperator())
|
||||
self.register(StackOperator())
|
||||
self.register(ChunkOperator())
|
||||
|
||||
# Matrix multiplication operators
|
||||
self.register(MMOperator())
|
||||
|
|
|
|||
|
|
@ -117,3 +117,127 @@ class DivOperator(PointwiseOperator):
|
|||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.div"
|
||||
|
||||
|
||||
class ClampOperator(Operator):
|
||||
"""Operator for torch.clamp (element-wise clamping)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("clamp")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.clamp"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Clamp can produce tensors but not scalars."""
|
||||
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
|
||||
return False
|
||||
return isinstance(output_spec, TensorSpec)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for clamp operation.
|
||||
|
||||
Clamp takes:
|
||||
- input tensor (same shape and dtype as output)
|
||||
- optional min value (scalar or None)
|
||||
- optional max value (scalar or None)
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("ClampOperator can only produce TensorSpec outputs")
|
||||
|
||||
return [
|
||||
TensorSpec(
|
||||
size=output_spec.size,
|
||||
stride=output_spec.stride,
|
||||
dtype=output_spec.dtype,
|
||||
)
|
||||
]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for clamp operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("Clamp requires exactly 1 input tensor")
|
||||
|
||||
input_name = input_names[0]
|
||||
|
||||
# Generate random min and max values for clamping
|
||||
# We'll randomly decide whether to use min, max, or both
|
||||
use_min = random.random() > 0.33
|
||||
use_max = random.random() > 0.33
|
||||
|
||||
if not use_min and not use_max:
|
||||
use_min = True
|
||||
|
||||
args = [input_name]
|
||||
if use_min:
|
||||
args.append("min=-1.0")
|
||||
else:
|
||||
args.append("min=None")
|
||||
|
||||
if use_max:
|
||||
args.append("max=1.0")
|
||||
else:
|
||||
args.append("max=None")
|
||||
|
||||
return f"{output_name} = torch.clamp({', '.join(args)})"
|
||||
|
||||
|
||||
class CumsumOperator(Operator):
|
||||
"""Operator for torch.cumsum (cumulative sum along a dimension)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("cumsum")
|
||||
|
||||
@property
|
||||
def torch_op_name(self) -> str:
|
||||
return "torch.cumsum"
|
||||
|
||||
def can_produce(self, output_spec: Spec) -> bool:
|
||||
"""Cumsum can produce tensors but not scalars."""
|
||||
if isinstance(output_spec, TensorSpec) and output_spec.dtype == torch.bool:
|
||||
return False
|
||||
# Cumsum needs at least 1 dimension
|
||||
if isinstance(output_spec, TensorSpec) and len(output_spec.size) == 0:
|
||||
return False
|
||||
return isinstance(output_spec, TensorSpec)
|
||||
|
||||
def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
|
||||
"""Generate input specs for cumsum operation.
|
||||
|
||||
Cumsum takes an input tensor with same shape and dtype as output.
|
||||
"""
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("CumsumOperator can only produce TensorSpec outputs")
|
||||
|
||||
return [
|
||||
TensorSpec(
|
||||
size=output_spec.size,
|
||||
stride=output_spec.stride,
|
||||
dtype=output_spec.dtype,
|
||||
)
|
||||
]
|
||||
|
||||
def codegen(
|
||||
self, output_name: str, input_names: list[str], output_spec: Spec
|
||||
) -> str:
|
||||
"""Generate code for cumsum operation."""
|
||||
if len(input_names) != 1:
|
||||
raise ValueError("Cumsum requires exactly 1 input tensor")
|
||||
|
||||
if not isinstance(output_spec, TensorSpec):
|
||||
raise ValueError("Output spec must be a TensorSpec")
|
||||
|
||||
input_name = input_names[0]
|
||||
|
||||
# Choose a random valid dimension
|
||||
num_dims = len(output_spec.size)
|
||||
if num_dims == 0:
|
||||
raise ValueError("Cumsum requires tensor with at least 1 dimension")
|
||||
|
||||
# Pick a random dimension index
|
||||
dim = random.randint(0, num_dims - 1)
|
||||
|
||||
return f"{output_name} = torch.cumsum({input_name}, dim={dim})"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user