mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
```
import torch
torch._dynamo.config.capture_scalar_outputs = True
torch.manual_seed(42)
def fuzzed_program(arg_0, arg_1, arg_2):
var_node_3 = arg_0 # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_4 = torch.full((1,), (-0.29262632146522655-0.7687848816195035j), dtype=torch.complex128) # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_2 = torch.ops.aten.add(var_node_3, var_node_4) # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_6 = arg_1 # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_7 = arg_2 # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_5 = torch.ops.aten.add(var_node_6, var_node_7) # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_1 = torch.ops.aten.add(var_node_2, var_node_5) # size=(1,), stride=(1,), dtype=complex128, device=cuda
var_node_0 = var_node_1.item() # dtype=complex128
return var_node_0
arg_0 = torch.as_strided(torch.randn(1).to(torch.complex128), (1,), (1,))
arg_1 = torch.as_strided(torch.randn(1).to(torch.complex128), (1,), (1,))
arg_2 = torch.as_strided(torch.randn(1).to(torch.complex128), (1,), (1,))
args = (arg_0, arg_1, arg_2)
result_original = fuzzed_program(*args)
print('✅ eager success')
compiled_program = torch.compile(fuzzed_program, fullgraph=False, dynamic=True)
result_compiled = compiled_program(*args)
print('✅ compile success')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163812
Approved by: https://github.com/pianpwk
ghstack dependencies: #163743
242 lines
8.3 KiB
Python
242 lines
8.3 KiB
Python
# mypy: ignore-errors
|
|
import os
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from torchfuzz.operators import get_operator
|
|
from torchfuzz.ops_fuzzer import OperationGraph
|
|
from torchfuzz.tensor_descriptor import format_tensor_descriptor
|
|
from torchfuzz.tensor_fuzzer import ScalarSpec, Spec, TensorSpec
|
|
|
|
|
|
def convert_graph_to_python_code(
|
|
operation_graph: OperationGraph, seed: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Convert an operation graph to executable Python code using topological ordering.
|
|
|
|
The graph-based approach generates code by:
|
|
1. Getting the topological order of nodes (dependencies before dependents)
|
|
2. Generating code for each node in that order
|
|
3. Properly handling input dependencies through node connections
|
|
|
|
Args:
|
|
operation_graph: OperationGraph instance containing the operation DAG
|
|
seed: Random seed for reproducible code generation. If None, uses current random state.
|
|
|
|
Returns:
|
|
String containing the complete Python code that executes the operations
|
|
"""
|
|
|
|
# Set seed for reproducible code generation
|
|
if seed is not None:
|
|
import random
|
|
|
|
random.seed(seed + 1000) # Offset to avoid conflicts with graph generation
|
|
torch.manual_seed(seed + 1000)
|
|
|
|
if not operation_graph.nodes:
|
|
raise ValueError("Empty operation graph")
|
|
|
|
# Get topological order - this ensures dependencies are processed before dependents
|
|
topo_order = operation_graph.get_topological_order()
|
|
|
|
# Track generated variables and arg operations
|
|
generated_code_lines = []
|
|
node_variables: dict[str, tuple[str, Spec]] = {} # Maps node_id to (var_name, spec)
|
|
arg_operations: list[
|
|
tuple[str, Spec]
|
|
] = [] # List of (node_id, spec) for arg operations
|
|
|
|
# Process nodes in topological order
|
|
for node_id in topo_order:
|
|
node = operation_graph.nodes[node_id]
|
|
op_name = node.op_name
|
|
output_spec = node.output_spec
|
|
|
|
# Generate output variable name
|
|
output_var_name = f"var_{node_id}"
|
|
|
|
# Generate input variable names from input nodes
|
|
input_var_names = []
|
|
for input_node_id in node.input_nodes:
|
|
if input_node_id in node_variables:
|
|
input_var_name, _ = node_variables[input_node_id]
|
|
input_var_names.append(input_var_name)
|
|
else:
|
|
raise ValueError(
|
|
f"Node {node_id} depends on {input_node_id}, but {input_node_id} "
|
|
f"was not processed yet. Topological order may be incorrect."
|
|
)
|
|
|
|
# Handle different operation types
|
|
if op_name == "arg" or op_name.startswith("arg_"):
|
|
# Track arg operations for later function signature generation
|
|
arg_operations.append((node_id, output_spec))
|
|
arg_name = f"arg_{len(arg_operations) - 1}"
|
|
# Add tensor descriptor comment for arg operations too
|
|
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
|
|
operation_lines = [f"{output_var_name} = {arg_name} " + descriptor_comment]
|
|
else:
|
|
# Generate operation execution code
|
|
operation_lines = generate_simple_operation_code(
|
|
output_var_name, input_var_names, op_name, output_spec
|
|
)
|
|
|
|
# Add proper indentation for function body
|
|
generated_code_lines.extend([" " + line for line in operation_lines])
|
|
|
|
# Track this node's variable
|
|
node_variables[node_id] = (output_var_name, output_spec)
|
|
|
|
# The final result comes from the root node
|
|
root_node_id = operation_graph.root_node_id
|
|
if root_node_id not in node_variables:
|
|
raise ValueError(f"Root node {root_node_id} was not processed")
|
|
|
|
final_var_name, _ = node_variables[root_node_id]
|
|
|
|
# Generate function signature based on discovered arg operations
|
|
if arg_operations:
|
|
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
|
|
function_signature = f"def fuzzed_program({', '.join(arg_names)})"
|
|
else:
|
|
function_signature = "def fuzzed_program()"
|
|
|
|
# Build the complete code - all imports at the top
|
|
code_lines = [
|
|
"import torch",
|
|
"torch._dynamo.config.capture_scalar_outputs = True",
|
|
"",
|
|
]
|
|
|
|
# Add single seed at the top if seed is provided
|
|
if seed is not None:
|
|
code_lines.append(f"torch.manual_seed({seed})")
|
|
code_lines.append("")
|
|
|
|
code_lines.append(function_signature + ":")
|
|
|
|
# Add the generated operation code
|
|
code_lines.extend(generated_code_lines)
|
|
|
|
# Add return statement
|
|
code_lines.extend(
|
|
[
|
|
f" return {final_var_name}",
|
|
"",
|
|
]
|
|
)
|
|
|
|
# Generate argument creation code without individual seeds
|
|
if arg_operations:
|
|
for i, (node_id, spec) in enumerate(arg_operations):
|
|
arg_name = f"arg_{i}"
|
|
|
|
if isinstance(spec, ScalarSpec):
|
|
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
|
code_lines.append(
|
|
f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
|
|
)
|
|
|
|
elif isinstance(spec, TensorSpec):
|
|
size_str = str(spec.size)
|
|
stride_str = str(spec.stride)
|
|
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
|
|
|
|
# Calculate storage size needed for the strided tensor
|
|
if spec.size:
|
|
storage_size = 1
|
|
for dim_size, stride in zip(spec.size, spec.stride):
|
|
if dim_size > 1:
|
|
storage_size = max(
|
|
storage_size, (dim_size - 1) * abs(stride) + 1
|
|
)
|
|
else:
|
|
storage_size = 1
|
|
|
|
code_lines.append(
|
|
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
|
|
)
|
|
|
|
# Generate the final execution with both normal and compiled versions
|
|
if arg_operations:
|
|
arg_names = [f"arg_{i}" for i in range(len(arg_operations))]
|
|
if len(arg_names) == 1:
|
|
args_tuple = (
|
|
f"({arg_names[0]},)" # Single element tuple needs trailing comma
|
|
)
|
|
else:
|
|
args_tuple = f"({', '.join(arg_names)})"
|
|
else:
|
|
args_tuple = "()"
|
|
|
|
code_lines.extend(
|
|
[
|
|
"",
|
|
f"args = {args_tuple}",
|
|
"result_original = fuzzed_program(*args)",
|
|
"print('✅ eager success')",
|
|
"compiled_program = torch.compile(fuzzed_program, fullgraph=False, dynamic=True)",
|
|
"result_compiled = compiled_program(*args)",
|
|
"print('✅ compile success')",
|
|
]
|
|
)
|
|
|
|
return "\n".join(code_lines)
|
|
|
|
|
|
def generate_simple_operation_code(
|
|
output_var: str,
|
|
input_vars: list,
|
|
op_name: str,
|
|
output_spec,
|
|
) -> list:
|
|
"""
|
|
Generate code lines for executing a single operation using class-based operators.
|
|
|
|
Args:
|
|
output_var: Name of the output variable
|
|
input_vars: List of input variable names
|
|
op_name: Name of the operation
|
|
output_spec: Output specification for the operation
|
|
"""
|
|
# Try to get the operator from the registry
|
|
operator = get_operator(op_name)
|
|
|
|
if operator is not None:
|
|
# Use the class-based operator to generate code
|
|
code_line = operator.codegen(output_var, input_vars, output_spec)
|
|
# Add tensor descriptor comment
|
|
descriptor_comment = f"# {format_tensor_descriptor(output_spec)}"
|
|
return [code_line + " " + descriptor_comment]
|
|
else:
|
|
# Fallback for unknown operations
|
|
return [f"# Unknown operation: {op_name}"]
|
|
|
|
|
|
def create_program_file(python_code: str) -> str:
|
|
"""
|
|
Create a temporary Python file from the generated code.
|
|
|
|
Args:
|
|
python_code: String containing Python code to write
|
|
|
|
Returns:
|
|
Path to the created temporary file
|
|
"""
|
|
import random
|
|
|
|
# Generate a random nonce for the filename
|
|
nonce = random.randint(0, 1_000_000_000)
|
|
tmp_dir = "/tmp/torchfuzz"
|
|
os.makedirs(tmp_dir, exist_ok=True)
|
|
generated_file_path = os.path.join(tmp_dir, f"fuzz_{nonce}.py")
|
|
|
|
# Write the generated code to the specified file
|
|
with open(generated_file_path, "w") as f:
|
|
f.write(python_code)
|
|
|
|
return generated_file_path
|