[torchfuzz] make generated code much more concise and cleaner (#163812)

```
import torch

torch._dynamo.config.capture_scalar_outputs = True
torch.manual_seed(42)

def fuzzed_program(arg_0, arg_1, arg_2):
    var_node_3 = arg_0 # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_4 = torch.full((1,), (-0.29262632146522655-0.7687848816195035j), dtype=torch.complex128) # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_2 = torch.ops.aten.add(var_node_3, var_node_4) # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_6 = arg_1 # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_7 = arg_2 # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_5 = torch.ops.aten.add(var_node_6, var_node_7) # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_1 = torch.ops.aten.add(var_node_2, var_node_5) # size=(1,), stride=(1,), dtype=complex128, device=cuda
    var_node_0 = var_node_1.item() # dtype=complex128
    return var_node_0

arg_0 = torch.as_strided(torch.randn(1).to(torch.complex128), (1,), (1,))
arg_1 = torch.as_strided(torch.randn(1).to(torch.complex128), (1,), (1,))
arg_2 = torch.as_strided(torch.randn(1).to(torch.complex128), (1,), (1,))

args = (arg_0, arg_1, arg_2)
result_original = fuzzed_program(*args)
print(' eager success')
compiled_program = torch.compile(fuzzed_program, fullgraph=False, dynamic=True)
result_compiled = compiled_program(*args)
print(' compile success')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163812
Approved by: https://github.com/pianpwk
ghstack dependencies: #163743
This commit is contained in:
bobrenjc93 2025-09-24 21:07:38 -07:00 committed by PyTorch MergeBot
parent 92f7361e27
commit b14a14a662
2 changed files with 29 additions and 50 deletions

View File

@ -55,11 +55,6 @@ def convert_graph_to_python_code(
op_name = node.op_name op_name = node.op_name
output_spec = node.output_spec output_spec = node.output_spec
# Generate comment for this operation
generated_code_lines.append(
f" # Node {node_id}: {op_name} (depth {node.depth})"
)
# Generate output variable name # Generate output variable name
output_var_name = f"var_{node_id}" output_var_name = f"var_{node_id}"
@ -91,7 +86,6 @@ def convert_graph_to_python_code(
# Add proper indentation for function body # Add proper indentation for function body
generated_code_lines.extend([" " + line for line in operation_lines]) generated_code_lines.extend([" " + line for line in operation_lines])
generated_code_lines.append("")
# Track this node's variable # Track this node's variable
node_variables[node_id] = (output_var_name, output_spec) node_variables[node_id] = (output_var_name, output_spec)
@ -110,72 +104,61 @@ def convert_graph_to_python_code(
else: else:
function_signature = "def fuzzed_program()" function_signature = "def fuzzed_program()"
# Build the complete code # Build the complete code - all imports at the top
fuzzer_dir = os.path.dirname(os.path.abspath(__file__))
code_lines = [ code_lines = [
"import torch", "import torch",
"import sys", "torch._dynamo.config.capture_scalar_outputs = True",
"import os",
"# Add fuzzer directory to path so we can import tensor_fuzzer",
f"fuzzer_dir = r'{fuzzer_dir}'",
"if fuzzer_dir not in sys.path:",
" sys.path.insert(0, fuzzer_dir)",
"from tensor_fuzzer import fuzz_scalar, fuzz_tensor_simple, ScalarSpec, TensorSpec",
"", "",
"# Generated fuzzed program code (topological order from operation graph)",
f"# Graph has {len(operation_graph.nodes)} nodes",
"",
function_signature + ":",
] ]
# Add single seed at the top if seed is provided
if seed is not None:
code_lines.append(f"torch.manual_seed({seed})")
code_lines.append("")
code_lines.append(function_signature + ":")
# Add the generated operation code # Add the generated operation code
code_lines.extend(generated_code_lines) code_lines.extend(generated_code_lines)
# Add return statement # Add return statement
code_lines.extend( code_lines.extend(
[ [
" # Final result from root node",
f" return {final_var_name}", f" return {final_var_name}",
"", "",
] ]
) )
# Generate argument creation code with deterministic seeds # Generate argument creation code without individual seeds
if arg_operations: if arg_operations:
code_lines.append("# Create arguments for the fuzzed program")
for i, (node_id, spec) in enumerate(arg_operations): for i, (node_id, spec) in enumerate(arg_operations):
arg_name = f"arg_{i}" arg_name = f"arg_{i}"
# Use a deterministic seed based on the argument index and main seed
arg_seed = (seed + 10000 + i) if seed is not None else None
if isinstance(spec, ScalarSpec): if isinstance(spec, ScalarSpec):
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.") dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
if arg_seed is not None: code_lines.append(
code_lines.extend( f"{arg_name} = torch.tensor(torch.randn(()), dtype={dtype_str}).item()"
[ )
f"scalar_spec = ScalarSpec(dtype={dtype_str})",
f"{arg_name} = fuzz_scalar(scalar_spec, seed={arg_seed})",
]
)
else:
code_lines.extend(
[
f"scalar_spec = ScalarSpec(dtype={dtype_str})",
f"{arg_name} = fuzz_scalar(scalar_spec)",
]
)
elif isinstance(spec, TensorSpec): elif isinstance(spec, TensorSpec):
size_str = str(spec.size) size_str = str(spec.size)
stride_str = str(spec.stride) stride_str = str(spec.stride)
dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.") dtype_str = f"torch.{spec.dtype}".replace("torch.torch.", "torch.")
if arg_seed is not None:
code_lines.append( # Calculate storage size needed for the strided tensor
f"{arg_name} = fuzz_tensor_simple({size_str}, {stride_str}, {dtype_str}, seed={arg_seed})" if spec.size:
) storage_size = 1
for dim_size, stride in zip(spec.size, spec.stride):
if dim_size > 1:
storage_size = max(
storage_size, (dim_size - 1) * abs(stride) + 1
)
else: else:
code_lines.append( storage_size = 1
f"{arg_name} = fuzz_tensor_simple({size_str}, {stride_str}, {dtype_str})"
) code_lines.append(
f"{arg_name} = torch.as_strided(torch.randn({storage_size}).to({dtype_str}), {size_str}, {stride_str})"
)
# Generate the final execution with both normal and compiled versions # Generate the final execution with both normal and compiled versions
if arg_operations: if arg_operations:
@ -191,9 +174,6 @@ def convert_graph_to_python_code(
code_lines.extend( code_lines.extend(
[ [
"import torch",
"import sys",
"torch._dynamo.config.capture_scalar_outputs = True",
"", "",
f"args = {args_tuple}", f"args = {args_tuple}",
"result_original = fuzzed_program(*args)", "result_original = fuzzed_program(*args)",
@ -201,7 +181,6 @@ def convert_graph_to_python_code(
"compiled_program = torch.compile(fuzzed_program, fullgraph=False, dynamic=True)", "compiled_program = torch.compile(fuzzed_program, fullgraph=False, dynamic=True)",
"result_compiled = compiled_program(*args)", "result_compiled = compiled_program(*args)",
"print('✅ compile success')", "print('✅ compile success')",
"",
] ]
) )

View File

@ -47,7 +47,7 @@ def fuzz_and_execute(
# Generate max_depth if not provided (range 3-12) # Generate max_depth if not provided (range 3-12)
if max_depth is None: if max_depth is None:
random.seed(seed + 999) # Use seed offset for consistent depth selection random.seed(seed + 999) # Use seed offset for consistent depth selection
max_depth = random.randint(3, 12) max_depth = random.randint(2, 4)
else: else:
# Clamp max_depth to valid range # Clamp max_depth to valid range
max_depth = max(1, max_depth) max_depth = max(1, max_depth)