mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164716 Approved by: https://github.com/pianpwk ghstack dependencies: #164432, #164434, #164514, #164646, #164647, #164649, #164687, #164688, #164693, #164694, #164715
81 lines
3.0 KiB
Python
81 lines
3.0 KiB
Python
"""Nonzero operator implementation."""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from torchfuzz.operators.base import Operator
|
|
from torchfuzz.tensor_fuzzer import Spec, TensorSpec
|
|
|
|
|
|
class NonzeroOperator(Operator):
|
|
"""Operator for finding nonzero elements in a tensor."""
|
|
|
|
def __init__(self):
|
|
super().__init__("nonzero")
|
|
|
|
@property
|
|
def torch_op_name(self) -> Optional[str]:
|
|
"""Return the torch operation name."""
|
|
return "torch.nonzero"
|
|
|
|
def can_produce(self, output_spec: Spec) -> bool:
|
|
"""Nonzero produces a tensor with shape (n_nonzero, n_dims).
|
|
|
|
We can deterministically synthesize inputs to match any 2D int64 output
|
|
shape (k, d) without data-dependent guards by constructing an input with
|
|
exactly k non-zero elements and d dimensions.
|
|
"""
|
|
return (
|
|
isinstance(output_spec, TensorSpec)
|
|
and output_spec.dtype in [torch.int64, torch.long]
|
|
and len(output_spec.size) == 2
|
|
)
|
|
|
|
def fuzz_inputs_specs(self, output_spec: Spec, num_inputs: int = 1) -> list[Spec]:
|
|
"""Generate input spec for nonzero operation.
|
|
|
|
The actual values will be synthesized in codegen to achieve the target size.
|
|
"""
|
|
if not isinstance(output_spec, TensorSpec):
|
|
raise ValueError("NonzeroOperator can only produce TensorSpec outputs")
|
|
|
|
# Provide a placeholder spec; codegen will ignore the actual input content
|
|
# and synthesize a tensor with desired nonzero count and dimensionality.
|
|
d = output_spec.size[1]
|
|
input_spec = TensorSpec(
|
|
size=tuple([1] * d) if d > 0 else (),
|
|
stride=tuple([1] * d) if d > 0 else (),
|
|
dtype=torch.bool,
|
|
)
|
|
return [input_spec]
|
|
|
|
def codegen(
|
|
self, output_name: str, input_names: list[str], output_spec: Spec
|
|
) -> str:
|
|
"""Generate code for nonzero using synthesized input to match target size.
|
|
|
|
No data-dependent conditionals/guards. Constructs an input with exactly
|
|
k = output_spec.size[0] non-zero elements and d = output_spec.size[1] dims,
|
|
then calls torch.nonzero on it.
|
|
"""
|
|
if len(input_names) != 1:
|
|
raise ValueError("NonzeroOperator requires exactly one input")
|
|
if not isinstance(output_spec, TensorSpec) or len(output_spec.size) != 2:
|
|
raise ValueError("NonzeroOperator requires 2D TensorSpec output")
|
|
k = output_spec.size[0]
|
|
d = output_spec.size[1]
|
|
# Construct concrete shape literal like (k, 1, 1, ...)
|
|
shape_elems = [str(k)] + ["1"] * max(0, d - 1)
|
|
shape_literal = (
|
|
"(" + ", ".join(shape_elems) + ("," if d == 1 else "") + ")"
|
|
if d > 0
|
|
else "()"
|
|
)
|
|
return (
|
|
f"_x_nz = torch.zeros({shape_literal}, dtype=torch.bool, device={input_names[0]}.device)\n"
|
|
f"_x_nz_flat = _x_nz.reshape(-1)\n"
|
|
f"_x_nz_flat[:{k}] = True\n"
|
|
f"{output_name} = torch.nonzero(_x_nz)"
|
|
)
|