pytorch/tools/experimental/torchfuzz/operators/nonzero.py

81 lines
3.0 KiB
Python

"""Nonzero operator implementation."""
from typing import Optional
import torch
from torchfuzz.operators.base import Operator
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
class NonzeroOperator(Operator):
"""Operator for finding nonzero elements in a tensor."""
def __init__(self):
super().__init__("nonzero")
@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.nonzero"
def can_produce(self, output_spec: Spec) -> bool:
"""Nonzero produces a tensor with shape (n_nonzero, n_dims).
We can deterministically synthesize inputs to match any 2D int64 output
shape (k, d) without data-dependent guards by constructing an input with
exactly k non-zero elements and d dimensions.
"""
return (
isinstance(output_spec, TensorSpec)
and output_spec.dtype in [torch.int64, torch.long]
and len(output_spec.size) == 2
)
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
"""Generate input spec for nonzero operation.
The actual values will be synthesized in codegen to achieve the target size.
"""
if not isinstance(output_spec, TensorSpec):
raise ValueError("NonzeroOperator can only produce TensorSpec outputs")
# Provide a placeholder spec; codegen will ignore the actual input content
# and synthesize a tensor with desired nonzero count and dimensionality.
d = output_spec.size[1]
input_spec = TensorSpec(
size=tuple([1] * d) if d > 0 else (),
stride=tuple([1] * d) if d > 0 else (),
dtype=torch.bool,
)
return [input_spec]
def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for nonzero using synthesized input to match target size.
No data-dependent conditionals/guards. Constructs an input with exactly
k = output_spec.size[0] non-zero elements and d = output_spec.size[1] dims,
then calls torch.nonzero on it.
"""
if len(input_names) != 1:
raise ValueError("NonzeroOperator requires exactly one input")
if not isinstance(output_spec, TensorSpec) or len(output_spec.size) != 2:
raise ValueError("NonzeroOperator requires 2D TensorSpec output")
k = output_spec.size[0]
d = output_spec.size[1]
# Construct concrete shape literal like (k, 1, 1, ...)
shape_elems = [str(k)] + ["1"] * max(0, d - 1)
shape_literal = (
"(" + ", ".join(shape_elems) + ("," if d == 1 else "") + ")"
if d > 0
else "()"
)
return (
f"_x_nz = torch.zeros({shape_literal}, dtype=torch.bool, device={input_names[0]}.device)\n"
f"_x_nz_flat = _x_nz.reshape(-1)\n"
f"_x_nz_flat[:{k}] = True\n"
f"{output_name} = torch.nonzero(_x_nz)"
)