diff --git a/scripts/analyze_procmod_bugs.py b/scripts/analyze_procmod_bugs.py new file mode 100644 index 00000000..1ddf7144 --- /dev/null +++ b/scripts/analyze_procmod_bugs.py @@ -0,0 +1,1512 @@ +#!/usr/bin/env python3 +""" +Analyze procedurally generated bugs and their validation results. + +This script analyzes the bugs generated by procedural modifications and provides +detailed statistics about: +- Total bugs generated and validated +- Breakdown by modifier type +- Validation pass rates +- Test failure statistics +- Distribution of bugs across modifiers + +Usage: + python scripts/analyze_procmod_bugs.py [options] + python scripts/analyze_procmod_bugs.py --repo # Analyze single repo + python scripts/analyze_procmod_bugs.py # Analyze all repos + +Example: + python scripts/analyze_procmod_bugs.py + python scripts/analyze_procmod_bugs.py --repo dtolnay__anyhow.1d7ef1db +""" + +import argparse +import json +import os +import re +import subprocess +import time +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict + +import matplotlib.pyplot as plt +import numpy as np + +from swebench.harness.constants import FAIL_TO_PASS, LOG_REPORT, PASS_TO_PASS + + +def extract_modifier_name(instance_id: str) -> str: + """Extract the modifier name from an instance ID. + + Example: Instagram__MonkeyType.70c3acf6.func_pm_remove_assign__abc123 -> func_pm_remove_assign + """ + parts = instance_id.split(".") + if len(parts) >= 3: + last_part = parts[-1] + if "__" in last_part: + return last_part.split("__")[0] + return "unknown" + + +def extract_test_count(repo_id: str) -> int: + """Extract total number of unit tests from test_output.txt. + + Looks for lines containing 'test result: ok. X passed' and sums up all X values. + + Args: + repo_id: Repository identifier (e.g., Instagram__MonkeyType.70c3acf6) + + Returns: + Total number of tests, or 0 if test_output.txt not found + """ + test_output_path = ( + Path("logs/run_validation") / repo_id / f"{repo_id}.ref" / "test_output.txt" + ) + + if not test_output_path.exists(): + return 0 + + total_tests = 0 + pattern = re.compile(r"test result: ok\. (\d+) passed") + + try: + with open(test_output_path, "r") as f: + for line in f: + match = pattern.search(line) + if match: + total_tests += int(match.group(1)) + except Exception as e: + print(f"Warning: Could not read test output for {repo_id}: {e}") + return 0 + + return total_tests + + +def analyze_bugs(repo_id: str) -> Dict[str, Any]: + """Analyze bugs for a given repository. + + Args: + repo_id: Repository identifier (e.g., Instagram__MonkeyType.70c3acf6) + + Returns: + Dictionary containing analysis results + """ + bug_gen_dir = Path("logs/bug_gen") / repo_id + validation_dir = Path("logs/run_validation") / repo_id + + if not bug_gen_dir.exists(): + raise FileNotFoundError(f"Bug generation directory not found: {bug_gen_dir}") + + generated_bugs = defaultdict(list) + total_generated = 0 + + for root, _, files in os.walk(bug_gen_dir): + for file in files: + if file.startswith("bug__") and file.endswith(".diff"): + total_generated += 1 + modifier_name = file.split("bug__")[1].split("__")[0] + instance_id = f"{repo_id}.{file.split('bug__')[1].replace('.diff', '')}" + generated_bugs[modifier_name].append(instance_id) + + generated_bugs_len = sum(len(v) for v in generated_bugs.values()) + assert generated_bugs_len == total_generated + + validated_bugs = defaultdict( + lambda: { + "total": 0, + "passed": 0, + "failed": 0, + "f2p_counts": [], + "p2p_counts": [], + "instances": [], + } + ) + + timeout_bugs = defaultdict(list) + total_timeouts = 0 + total_validated = 0 + total_passed = 0 + total_failed = 0 + + if validation_dir.exists(): + for instance_dir in os.listdir(validation_dir): + # Skip reference tests + if instance_dir.endswith(".ref"): + print(f"Skipping {instance_dir} because it is a reference test") + continue + + instance_path = validation_dir / instance_dir + report_path = instance_path / LOG_REPORT + + if report_path.exists(): + with open(report_path, "r") as f: + report = json.load(f) + + modifier_name = extract_modifier_name(instance_dir) + + # Exclude if report timed_out is true + if report.get("timed_out", False): + print(f"Timeout bug from timed_out == True: {instance_dir}") + timeout_bugs[modifier_name].append(instance_dir) + continue + + total_validated += 1 + + f2p_count = len(report.get(FAIL_TO_PASS, [])) + p2p_count = len(report.get(PASS_TO_PASS, [])) + + validated_bugs[modifier_name]["total"] += 1 + validated_bugs[modifier_name]["f2p_counts"].append(f2p_count) + validated_bugs[modifier_name]["p2p_counts"].append(p2p_count) + validated_bugs[modifier_name]["instances"].append( + {"instance_id": instance_dir, "f2p": f2p_count, "p2p": p2p_count} + ) + + if f2p_count > 0: + validated_bugs[modifier_name]["passed"] += 1 + total_passed += 1 + else: + validated_bugs[modifier_name]["failed"] += 1 + total_failed += 1 + else: + print(f"Timeout bug from missing report: {instance_dir}") + timeout_bugs[modifier_name].append(instance_dir) + + total_timeouts = total_generated - total_validated + # Add generated bugs that are missing from the validated folder to timeout_bugs + for modifier_name, bug_list in generated_bugs.items(): + for bug_id in bug_list: + instance_path = validation_dir / bug_id + # If the bug was generated but not validated (not in validation folder) + if not instance_path.exists(): + print(f"Timeout bug from missing validation folder: {bug_id}") + timeout_bugs[modifier_name].append(bug_id) + + # Extract test count + test_count = extract_test_count(repo_id) + + return { + "repo_id": repo_id, + "total_generated": total_generated, + "total_validated": total_validated, + "total_passed": total_passed, + "total_failed": total_failed, + "total_timeouts": total_timeouts, + "test_count": test_count, + "generated_by_modifier": {k: len(v) for k, v in generated_bugs.items()}, + "validated_by_modifier": dict(validated_bugs), + "timeout_by_modifier": {k: len(v) for k, v in timeout_bugs.items()}, + } + + +def print_statistics(analysis: Dict[str, Any]) -> None: + """Print detailed statistics from the analysis.""" + + print("=" * 80) + print(f"Bug Generation and Validation Analysis for {analysis['repo_id']}") + print("=" * 80) + print() + + print("OVERALL STATISTICS") + print("-" * 80) + print(f"Total bugs generated: {analysis['total_generated']}") + print(f"Total bugs validated: {analysis['total_validated']}") + print( + f"Bugs that passed validation: {analysis['total_passed']} ({analysis['total_passed'] / max(analysis['total_validated'], 1) * 100:.1f}%)" + ) + print( + f"Bugs that failed validation: {analysis['total_failed']} ({analysis['total_failed'] / max(analysis['total_validated'], 1) * 100:.1f}%)" + ) + print() + + print("PER-MODIFIER STATISTICS") + print("-" * 80) + print( + f"{'Modifier':<35} {'Generated':<12} {'Validated':<12} {'Passed':<12} {'Pass Rate':<12}" + ) + print("-" * 80) + + sorted_modifiers = sorted( + analysis["generated_by_modifier"].items(), key=lambda x: x[1], reverse=True + ) + + for modifier, generated_count in sorted_modifiers: + validated_data = analysis["validated_by_modifier"].get(modifier, {}) + validated_count = validated_data.get("total", 0) + passed_count = validated_data.get("passed", 0) + pass_rate = (passed_count / max(validated_count, 1)) * 100 + + print( + f"{modifier:<35} {generated_count:<12} {validated_count:<12} {passed_count:<12} {pass_rate:>10.1f}%" + ) + + print() + + print("TEST FAILURE STATISTICS") + print("-" * 80) + print( + f"{'Modifier':<35} {'Avg F2P':<12} {'Min F2P':<12} {'Max F2P':<12} {'Avg P2P':<12}" + ) + print("-" * 80) + + for modifier, generated_count in sorted_modifiers: + validated_data = analysis["validated_by_modifier"].get(modifier, {}) + f2p_counts = validated_data.get("f2p_counts", []) + p2p_counts = validated_data.get("p2p_counts", []) + + if f2p_counts: + avg_f2p = sum(f2p_counts) / len(f2p_counts) + min_f2p = min(f2p_counts) + max_f2p = max(f2p_counts) + avg_p2p = sum(p2p_counts) / len(p2p_counts) + + print( + f"{modifier:<35} {avg_f2p:<12.2f} {min_f2p:<12} {max_f2p:<12} {avg_p2p:<12.2f}" + ) + + print() + + print("DISTRIBUTION SUMMARY") + print("-" * 80) + + all_f2p = [] + all_p2p = [] + for validated_data in analysis["validated_by_modifier"].values(): + all_f2p.extend(validated_data.get("f2p_counts", [])) + all_p2p.extend(validated_data.get("p2p_counts", [])) + + if all_f2p: + print( + f"Average tests broken per bug (F2P): {sum(all_f2p) / len(all_f2p):.2f}" + ) + print( + f"Median tests broken per bug (F2P): {sorted(all_f2p)[len(all_f2p) // 2]}" + ) + print(f"Min tests broken per bug (F2P): {min(all_f2p)}") + print(f"Max tests broken per bug (F2P): {max(all_f2p)}") + print() + print( + f"Average tests maintained per bug (P2P): {sum(all_p2p) / len(all_p2p):.2f}" + ) + print( + f"Median tests maintained per bug (P2P): {sorted(all_p2p)[len(all_p2p) // 2]}" + ) + + print() + print("=" * 80) + + +def save_report(analysis: Dict[str, Any], output_file: str) -> None: + """Save the analysis report to a JSON file.""" + with open(output_file, "w") as f: + json.dump(analysis, f, indent=2) + print(f"Detailed report saved to: {output_file}") + + +def plot_bug_distribution( + analysis: Dict[str, Any], + output_path: str, + show_generated_bugs: bool = False, + show_timeout_bugs: bool = False, +) -> None: + """Plot bar chart of bug distribution by modifier type. + + Args: + analysis: Analysis results dictionary + output_path: Path to save the plot + show_generated_bugs: Whether to show generated bugs bar + show_timeout_bugs: Whether to show timeout bugs bar stacked on validated + """ + # Extract data + generated_by_modifier = analysis.get("generated_by_modifier", {}) + validated_by_modifier = analysis.get("validated_by_modifier", {}) + timeout_by_modifier = analysis.get("timeout_by_modifier", {}) + + # If it's aggregate data, handle differently + if "aggregate_statistics" in analysis: + modifier_data = analysis["aggregate_statistics"]["by_modifier"] + generated_by_modifier = {k: v["generated"] for k, v in modifier_data.items()} + validated_by_modifier = { + k: {"total": v["validated"], "passed": v["passed"]} + for k, v in modifier_data.items() + } + timeout_by_modifier = {k: v.get("timeout", 0) for k, v in modifier_data.items()} + + if not generated_by_modifier: + print("No data to plot") + return + + # Sort modifiers by generated count (descending) + sorted_modifiers = sorted( + generated_by_modifier.items(), key=lambda x: x[1], reverse=True + ) + + modifier_keys = [m[0] for m in sorted_modifiers] + modifiers_display = [m[0].replace("func_pm_", "") for m in sorted_modifiers] + generated_counts = [m[1] for m in sorted_modifiers] + + # Get validated, passed, and timeout counts for each modifier + validated_counts = [] + passed_counts = [] + timeout_counts = [] + for modifier_key in modifier_keys: + if modifier_key in validated_by_modifier: + if isinstance(validated_by_modifier[modifier_key], dict): + validated_counts.append( + validated_by_modifier[modifier_key].get("total", 0) + ) + passed_counts.append( + validated_by_modifier[modifier_key].get("passed", 0) + ) + else: + validated_counts.append(validated_by_modifier[modifier_key]) + passed_counts.append(0) + else: + validated_counts.append(0) + passed_counts.append(0) + + # Get timeout count for this modifier + timeout_counts.append(timeout_by_modifier.get(modifier_key, 0)) + + # Filter out modifiers with zero passed bugs + filtered_data = [ + (mod, gen, val, pas, tim) + for mod, gen, val, pas, tim in zip( + modifiers_display, + generated_counts, + validated_counts, + passed_counts, + timeout_counts, + ) + if pas > 0 + ] + + if not filtered_data: + print("No modifiers with passed bugs to plot") + return + + # Unpack filtered data + ( + modifiers_display, + generated_counts, + validated_counts, + passed_counts, + timeout_counts, + ) = zip(*filtered_data) + + # Create figure and axis + fig, ax = plt.subplots(figsize=(14, 8.8)) + + # Set positions for bars + x = np.arange(len(modifiers_display)) + width = 0.6 + + # Create overlaid bars (drawn from back to front) + if show_generated_bugs: + # Back: Generated (lightgray - 10% darker than whitesmoke) + bars0 = ax.bar( + x, + generated_counts, + width, + label="Generated", + color="lightgray", + edgecolor="none", + zorder=1, + ) + # Middle: Validated (gray - 10% darker than silver) + bars1 = ax.bar( + x, + validated_counts, + width, + label="Validated", + color="gray", + edgecolor="none", + zorder=2, + ) + # Front: Passed (black) - overlay on validated + bars2 = ax.bar( + x, + passed_counts, + width, + label="Passed", + color="black", + edgecolor="none", + zorder=3, + ) + # Timeout bars stacked on top of validated (dotted pattern) + if show_timeout_bugs: + bars3 = ax.bar( + x, + timeout_counts, + width, + bottom=validated_counts, + label="Timeout", + color="gray", + edgecolor="black", + linewidth=0, + hatch="...", + zorder=4, + ) + else: + # Back: Validated (light grey) + bars1 = ax.bar( + x, + validated_counts, + width, + label="Validated", + color="lightgrey", + edgecolor="none", + zorder=1, + ) + # Front: Passed (black) - overlay on validated + bars2 = ax.bar( + x, + passed_counts, + width, + label="Passed", + color="black", + edgecolor="none", + zorder=2, + ) + # Timeout bars stacked on top of validated (dotted pattern) + if show_timeout_bugs: + bars3 = ax.bar( + x, + timeout_counts, + width, + bottom=validated_counts, + label="Timeout", + color="lightgrey", + edgecolor="black", + linewidth=0, + hatch="...", + zorder=3, + ) + + # Customize plot + ax.set_xlabel("Modifier Type", fontsize=22, fontweight="bold") + ax.set_ylabel("Number of Bugs", fontsize=22, fontweight="bold") + ax.set_title( + "Bug Distribution by Modifier Type", fontsize=24, fontweight="bold", pad=20 + ) + ax.set_xticks(x) + ax.set_xticklabels(modifiers_display, rotation=45, ha="right", fontsize=20) + ax.tick_params(axis="y", labelsize=20) + ax.legend(fontsize=20, loc="upper right") + ax.grid(axis="y", alpha=0.3, linestyle="--") + + # Add value labels on bars + if show_generated_bugs: + for i, (gen, val, pas, tim) in enumerate( + zip(generated_counts, validated_counts, passed_counts, timeout_counts) + ): + # Label for generated (at the top of generated bar) + ax.text( + x[i], + gen, + f"{int(gen)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="dimgrey", + ) + # Label for validated (at the top of validated bar) + if not show_timeout_bugs: + ax.text( + x[i], + val, + f"{int(val)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="dimgrey", + ) + # Label for passed (at the top of passed bar) + ax.text( + x[i], + pas, + f"{int(pas)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="white", + ) + # Label for timeout (at the top of timeout bar) + if show_timeout_bugs and tim > 0: + ax.text( + x[i], + val + tim, + f"{int(gen)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="dimgrey", + ) + else: + for i, (gen, val, pas, tim) in enumerate( + zip(generated_counts, validated_counts, passed_counts, timeout_counts) + ): + # Label for validated (at the top of validated bar) + if not show_timeout_bugs: + ax.text( + x[i], + val, + f"{int(val)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="dimgrey", + ) + # Label for passed (at the top of passed bar) + ax.text( + x[i], + pas, + f"{int(pas)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="black", + ) + # Label for timeout (at the top of timeout bar) + if show_timeout_bugs and tim > 0: + ax.text( + x[i], + val + tim, + f"{int(gen)}", + ha="center", + va="bottom", + fontsize=16, + fontweight="bold", + color="dimgrey", + ) + + # Tight layout to prevent label cutoff + plt.tight_layout() + + # Ensure output directory exists + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Save plot + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + print(f"Bug distribution plot saved to: {output_path}") + + +def plot_per_repo_distribution( + all_analyses: list[Dict[str, Any]], output_path: str, show_repo_owner: bool = False +) -> None: + """Plot per-repo breakdown of validated, passed, and timeout bugs. + + Args: + all_analyses: List of analysis results for each repo + output_path: Path to save the plot + show_repo_owner: Whether to show repo owner in labels + """ + if not all_analyses: + print("No data to plot") + return + + # Extract data per repo + repos = [a["repo_id"] for a in all_analyses] + validated = [a["total_validated"] for a in all_analyses] + passed = [a["total_passed"] for a in all_analyses] + timeout = [a.get("total_timeouts", 0) for a in all_analyses] + + # Truncate commit_id from repo names (remove part after last dot) + repos_display = [r.rsplit(".", 1)[0] for r in repos] + + # Replace __ with / and optionally hide owner + if show_repo_owner: + repos_display = [r.replace("__", "/") for r in repos_display] + else: + # Hide owner (everything before and including __) + repos_display = [ + r.split("__", 1)[-1] if "__" in r else r for r in repos_display + ] + + # Create figure + fig, ax = plt.subplots(figsize=(16, 10)) + + x = np.arange(len(repos)) + width = 0.25 + + # Create grouped bars + ax.bar(x - width, validated, width, label="Validated", color="lightgrey") + ax.bar(x, passed, width, label="Passed", color="black") + ax.bar(x + width, timeout, width, label="Timeout", color="lightgrey", hatch="...") + + # Customize plot + ax.set_xlabel("Repository", fontsize=22, fontweight="bold") + ax.set_ylabel("Number of Bugs", fontsize=22, fontweight="bold") + ax.set_title( + "Per-Repository Bug Distribution", fontsize=24, fontweight="bold", pad=20 + ) + ax.set_xticks(x) + ax.set_xticklabels(repos_display, rotation=45, ha="right", fontsize=14) + ax.tick_params(axis="y", labelsize=20) + ax.legend(fontsize=20, loc="upper right") + ax.grid(axis="y", alpha=0.3, linestyle="--") + + plt.tight_layout() + + # Ensure output directory exists + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Save plot + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + print(f"Per-repo distribution plot saved to: {output_path}") + + +def get_repo_info(owner: str, repo: str) -> dict: + """Get repository info from GitHub API using curl. + + Args: + owner: Repository owner + repo: Repository name + + Returns: + Dictionary with 'size' (in KB) and 'stars' (stargazers_count), or empty dict if request fails + """ + url = f"https://api.github.com/repos/{owner}/{repo}" + + # Build curl command with authentication if GITHUB_TOKEN is available + curl_cmd = ["curl", "-s"] + + # Check for GITHUB_TOKEN environment variable + github_token = os.environ.get("GITHUB_TOKEN") + if github_token: + curl_cmd.extend(["-H", f"Authorization: Bearer {github_token}"]) + + curl_cmd.extend(["-H", "Accept: application/vnd.github+json"]) + curl_cmd.append(url) + + try: + result = subprocess.run(curl_cmd, capture_output=True, text=True, timeout=10) + + if result.returncode == 0 and result.stdout.strip(): + try: + data = json.loads(result.stdout) + + # Check for GitHub API errors (rate limit, not found, etc.) + if "message" in data: + if "rate limit" in data["message"].lower(): + print( + f"Warning: GitHub API rate limit exceeded. Message: {data['message']}" + ) + else: + print( + f"Warning: GitHub API error for {owner}/{repo}: {data['message']}" + ) + return {} + + # Successfully got data + return { + "size": data.get("size", 0), + "stars": data.get("stargazers_count", 0), + } + except json.JSONDecodeError as e: + print(f"Warning: Failed to parse JSON response for {owner}/{repo}: {e}") + return {} + else: + print( + f"Warning: Failed to get repo info for {owner}/{repo} (curl returned {result.returncode})" + ) + return {} + except Exception as e: + print(f"Warning: Error getting repo info for {owner}/{repo}: {e}") + return {} + + +def plot_timeout_vs_tests_correlation( + all_analyses: list[Dict[str, Any]], output_path: str +) -> None: + """Plot correlation between percent of timeout bugs and total number of tests. + + Args: + all_analyses: List of analysis results for each repo + output_path: Path to save the plot + """ + if not all_analyses: + print("No data to plot") + return + + # Extract data per repo + test_counts = [] + timeout_percentages = [] + repo_names = [] + + for analysis in all_analyses: + test_count = analysis.get("test_count", 0) + total_generated = analysis.get("total_generated", 0) + total_timeouts = analysis.get("total_timeouts", 0) + + # Skip repos with no tests or no bugs + if test_count == 0 or total_generated == 0: + continue + + timeout_pct = (total_timeouts / total_generated) * 100 + test_counts.append(test_count) + timeout_percentages.append(timeout_pct) + + # Truncate commit_id from repo name + repo_name = analysis["repo_id"].rsplit(".", 1)[0] + # Hide owner (everything before and including __) + repo_name = repo_name.split("__", 1)[-1] if "__" in repo_name else repo_name + repo_names.append(repo_name) + + if not test_counts: + print("No data with valid test counts to plot") + return + + # Convert to numpy arrays for easier manipulation + test_counts = np.array(test_counts) + timeout_percentages = np.array(timeout_percentages) + repo_names = np.array(repo_names) + + # Identify outliers using IQR method on both axes + q1_x, q3_x = np.percentile(test_counts, [25, 75]) + iqr_x = q3_x - q1_x + lower_x, upper_x = q1_x - 1.5 * iqr_x, q3_x + 1.5 * iqr_x + + q1_y, q3_y = np.percentile(timeout_percentages, [25, 75]) + iqr_y = q3_y - q1_y + lower_y, upper_y = q1_y - 1.5 * iqr_y, q3_y + 1.5 * iqr_y + + # Create mask for non-outliers + mask_x = (test_counts >= lower_x) & (test_counts <= upper_x) + mask_y = (timeout_percentages >= lower_y) & (timeout_percentages <= upper_y) + mask = mask_x & mask_y + + test_counts_filtered = test_counts[mask] + timeout_percentages_filtered = timeout_percentages[mask] + outliers_x = test_counts[~mask] + outliers_y = timeout_percentages[~mask] + + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + + # Scatter plot for non-outliers only + ax.scatter( + test_counts_filtered, + timeout_percentages_filtered, + alpha=0.6, + s=100, + color="black", + label="Data", + ) + + # Print outliers if any (but don't plot them) + if len(outliers_x) > 0: + print(f"\nExcluded {len(outliers_x)} outlier(s) from correlation analysis:") + outlier_repos = repo_names[~mask] + for i, (repo, tests, timeout_pct) in enumerate( + zip(outlier_repos, outliers_x, outliers_y) + ): + print(f" {i + 1}. {repo}: {int(tests)} tests, {timeout_pct:.1f}% timeout") + print() + + # Add linear regression line using filtered data + if len(test_counts_filtered) > 1: + z = np.polyfit(test_counts_filtered, timeout_percentages_filtered, 1) + p = np.poly1d(z) + x_line = np.linspace(min(test_counts_filtered), max(test_counts_filtered), 100) + ax.plot( + x_line, + p(x_line), + "r--", + alpha=0.8, + linewidth=2, + label=f"y={z[0]:.4f}x+{z[1]:.2f}", + ) + + # Calculate correlation coefficient on filtered data + correlation = np.corrcoef(test_counts_filtered, timeout_percentages_filtered)[ + 0, 1 + ] + ax.text( + 0.05, + 0.95, + f"Correlation: {correlation:.3f}", + transform=ax.transAxes, + fontsize=16, + verticalalignment="top", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + ) + + # Customize plot + ax.set_xlabel("Total Number of Unit Tests", fontsize=18, fontweight="bold") + ax.set_ylabel("Timeout Bugs (%)", fontsize=18, fontweight="bold") + ax.set_title( + "Correlation: Timeout Bugs vs Number of Tests", + fontsize=20, + fontweight="bold", + pad=20, + ) + ax.tick_params(axis="both", labelsize=14) + ax.grid(alpha=0.3, linestyle="--") + ax.legend(fontsize=14, loc="upper right") + + plt.tight_layout() + + # Ensure output directory exists + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Save plot + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + print(f"Timeout vs tests correlation plot saved to: {output_path}") + + +def plot_num_tests_repo_size_correlation( + all_analyses: list[Dict[str, Any]], output_path: str +) -> None: + """Plot correlation between number of tests and repository size. + + Args: + all_analyses: List of analysis results for each repo + output_path: Path to save the plot + """ + if not all_analyses: + print("No data to plot") + return + + # Extract data per repo + test_counts = [] + repo_sizes = [] + repo_names = [] + + github_token = os.environ.get("GITHUB_TOKEN") + if github_token: + print("\nFetching repository sizes from GitHub API (authenticated)...") + else: + print( + "\nFetching repository sizes from GitHub API (unauthenticated - rate limited to 60 requests/hour)..." + ) + print( + "Tip: Set GITHUB_TOKEN environment variable to increase rate limit to 5000 requests/hour" + ) + + for analysis in all_analyses: + test_count = analysis.get("test_count", 0) + repo_id = analysis["repo_id"] + + # Skip repos with no tests + if test_count == 0: + continue + + # Parse repo_id to extract owner and repo name + # Format: owner__repo.commit_hash + repo_full = repo_id.rsplit(".", 1)[0] # Remove commit hash + if "__" in repo_full: + owner, repo = repo_full.split("__", 1) + + # Get repo info from GitHub API + repo_info = get_repo_info(owner, repo) + repo_size = repo_info.get("size", 0) + + if repo_size > 0: + test_counts.append(test_count) + repo_sizes.append(repo_size) + repo_names.append(repo) + + # Small delay to avoid rate limiting + time.sleep(0.1) + + if not test_counts: + print("No data with valid test counts and repo sizes to plot") + return + + print(f"Successfully fetched sizes for {len(test_counts)} repositories\n") + + # Convert to numpy arrays for easier manipulation + test_counts = np.array(test_counts) + repo_sizes = np.array(repo_sizes) + repo_names = np.array(repo_names) + + # Identify outliers using IQR method on both axes + q1_x, q3_x = np.percentile(repo_sizes, [25, 75]) + iqr_x = q3_x - q1_x + lower_x, upper_x = q1_x - 1.5 * iqr_x, q3_x + 1.5 * iqr_x + + q1_y, q3_y = np.percentile(test_counts, [25, 75]) + iqr_y = q3_y - q1_y + lower_y, upper_y = q1_y - 1.5 * iqr_y, q3_y + 1.5 * iqr_y + + # Create mask for non-outliers + mask_x = (repo_sizes >= lower_x) & (repo_sizes <= upper_x) + mask_y = (test_counts >= lower_y) & (test_counts <= upper_y) + mask = mask_x & mask_y + + repo_sizes_filtered = repo_sizes[mask] + test_counts_filtered = test_counts[mask] + outliers_x = repo_sizes[~mask] + outliers_y = test_counts[~mask] + + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + + # Scatter plot for non-outliers only + ax.scatter( + repo_sizes_filtered, + test_counts_filtered, + alpha=0.6, + s=100, + color="black", + label="Data", + ) + + # Print outliers if any (but don't plot them) + if len(outliers_x) > 0: + print(f"Excluded {len(outliers_x)} outlier(s) from correlation analysis:") + outlier_repos = repo_names[~mask] + for i, (repo, size, tests) in enumerate( + zip(outlier_repos, outliers_x, outliers_y) + ): + print(f" {i + 1}. {repo}: {int(size)} KB, {int(tests)} tests") + print() + + # Add linear regression line using filtered data + if len(repo_sizes_filtered) > 1: + z = np.polyfit(repo_sizes_filtered, test_counts_filtered, 1) + p = np.poly1d(z) + x_line = np.linspace(min(repo_sizes_filtered), max(repo_sizes_filtered), 100) + ax.plot( + x_line, + p(x_line), + "r--", + alpha=0.8, + linewidth=2, + label=f"y={z[0]:.4f}x+{z[1]:.2f}", + ) + + # Calculate correlation coefficient on filtered data + correlation = np.corrcoef(repo_sizes_filtered, test_counts_filtered)[0, 1] + ax.text( + 0.05, + 0.95, + f"Correlation: {correlation:.3f}", + transform=ax.transAxes, + fontsize=16, + verticalalignment="top", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + ) + + # Customize plot + ax.set_xlabel("Repository Size (KB)", fontsize=18, fontweight="bold") + ax.set_ylabel("Total Number of Unit Tests", fontsize=18, fontweight="bold") + ax.set_title( + "Correlation: Number of Tests vs Repository Size", + fontsize=20, + fontweight="bold", + pad=20, + ) + ax.tick_params(axis="both", labelsize=14) + ax.grid(alpha=0.3, linestyle="--") + ax.legend(fontsize=14, loc="upper right") + + plt.tight_layout() + + # Ensure output directory exists + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Save plot + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + print(f"Number of tests vs repo size correlation plot saved to: {output_path}") + + +def plot_num_tests_repo_star_correlation( + all_analyses: list[Dict[str, Any]], output_path: str +) -> None: + """Plot correlation between number of tests and repository stars. + + Args: + all_analyses: List of analysis results for each repo + output_path: Path to save the plot + """ + if not all_analyses: + print("No data to plot") + return + + # Extract data per repo + test_counts = [] + repo_stars = [] + repo_names = [] + + github_token = os.environ.get("GITHUB_TOKEN") + if github_token: + print("\nFetching repository stars from GitHub API (authenticated)...") + else: + print( + "\nFetching repository stars from GitHub API (unauthenticated - rate limited to 60 requests/hour)..." + ) + print( + "Tip: Set GITHUB_TOKEN environment variable to increase rate limit to 5000 requests/hour" + ) + + for analysis in all_analyses: + test_count = analysis.get("test_count", 0) + repo_id = analysis["repo_id"] + + # Skip repos with no tests + if test_count == 0: + continue + + # Parse repo_id to extract owner and repo name + # Format: owner__repo.commit_hash + repo_full = repo_id.rsplit(".", 1)[0] # Remove commit hash + if "__" in repo_full: + owner, repo = repo_full.split("__", 1) + + # Get repo info from GitHub API + repo_info = get_repo_info(owner, repo) + stars = repo_info.get("stars", 0) + + if stars > 0: + test_counts.append(test_count) + repo_stars.append(stars) + repo_names.append(repo) + + # Small delay to avoid rate limiting + time.sleep(0.1) + + if not test_counts: + print("No data with valid test counts and repo stars to plot") + return + + print(f"Successfully fetched stars for {len(test_counts)} repositories\n") + + # Convert to numpy arrays for easier manipulation + test_counts = np.array(test_counts) + repo_stars = np.array(repo_stars) + repo_names = np.array(repo_names) + + # Identify outliers using IQR method on both axes + q1_x, q3_x = np.percentile(repo_stars, [25, 75]) + iqr_x = q3_x - q1_x + lower_x, upper_x = q1_x - 1.5 * iqr_x, q3_x + 1.5 * iqr_x + + q1_y, q3_y = np.percentile(test_counts, [25, 75]) + iqr_y = q3_y - q1_y + lower_y, upper_y = q1_y - 1.5 * iqr_y, q3_y + 1.5 * iqr_y + + # Create mask for non-outliers + mask_x = (repo_stars >= lower_x) & (repo_stars <= upper_x) + mask_y = (test_counts >= lower_y) & (test_counts <= upper_y) + mask = mask_x & mask_y + + repo_stars_filtered = repo_stars[mask] + test_counts_filtered = test_counts[mask] + outliers_x = repo_stars[~mask] + outliers_y = test_counts[~mask] + + # Create figure + fig, ax = plt.subplots(figsize=(12, 8)) + + # Scatter plot for non-outliers only + ax.scatter( + repo_stars_filtered, + test_counts_filtered, + alpha=0.6, + s=100, + color="black", + label="Data", + ) + + # Print outliers if any (but don't plot them) + if len(outliers_x) > 0: + print(f"Excluded {len(outliers_x)} outlier(s) from correlation analysis:") + outlier_repos = repo_names[~mask] + for i, (repo, stars, tests) in enumerate( + zip(outlier_repos, outliers_x, outliers_y) + ): + print(f" {i + 1}. {repo}: {int(stars)} stars, {int(tests)} tests") + print() + + # Add linear regression line using filtered data + if len(repo_stars_filtered) > 1: + z = np.polyfit(repo_stars_filtered, test_counts_filtered, 1) + p = np.poly1d(z) + x_line = np.linspace(min(repo_stars_filtered), max(repo_stars_filtered), 100) + ax.plot( + x_line, + p(x_line), + "r--", + alpha=0.8, + linewidth=2, + label=f"y={z[0]:.4f}x+{z[1]:.2f}", + ) + + # Calculate correlation coefficient on filtered data + correlation = np.corrcoef(repo_stars_filtered, test_counts_filtered)[0, 1] + ax.text( + 0.05, + 0.95, + f"Correlation: {correlation:.3f}", + transform=ax.transAxes, + fontsize=16, + verticalalignment="top", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + ) + + # Customize plot + ax.set_xlabel("Repository Stars", fontsize=18, fontweight="bold") + ax.set_ylabel("Total Number of Unit Tests", fontsize=18, fontweight="bold") + ax.set_title( + "Correlation: Number of Tests vs Repository Stars", + fontsize=20, + fontweight="bold", + pad=20, + ) + ax.tick_params(axis="both", labelsize=14) + ax.grid(alpha=0.3, linestyle="--") + ax.legend(fontsize=14, loc="upper right") + + plt.tight_layout() + + # Ensure output directory exists + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Save plot + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + + print(f"Number of tests vs repo stars correlation plot saved to: {output_path}") + + +def discover_repos() -> list[str]: + """Discover all repos under logs/run_validation. + + Returns: + List of repo IDs found in the validation directory + """ + validation_base = Path("logs/run_validation") + if not validation_base.exists(): + return [] + + repos = [] + for item in validation_base.iterdir(): + if item.is_dir(): + repos.append(item.name) + + return sorted(repos) + + +def print_aggregate_statistics(all_analyses: list[Dict[str, Any]]) -> None: + """Print aggregate statistics across all repos.""" + + total_repos = len(all_analyses) + total_generated = sum(a["total_generated"] for a in all_analyses) + total_validated = sum(a["total_validated"] for a in all_analyses) + total_passed = sum(a["total_passed"] for a in all_analyses) + total_failed = sum(a["total_failed"] for a in all_analyses) + total_timeouts = sum(a.get("total_timeouts", 0) for a in all_analyses) + + # Aggregate by modifier across all repos + modifier_stats = defaultdict( + lambda: { + "generated": 0, + "validated": 0, + "passed": 0, + "failed": 0, + "timeout": 0, + "f2p_counts": [], + "p2p_counts": [], + } + ) + + for analysis in all_analyses: + for modifier, count in analysis["generated_by_modifier"].items(): + modifier_stats[modifier]["generated"] += count + + for modifier, data in analysis["validated_by_modifier"].items(): + modifier_stats[modifier]["validated"] += data["total"] + modifier_stats[modifier]["passed"] += data["passed"] + modifier_stats[modifier]["failed"] += data["failed"] + modifier_stats[modifier]["f2p_counts"].extend(data["f2p_counts"]) + modifier_stats[modifier]["p2p_counts"].extend(data["p2p_counts"]) + + for modifier, count in analysis.get("timeout_by_modifier", {}).items(): + modifier_stats[modifier]["timeout"] += count + + print("\n") + print("=" * 80) + print("AGGREGATE STATISTICS ACROSS ALL REPOS") + print("=" * 80) + print() + + print("OVERALL STATISTICS") + print("-" * 80) + print(f"Total repositories analyzed: {total_repos}") + print(f"Total bugs generated: {total_generated}") + print(f"Total bugs validated: {total_validated}") + if total_validated > 0: + print( + f"Bugs that passed validation: {total_passed} ({total_passed / total_validated * 100:.1f}%)" + ) + print( + f"Bugs that failed validation: {total_failed} ({total_failed / total_validated * 100:.1f}%)" + ) + print() + + print("PER-MODIFIER STATISTICS (AGGREGATED)") + print("-" * 80) + print( + f"{'Modifier':<35} {'Generated':<12} {'Validated':<12} {'Passed':<12} {'Pass Rate':<12}" + ) + print("-" * 80) + + sorted_modifiers = sorted( + modifier_stats.items(), key=lambda x: x[1]["generated"], reverse=True + ) + + for modifier, stats in sorted_modifiers: + validated_count = stats["validated"] + passed_count = stats["passed"] + pass_rate = (passed_count / max(validated_count, 1)) * 100 + + print( + f"{modifier:<35} {stats['generated']:<12} {validated_count:<12} {passed_count:<12} {pass_rate:>10.1f}%" + ) + + print() + + print("TEST FAILURE STATISTICS (AGGREGATED)") + print("-" * 80) + print( + f"{'Modifier':<35} {'Avg F2P':<12} {'Min F2P':<12} {'Max F2P':<12} {'Avg P2P':<12}" + ) + print("-" * 80) + + for modifier, stats in sorted_modifiers: + f2p_counts = stats["f2p_counts"] + p2p_counts = stats["p2p_counts"] + + if f2p_counts: + avg_f2p = sum(f2p_counts) / len(f2p_counts) + min_f2p = min(f2p_counts) + max_f2p = max(f2p_counts) + avg_p2p = sum(p2p_counts) / len(p2p_counts) + + print( + f"{modifier:<35} {avg_f2p:<12.2f} {min_f2p:<12} {max_f2p:<12} {avg_p2p:<12.2f}" + ) + + print() + print("=" * 80) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze procedurally generated bugs and validation results" + ) + parser.add_argument( + "--repo", + "-r", + type=str, + default=None, + help="Repository identifier (e.g., Instagram__MonkeyType.70c3acf6). If not provided, analyzes all repos.", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Output file for detailed JSON report (default: logs/analysis/_analysis.json or logs/analysis/aggregate_analysis.json)", + ) + parser.add_argument( + "--show-generated-bugs", + action="store_true", + default=False, + help="Show generated bugs as another bar behind validated and passed. If enabled, validated bar shows in grey and generated in light grey.", + ) + parser.add_argument( + "--show-timeout-bugs", + action="store_true", + default=False, + help="Show timeout bugs as a dotted bar stacked on top of validated bugs.", + ) + parser.add_argument( + "--show-repo-owner", + action="store_true", + default=False, + help="Show repository owner in per-repo plot labels.", + ) + + args = parser.parse_args() + + if args.repo: + # Analyze single repo + analysis = analyze_bugs(args.repo) + print_statistics(analysis) + + if args.output is None: + output_dir = Path("logs/analysis") + output_dir.mkdir(parents=True, exist_ok=True) + args.output = str(output_dir / f"{args.repo}_analysis.json") + + save_report(analysis, args.output) + + # Plot bug distribution + plot_output = Path("logs/analysis") / "bug_distribution.png" + plot_bug_distribution( + analysis, str(plot_output), args.show_generated_bugs, args.show_timeout_bugs + ) + else: + # Analyze all repos + repos = discover_repos() + + if not repos: + print("No repositories found in logs/run_validation/") + return + + print(f"Found {len(repos)} repositories to analyze") + print() + + all_analyses = [] + + for repo in repos: + try: + analysis = analyze_bugs(repo) + all_analyses.append(analysis) + print_statistics(analysis) + print() + except FileNotFoundError as e: + print(f"Skipping {repo}: {e}") + print() + + if all_analyses: + print_aggregate_statistics(all_analyses) + + # Save aggregate report + if args.output is None: + output_dir = Path("logs/analysis") + output_dir.mkdir(parents=True, exist_ok=True) + args.output = str(output_dir / "aggregate_analysis.json") + + # Calculate aggregate statistics for JSON + total_generated = sum(a["total_generated"] for a in all_analyses) + total_validated = sum(a["total_validated"] for a in all_analyses) + total_passed = sum(a["total_passed"] for a in all_analyses) + total_failed = sum(a["total_failed"] for a in all_analyses) + total_timeouts = sum(a.get("total_timeouts", 0) for a in all_analyses) + + modifier_stats = defaultdict( + lambda: { + "generated": 0, + "validated": 0, + "passed": 0, + "failed": 0, + "timeout": 0, + "f2p_counts": [], + "p2p_counts": [], + } + ) + + for analysis in all_analyses: + for modifier, count in analysis["generated_by_modifier"].items(): + modifier_stats[modifier]["generated"] += count + + for modifier, data in analysis["validated_by_modifier"].items(): + modifier_stats[modifier]["validated"] += data["total"] + modifier_stats[modifier]["passed"] += data["passed"] + modifier_stats[modifier]["failed"] += data["failed"] + modifier_stats[modifier]["f2p_counts"].extend(data["f2p_counts"]) + modifier_stats[modifier]["p2p_counts"].extend(data["p2p_counts"]) + + for modifier, count in analysis.get("timeout_by_modifier", {}).items(): + modifier_stats[modifier]["timeout"] += count + + # Calculate summary statistics for each modifier + modifier_summaries = {} + for modifier, stats in modifier_stats.items(): + summary = { + "generated": stats["generated"], + "validated": stats["validated"], + "passed": stats["passed"], + "failed": stats["failed"], + "timeout": stats["timeout"], + "pass_rate": (stats["passed"] / max(stats["validated"], 1)) * 100, + } + + if stats["f2p_counts"]: + summary["f2p_avg"] = sum(stats["f2p_counts"]) / len( + stats["f2p_counts"] + ) + summary["f2p_min"] = min(stats["f2p_counts"]) + summary["f2p_max"] = max(stats["f2p_counts"]) + summary["p2p_avg"] = sum(stats["p2p_counts"]) / len( + stats["p2p_counts"] + ) + + modifier_summaries[modifier] = summary + + aggregate_data = { + "total_repos": len(all_analyses), + "repos": [a["repo_id"] for a in all_analyses], + "aggregate_statistics": { + "total_generated": total_generated, + "total_validated": total_validated, + "total_passed": total_passed, + "total_failed": total_failed, + "total_timeouts": total_timeouts, + "pass_rate": (total_passed / max(total_validated, 1)) * 100, + "by_modifier": modifier_summaries, + }, + "individual_analyses": all_analyses, + } + save_report(aggregate_data, args.output) + + # Plot aggregate bug distribution + plot_output = Path("logs/analysis") / "bug_distribution.png" + plot_bug_distribution( + aggregate_data, + str(plot_output), + args.show_generated_bugs, + args.show_timeout_bugs, + ) + + # Plot per-repo distribution + per_repo_output = Path("logs/analysis") / "per_repo_bug_distribution.png" + plot_per_repo_distribution( + all_analyses, str(per_repo_output), args.show_repo_owner + ) + + # Plot timeout vs tests correlation + correlation_output = ( + Path("logs/analysis") / "num_tests_timeout_correlation.png" + ) + plot_timeout_vs_tests_correlation(all_analyses, str(correlation_output)) + + # Plot num_tests vs repo_size correlation + repo_size_output = ( + Path("logs/analysis") / "num_tests_repo_size_correlation.png" + ) + plot_num_tests_repo_size_correlation(all_analyses, str(repo_size_output)) + + # Plot num_tests vs repo_stars correlation + repo_star_output = ( + Path("logs/analysis") / "num_tests_repo_star_correlation.png" + ) + plot_num_tests_repo_star_correlation(all_analyses, str(repo_star_output)) + + +if __name__ == "__main__": + main() diff --git a/scripts/procmod_bugs.py b/scripts/procmod_bugs.py new file mode 100644 index 00000000..9046cac3 --- /dev/null +++ b/scripts/procmod_bugs.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python3 +""" +Procedural Bug Generation for SWE-smith +Converts the procedural_bug_gen.sh script to Python +""" + +import argparse +import inspect +import json +import os +import platform +import subprocess +import sys +from pathlib import Path +from typing import List, Tuple + + +def run_command(cmd, shell=False, capture_output=False, check=True): + """Run a shell command and handle errors.""" + try: + if capture_output: + result = subprocess.run( + cmd, shell=shell, capture_output=True, text=True, check=check + ) + return result + else: + subprocess.run(cmd, shell=shell, check=check) + return None + except subprocess.CalledProcessError as e: + if check: + raise + return e + + +def cleanup_containers(): + """Clean up stale containers from previous run.""" + try: + # Get container IDs that match swesmith.val + result = subprocess.run( + "docker ps -a | grep swesmith.val | awk '{print $1}'", + shell=True, + capture_output=True, + text=True, + check=False, + ) + container_ids = result.stdout.strip() + + if container_ids: + subprocess.run( + f"echo {container_ids} | xargs docker rm -f", + shell=True, + check=False, + stderr=subprocess.DEVNULL, + ) + except Exception: + # Ignore cleanup errors + pass + + +def check_docker_image(image_name): + """Check if Docker image exists, pull if not.""" + print(f"[Step 1/4] Verifying Docker image...") + + # Check if image exists + result = subprocess.run( + ["docker", "image", "inspect", image_name], capture_output=True, check=False + ) + + if result.returncode == 0: + print(f"✓ Docker image found: {image_name}") + return True + + print(f"✗ Docker image not found: {image_name}") + print("Attempting to pull the image...") + + try: + subprocess.run(["docker", "pull", image_name], check=True) + return True + except subprocess.CalledProcessError as e: + print("Error: Failed to pull Docker image. Please ensure the image exists.") + raise + + +def generate_bugs(repo_id, max_bugs, interleave=False): + """Generate bugs procedurally.""" + print("\n[Step 2/4] Generating bugs procedurally...") + cmd_parts = [ + "python", + "-m", + "swesmith.bug_gen.procedural.generate", + repo_id, + "--max_bugs", + str(max_bugs), + ] + if interleave: + cmd_parts.append("--interleave") + + print(f"Running: {' '.join(cmd_parts)}") + + try: + subprocess.run(cmd_parts, check=True) + except subprocess.CalledProcessError as e: + print("Error: Bug generation failed.") + raise + + +def collect_patches(repo_id): + """Collect all patches into a single file.""" + print("\n[Step 3/4] Collecting all patches...") + patches_file = f"logs/bug_gen/{repo_id}_all_patches.json" + print(f"Running: python -m swesmith.bug_gen.collect_patches logs/bug_gen/{repo_id}") + + try: + subprocess.run( + [ + "python", + "-m", + "swesmith.bug_gen.collect_patches", + f"logs/bug_gen/{repo_id}", + ], + check=True, + ) + except subprocess.CalledProcessError as e: + print("Error: Patch collection failed.") + raise + + # Verify patches file was created + if Path(patches_file).exists(): + with open(patches_file, "r") as f: + patches = json.load(f) + num_patches = len(patches) + print(f"✓ Collected {num_patches} patches to {patches_file}") + else: + print(f"✗ Patches file not found: {patches_file}") + raise + + return patches_file + + +def get_num_cores(): + """Determine number of CPU cores for parallel validation.""" + try: + if platform.system() == "Darwin": # macOS + result = subprocess.run( + ["sysctl", "-n", "hw.ncpu"], capture_output=True, text=True, check=False + ) + if result.returncode == 0: + return int(result.stdout.strip()) + else: # Linux + result = subprocess.run( + ["nproc"], capture_output=True, text=True, check=False + ) + if result.returncode == 0: + return int(result.stdout.strip()) + except Exception: + pass + + # Default to 8 if detection fails + return 8 + + +def run_validation(patches_file, num_cores, timeout_seconds): + """Run validation on generated patches with a configurable timeout. + + Args: + patches_file: Path to patches JSON file + num_cores: Number of cores for parallel validation + timeout_seconds: Timeout in seconds for validation + """ + print(f"\n[Step 4/4] Running validation...") + print(f"Running: python -m swesmith.harness.valid {patches_file} -w {num_cores}") + print(f"Timeout: {timeout_seconds} seconds ({timeout_seconds / 60:.1f} minutes)") + + try: + subprocess.run( + [ + "python", + "-m", + "swesmith.harness.valid", + patches_file, + "-w", + str(num_cores), + ], + check=True, + timeout=timeout_seconds, + ) + except subprocess.TimeoutExpired: + print(f"\n⚠️ Warning: Validation timed out after {timeout_seconds} seconds.") + print("Partial results may be available.") + except subprocess.CalledProcessError: + print("Warning: Validation encountered errors but may have partial results.") + + +def get_rust_repos() -> List[Tuple[str, str, str]]: + """Extract all Rust repository profiles. + + Returns: + List of tuples (owner, repo, commit) + """ + from swesmith.profiles.rust import RustProfile + import swesmith.profiles.rust as rust_module + + repos = [] + for name, obj in inspect.getmembers(rust_module): + if ( + inspect.isclass(obj) + and issubclass(obj, RustProfile) + and obj.__name__ != "RustProfile" + ): + # Instantiate to get the values + instance = obj() + repos.append((instance.owner, instance.repo, instance.commit[:8])) + + return repos + + +def get_repos_for_language(language: str) -> List[Tuple[str, str, str]]: + """Get all repositories for a given language. + + Args: + language: Programming language (e.g., 'rust', 'python', 'go') + + Returns: + List of tuples (owner, repo, commit) + """ + if language.lower() == "rust": + return get_rust_repos() + else: + raise ValueError(f"Language '{language}' is not supported yet.") + + +def print_summary(repo_id, patches_file): + """Print completion summary.""" + print("\n" + "=" * 42) + print("Bug Generation Complete!") + print("=" * 42) + print(f"Generated patches: {patches_file}") + print(f"Validation results: logs/run_validation/{repo_id}/") + print("\nNext steps:") + print(f" 1. Review validation results in logs/run_validation/{repo_id}/") + print(f" 2. Analyze bugs with: python scripts/analyze_procmod_bugs.py {repo_id}") + print( + f" 3. Collect validated instances: python -m swesmith.harness.gather logs/run_validation/{repo_id}" + ) + print("=" * 42) + + +def process_repo( + repo_owner: str, + repo_name_only: str, + repo_commit: str, + max_bugs: int, + validation_timeout: int, + interleave: bool = False, +): + """Process a single repository through the bug generation pipeline. + + Args: + repo_owner: Repository owner + repo_name_only: Repository name + repo_commit: Commit hash (short form) + max_bugs: Maximum bugs per modifier + validation_timeout: Timeout in seconds for validation step + """ + repo_name = f"{repo_owner}/{repo_name_only}" + repo_id = f"{repo_owner}__{repo_name_only}.{repo_commit}" + docker_image = f"jyangballin/swesmith.x86_64.{repo_owner.lower()}_{1776}_{repo_name_only.lower()}.{repo_commit}" + + # Print header + print("\n" + "=" * 42) + print("Procedural Bug Generation for SWE-smith") + print("=" * 42) + print(f"Repository: {repo_name}") + print(f"Repository ID: {repo_id}") + print(f"Max bugs per modifier: {max_bugs}") + print(f"Docker image: {docker_image}") + print("=" * 42) + print() + + # Execute pipeline + check_docker_image(docker_image) + generate_bugs(repo_id, max_bugs, interleave) + patches_file = collect_patches(repo_id) + num_cores = get_num_cores() + run_validation(patches_file, num_cores, validation_timeout) + print_summary(repo_id, patches_file) + + +def main(): + parser = argparse.ArgumentParser( + description="Procedural Bug Generation for SWE-smith" + ) + parser.add_argument( + "--language", + "-l", + default="rust", + help="Programming language to process (default: rust)", + ) + parser.add_argument( + "--max-bugs", + "-m", + type=int, + default=200, + help="Maximum number of bugs per modifier (default: 200)", + ) + parser.add_argument( + "--repo", "-r", help="Process only a specific repository (format: owner/repo)" + ) + parser.add_argument( + "--validation-timeout", + "-t", + type=int, + default=1200, + help="Timeout in seconds for validation step (default: 1200)", + ) + parser.add_argument( + "--sequential", + action="store_true", + help="Process modifiers sequentially instead of randomized interleaving (default: interleave)", + ) + + args = parser.parse_args() + + # Set Docker host for macOS + if platform.system() == "Darwin": + home = os.path.expanduser("~") + os.environ["DOCKER_HOST"] = f"unix://{home}/.docker/run/docker.sock" + + # Clean up stale containers + cleanup_containers() + + # Get repositories to process + if args.repo: + # Single repository mode + repos = get_repos_for_language(args.language) + repo_owner, repo_name_only = args.repo.split("/") + + # Find matching repo with commit + matching_repo = None + for owner, repo, commit in repos: + if owner == repo_owner and repo == repo_name_only: + matching_repo = (owner, repo, commit) + break + + if not matching_repo: + print( + f"Error: Repository {args.repo} not found in {args.language} profiles" + ) + sys.exit(1) + + repos = [matching_repo] + else: + # All repositories mode + repos = get_repos_for_language(args.language) + + # Print overall summary + print("=" * 60) + print(f"Processing {len(repos)} {args.language.upper()} repositories") + print("=" * 60) + for i, (owner, repo, commit) in enumerate(repos, 1): + print(f"{i:2d}. {owner}/{repo} @ {commit}") + print("=" * 60) + + # Process each repository + for i, (repo_owner, repo_name_only, repo_commit) in enumerate(repos, 1): + print(f"\n\n{'=' * 60}") + print(f"Processing repository {i}/{len(repos)}: {repo_owner}/{repo_name_only}") + print(f"{'=' * 60}") + + try: + process_repo( + repo_owner, + repo_name_only, + repo_commit, + args.max_bugs, + args.validation_timeout, + not args.sequential, # interleave by default + ) + except Exception as e: + print(f"\n⚠️ Error processing {repo_owner}/{repo_name_only}: {e}") + print("Continuing to next repository...") + continue + + # Final summary + print("\n\n" + "=" * 60) + print("All repositories processed!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/swesmith/bug_gen/adapters/rust.py b/swesmith/bug_gen/adapters/rust.py index 5fa4ae20..0f37183e 100644 --- a/swesmith/bug_gen/adapters/rust.py +++ b/swesmith/bug_gen/adapters/rust.py @@ -2,13 +2,95 @@ import tree_sitter_rust as tsrs import warnings -from swesmith.constants import TODO_REWRITE, CodeEntity +from swesmith.constants import TODO_REWRITE, CodeEntity, CodeProperty from tree_sitter import Language, Parser, Query, QueryCursor RUST_LANGUAGE = Language(tsrs.language()) class RustEntity(CodeEntity): + def _analyze_properties(self): + """Analyze Rust code properties.""" + node = self.node + + if node.type == "function_item": + self._tags.add(CodeProperty.IS_FUNCTION) + + self._walk_for_properties(node) + + def _walk_for_properties(self, n): + """Walk the AST and analyze properties.""" + self._check_control_flow(n) + self._check_operations(n) + self._check_expressions(n) + + for child in n.children: + self._walk_for_properties(child) + + def _check_control_flow(self, n): + """Check for control flow patterns.""" + if n.type in ["for_expression", "while_expression", "loop_expression"]: + self._tags.add(CodeProperty.HAS_LOOP) + if n.type == "if_expression": + self._tags.add(CodeProperty.HAS_IF) + for child in n.children: + if child.type == "else_clause": + self._tags.add(CodeProperty.HAS_IF_ELSE) + break + if n.type == "match_expression": + self._tags.add(CodeProperty.HAS_SWITCH) + + def _check_operations(self, n): + """Check for various operations.""" + if n.type == "index_expression": + self._tags.add(CodeProperty.HAS_LIST_INDEXING) + if n.type == "call_expression": + self._tags.add(CodeProperty.HAS_FUNCTION_CALL) + if n.type == "return_expression": + self._tags.add(CodeProperty.HAS_RETURN) + if n.type in ["let_declaration", "const_item", "static_item"]: + self._tags.add(CodeProperty.HAS_ASSIGNMENT) + + def _check_expressions(self, n): + """Check for expression patterns.""" + if n.type == "binary_expression": + self._tags.add(CodeProperty.HAS_BINARY_OP) + if n.type == "unary_expression": + self._tags.add(CodeProperty.HAS_UNARY_OP) + if n.type == "closure_expression": + self._tags.add(CodeProperty.HAS_LAMBDA) + + @property + def complexity(self) -> int: + """Calculate cyclomatic complexity for Rust code.""" + + def walk(node): + score = 0 + if node.type in [ + "!=", + "&&", + "<", + "<=", + "==", + ">", + ">=", + "||", + "match_arm", + "else_clause", + "for_expression", + "while_expression", + "loop_expression", + "if_expression", + ]: + score += 1 + + for child in node.children: + score += walk(child) + + return score + + return 1 + walk(self.node) + @property def name(self) -> str: func_query = Query(RUST_LANGUAGE, "(function_item name: (identifier) @name)") diff --git a/swesmith/bug_gen/procedural/__init__.py b/swesmith/bug_gen/procedural/__init__.py index fe088a76..2f0f9393 100644 --- a/swesmith/bug_gen/procedural/__init__.py +++ b/swesmith/bug_gen/procedural/__init__.py @@ -9,8 +9,10 @@ # For backward compatibility, expose Python-specific classes from swesmith.bug_gen.procedural.golang import MODIFIERS_GOLANG from swesmith.bug_gen.procedural.python import MODIFIERS_PYTHON +from swesmith.bug_gen.procedural.rust import MODIFIERS_RUST MAP_EXT_TO_MODIFIERS = { ".go": MODIFIERS_GOLANG, ".py": MODIFIERS_PYTHON, + ".rs": MODIFIERS_RUST, } diff --git a/swesmith/bug_gen/procedural/generate.py b/swesmith/bug_gen/procedural/generate.py index 36297e1e..6d8d539c 100644 --- a/swesmith/bug_gen/procedural/generate.py +++ b/swesmith/bug_gen/procedural/generate.py @@ -65,6 +65,7 @@ def main( repo: str, max_bugs: int, seed: int, + interleave: bool = False, ): random.seed(seed) total = 0 @@ -73,26 +74,58 @@ def main( entities = rp.extract_entities() print(f"Found {len(entities)} entities in {repo}.") - for ext, pm_list in MAP_EXT_TO_MODIFIERS.items(): - for pm in pm_list: - candidates = [ - x - for x in entities - if Path(x.file_path).suffix == ext and pm.can_change(x) - ] - if not candidates: - continue - print(f"[{repo}] Found {len(candidates)} candidates for {pm.name}.") - - log_dir = LOG_DIR_BUG_GEN / repo - log_dir.mkdir(parents=True, exist_ok=True) - print(f"Logging bugs to {log_dir}") - - if max_bugs > 0 and len(candidates) > max_bugs: - candidates = random.sample(candidates, max_bugs) - - for candidate in tqdm(candidates): - total += _process_candidate(candidate, pm, log_dir, repo) + log_dir = LOG_DIR_BUG_GEN / repo + log_dir.mkdir(parents=True, exist_ok=True) + print(f"Logging bugs to {log_dir}") + + if interleave: + # Build all (candidate, modifier) pairs upfront + pairs = [] + for ext, pm_list in MAP_EXT_TO_MODIFIERS.items(): + for pm in pm_list: + candidates = [ + x + for x in entities + if Path(x.file_path).suffix == ext and pm.can_change(x) + ] + if not candidates: + continue + print(f"[{repo}] Found {len(candidates)} candidates for {pm.name}.") + + if max_bugs > 0 and len(candidates) > max_bugs: + candidates = random.sample(candidates, max_bugs) + + # Add all pairs for this modifier + for candidate in candidates: + pairs.append((candidate, pm)) + + # Shuffle all pairs to interleave modifiers + random.shuffle(pairs) + print( + f"[{repo}] Processing {len(pairs)} (candidate, modifier) pairs in randomized order." + ) + + # Process in randomized order + for candidate, pm in tqdm(pairs): + total += _process_candidate(candidate, pm, log_dir, repo) + else: + # Sequential processing (original behavior) + for ext, pm_list in MAP_EXT_TO_MODIFIERS.items(): + for pm in pm_list: + candidates = [ + x + for x in entities + if Path(x.file_path).suffix == ext and pm.can_change(x) + ] + if not candidates: + continue + print(f"[{repo}] Found {len(candidates)} candidates for {pm.name}.") + + if max_bugs > 0 and len(candidates) > max_bugs: + candidates = random.sample(candidates, max_bugs) + + for candidate in tqdm(candidates): + total += _process_candidate(candidate, pm, log_dir, repo) shutil.rmtree(repo) print(f"Generated {total} bugs for {repo}.") @@ -119,6 +152,11 @@ def main( default=-1, help="Maximum number of bugs to generate.", ) + parser.add_argument( + "--interleave", + action="store_true", + help="Randomize and interleave modifiers instead of processing sequentially.", + ) args = parser.parse_args() main(**vars(args)) diff --git a/swesmith/bug_gen/procedural/rust/__init__.py b/swesmith/bug_gen/procedural/rust/__init__.py index e69de29b..73f0f0b6 100644 --- a/swesmith/bug_gen/procedural/rust/__init__.py +++ b/swesmith/bug_gen/procedural/rust/__init__.py @@ -0,0 +1,30 @@ +from swesmith.bug_gen.procedural.base import ProceduralModifier +from swesmith.bug_gen.procedural.rust.control_flow import ( + ControlIfElseInvertModifier, + ControlShuffleLinesModifier, +) +from swesmith.bug_gen.procedural.rust.operations import ( + OperationBreakChainsModifier, + OperationChangeConstantsModifier, + OperationChangeModifier, + OperationFlipOperatorModifier, + OperationSwapOperandsModifier, +) +from swesmith.bug_gen.procedural.rust.remove import ( + RemoveAssignModifier, + RemoveConditionalModifier, + RemoveLoopModifier, +) + +MODIFIERS_RUST: list[ProceduralModifier] = [ + ControlIfElseInvertModifier(likelihood=0.25), + ControlShuffleLinesModifier(likelihood=0.25), + RemoveAssignModifier(likelihood=0.25), + RemoveConditionalModifier(likelihood=0.25), + RemoveLoopModifier(likelihood=0.25), + OperationBreakChainsModifier(likelihood=0.25), + OperationChangeConstantsModifier(likelihood=0.25), + OperationChangeModifier(likelihood=0.25), + OperationFlipOperatorModifier(likelihood=0.25), + OperationSwapOperandsModifier(likelihood=0.25), +] diff --git a/swesmith/bug_gen/procedural/rust/base.py b/swesmith/bug_gen/procedural/rust/base.py new file mode 100644 index 00000000..892cd5c3 --- /dev/null +++ b/swesmith/bug_gen/procedural/rust/base.py @@ -0,0 +1,6 @@ +from abc import ABC +from swesmith.bug_gen.procedural.base import ProceduralModifier + + +class RustProceduralModifier(ProceduralModifier, ABC): + """Base class for Rust-specific procedural modifications.""" diff --git a/swesmith/bug_gen/procedural/rust/control_flow.py b/swesmith/bug_gen/procedural/rust/control_flow.py new file mode 100644 index 00000000..52a30708 --- /dev/null +++ b/swesmith/bug_gen/procedural/rust/control_flow.py @@ -0,0 +1,202 @@ +import tree_sitter_rust as tsrs + +from swesmith.bug_gen.procedural.base import CommonPMs +from swesmith.bug_gen.procedural.rust.base import RustProceduralModifier +from swesmith.constants import BugRewrite, CodeEntity +from tree_sitter import Language, Parser + +RUST_LANGUAGE = Language(tsrs.language()) + + +class ControlIfElseInvertModifier(RustProceduralModifier): + explanation: str = CommonPMs.CONTROL_IF_ELSE_INVERT.explanation + name: str = CommonPMs.CONTROL_IF_ELSE_INVERT.name + conditions: list = CommonPMs.CONTROL_IF_ELSE_INVERT.conditions + min_complexity: int = 5 + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply if-else inversion to the Rust code.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + changed = False + + for _ in range(self.max_attempts): + modified_code = self._invert_if_else_statements( + code_entity.src_code, tree.root_node + ) + + if modified_code != code_entity.src_code: + changed = True + break + + if not changed: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _invert_if_else_statements(self, source_code: str, node) -> str: + """Recursively find and invert if-else statements by swapping the bodies.""" + modifications = [] + + def collect_if_statements(n): + if n.type == "if_expression": + if_condition = None + if_body = None + else_clause = None + else_body = None + + for i, child in enumerate(n.children): + if child.type == "if": + continue + elif if_condition is None and child.type in [ + "binary_expression", + "identifier", + "call_expression", + "field_expression", + "unary_expression", + ]: + if_condition = child + elif child.type == "block" and if_body is None: + if_body = child + elif child.type == "else_clause": + else_clause = child + for else_child in child.children: + if else_child.type == "block": + else_body = else_child + break + break + + if ( + if_condition + and if_body + and else_clause + and else_body + and self.flip() + ): + modifications.append((n, if_condition, if_body, else_body)) + + for child in n.children: + collect_if_statements(child) + + collect_if_statements(node) + + if not modifications: + return source_code + + modified_source = source_code + for if_node, condition, if_body, else_body in reversed(modifications): + if_start = if_node.start_byte + if_body_start = if_body.start_byte + + prefix = source_code[if_start:if_body_start].strip() + + if_body_text = source_code[if_body.start_byte : if_body.end_byte] + else_body_text = source_code[else_body.start_byte : else_body.end_byte] + + new_if_else = f"{prefix} {else_body_text} else {if_body_text}" + + start_byte = if_node.start_byte + end_byte = if_node.end_byte + + modified_source = ( + modified_source[:start_byte] + new_if_else + modified_source[end_byte:] + ) + + return modified_source + + +class ControlShuffleLinesModifier(RustProceduralModifier): + explanation: str = CommonPMs.CONTROL_SHUFFLE_LINES.explanation + name: str = CommonPMs.CONTROL_SHUFFLE_LINES.name + conditions: list = CommonPMs.CONTROL_SHUFFLE_LINES.conditions + max_complexity: int = 10 + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply line shuffling to the Rust function body.""" + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._shuffle_function_statements( + code_entity.src_code, tree.root_node + ) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _shuffle_function_statements(self, source_code: str, node) -> str: + """Recursively find function declarations and shuffle their statements.""" + modifications = [] + + def collect_function_declarations(n): + if n.type == "function_item": + body_block = None + for child in n.children: + if child.type == "block": + body_block = child + break + + if body_block: + statements = [] + for child in body_block.children: + if child.type not in ["{", "}"]: + statements.append(child) + + if len(statements) >= 2: + modifications.append((body_block, statements)) + + for child in n.children: + collect_function_declarations(child) + + collect_function_declarations(node) + + if not modifications: + return source_code + + modified_source = source_code + for body_block, statements in reversed(modifications): + shuffled_indices = list(range(len(statements))) + self.rand.shuffle(shuffled_indices) + + if shuffled_indices == list(range(len(statements))): + if len(statements) >= 2: + shuffled_indices[0], shuffled_indices[1] = ( + shuffled_indices[1], + shuffled_indices[0], + ) + + statement_texts = [] + for stmt in statements: + stmt_text = source_code[stmt.start_byte : stmt.end_byte] + statement_texts.append(stmt_text) + + shuffled_texts = [statement_texts[i] for i in shuffled_indices] + + first_stmt_start = statements[0].start_byte + last_stmt_end = statements[-1].end_byte + + line_start = source_code.rfind("\n", 0, first_stmt_start) + 1 + indent = source_code[line_start:first_stmt_start] + + new_content = ("\n" + indent).join(shuffled_texts) + + modified_source = ( + modified_source[:first_stmt_start] + + new_content + + modified_source[last_stmt_end:] + ) + + return modified_source diff --git a/swesmith/bug_gen/procedural/rust/operations.py b/swesmith/bug_gen/procedural/rust/operations.py index e69de29b..c460e287 100644 --- a/swesmith/bug_gen/procedural/rust/operations.py +++ b/swesmith/bug_gen/procedural/rust/operations.py @@ -0,0 +1,425 @@ +import tree_sitter_rust as tsrs + +from swesmith.bug_gen.procedural.base import CommonPMs +from swesmith.bug_gen.procedural.rust.base import RustProceduralModifier +from swesmith.constants import BugRewrite, CodeEntity +from tree_sitter import Language, Parser + +RUST_LANGUAGE = Language(tsrs.language()) + +ALL_BINARY_OPERATORS = [ + "+", + "-", + "*", + "/", + "%", + "<<", + ">>", + "&", + "|", + "^", + "==", + "!=", + "<", + "<=", + ">", + ">=", + "&&", + "||", +] + +FLIPPED_OPERATORS = { + "+": "-", + "-": "+", + "*": "/", + "/": "*", + "%": "*", + "<<": ">>", + ">>": "<<", + "&": "|", + "|": "&", + "^": "&", + "==": "!=", + "!=": "==", + "<": ">", + "<=": ">=", + ">": "<", + ">=": "<=", + "&&": "||", + "||": "&&", +} + +# Operator groups for systematic changes +ARITHMETIC_OPS = ["+", "-", "*", "/", "%"] +BITWISE_OPS = ["&", "|", "^", "<<", ">>"] +COMPARISON_OPS = ["==", "!=", "<", "<=", ">", ">="] +LOGICAL_OPS = ["&&", "||"] + +ALL_BINARY_OPERATORS = [ + "+", + "-", + "*", + "/", + "%", + "<<", + ">>", + "&", + "|", + "^", + "==", + "!=", + "<", + "<=", + ">", + ">=", + "&&", + "||", +] + + +class OperationChangeModifier(RustProceduralModifier): + explanation: str = CommonPMs.OPERATION_CHANGE.explanation + name: str = CommonPMs.OPERATION_CHANGE.name + conditions: list = CommonPMs.OPERATION_CHANGE.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply operation changes to Rust binary expressions.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._change_operations(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _change_operations(self, source_code: str, node) -> str: + """Recursively find and change binary operations.""" + modifications = [] + + def collect_binary_ops(n): + if n.type == "binary_expression": + operator_node = None + for child in n.children: + if child.type in ALL_BINARY_OPERATORS: + operator_node = child + break + + if operator_node and self.flip(): + op = operator_node.text.decode("utf-8") + new_op = self._get_alternative_operator(op) + if new_op != op: + modifications.append((operator_node, new_op)) + + for child in n.children: + collect_binary_ops(child) + + collect_binary_ops(node) + + modified_code = source_code + for operator_node, new_op in sorted( + modifications, key=lambda x: x[0].start_byte, reverse=True + ): + start_byte = operator_node.start_byte + end_byte = operator_node.end_byte + modified_code = ( + modified_code[:start_byte] + new_op + modified_code[end_byte:] + ) + + return modified_code + + def _get_alternative_operator(self, op: str) -> str: + """Get an alternative operator from the same category.""" + if op in ARITHMETIC_OPS: + return self.rand.choice(ARITHMETIC_OPS) + elif op in BITWISE_OPS: + return self.rand.choice(BITWISE_OPS) + elif op in COMPARISON_OPS: + return self.rand.choice(COMPARISON_OPS) + elif op in LOGICAL_OPS: + return self.rand.choice(LOGICAL_OPS) + return op + + +class OperationFlipOperatorModifier(RustProceduralModifier): + explanation: str = CommonPMs.OPERATION_FLIP_OPERATOR.explanation + name: str = CommonPMs.OPERATION_FLIP_OPERATOR.name + conditions: list = CommonPMs.OPERATION_FLIP_OPERATOR.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply operator flipping to Rust binary expressions.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._flip_operators(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _flip_operators(self, source_code: str, node) -> str: + """Recursively find and flip binary operations.""" + modifications = [] + + def collect_binary_ops(n): + if n.type == "binary_expression": + operator_node = None + left_operand = None + + for i, child in enumerate(n.children): + if child.type in FLIPPED_OPERATORS: + operator_node = child + if i > 0: + left_operand = n.children[0] + break + + if operator_node and self.flip(): + op = operator_node.text.decode("utf-8") + if op in FLIPPED_OPERATORS: + if ( + op == "*" + and left_operand + and left_operand.type == "range_expression" + ): + pass # Skip this - it's a dereference, not multiplication + else: + modifications.append((operator_node, FLIPPED_OPERATORS[op])) + + for child in n.children: + collect_binary_ops(child) + + collect_binary_ops(node) + + modified_code = source_code + for operator_node, new_op in sorted( + modifications, key=lambda x: x[0].start_byte, reverse=True + ): + start_byte = operator_node.start_byte + end_byte = operator_node.end_byte + modified_code = ( + modified_code[:start_byte] + new_op + modified_code[end_byte:] + ) + + return modified_code + + +class OperationSwapOperandsModifier(RustProceduralModifier): + explanation: str = CommonPMs.OPERATION_SWAP_OPERANDS.explanation + name: str = CommonPMs.OPERATION_SWAP_OPERANDS.name + conditions: list = CommonPMs.OPERATION_SWAP_OPERANDS.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply operand swapping to Rust binary expressions.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._swap_operands(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _swap_operands(self, source_code: str, node) -> str: + """Recursively find and swap operands in binary operations.""" + modifications = [] + + def collect_binary_ops(n): + if n.type == "binary_expression" and len(n.children) >= 3: + if self.flip(): + left_operand = n.children[0] + operator = None + right_operand = None + + for i, child in enumerate(n.children[1:], 1): + if child.type in ALL_BINARY_OPERATORS: + operator = child + if i + 1 < len(n.children): + right_operand = n.children[i + 1] + break + + if left_operand and operator and right_operand: + modifications.append((n, left_operand, operator, right_operand)) + + for child in n.children: + collect_binary_ops(child) + + collect_binary_ops(node) + + modified_code = source_code + for expr_node, left, op, right in sorted( + modifications, key=lambda x: x[0].start_byte, reverse=True + ): + start_byte = expr_node.start_byte + end_byte = expr_node.end_byte + + left_text = left.text.decode("utf-8") + op_text = op.text.decode("utf-8") + right_text = right.text.decode("utf-8") + + new_expr = f"{right_text} {op_text} {left_text}" + modified_code = ( + modified_code[:start_byte] + new_expr + modified_code[end_byte:] + ) + + return modified_code + + +class OperationBreakChainsModifier(RustProceduralModifier): + explanation: str = CommonPMs.OPERATION_BREAK_CHAINS.explanation + name: str = CommonPMs.OPERATION_BREAK_CHAINS.name + conditions: list = CommonPMs.OPERATION_BREAK_CHAINS.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply chain breaking to Rust binary expressions.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._break_chains(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _break_chains(self, source_code: str, node) -> str: + """Recursively find and break chains in binary operations.""" + modifications = [] + + def collect_binary_ops(n): + if n.type == "binary_expression" and self.flip(): + left_operand = n.children[0] if n.children else None + right_operand = None + + for i, child in enumerate(n.children[1:], 1): + if child.type not in ALL_BINARY_OPERATORS: + right_operand = child + break + + if left_operand and left_operand.type == "binary_expression": + inner_left = ( + left_operand.children[0] if left_operand.children else None + ) + if inner_left: + modifications.append((n, inner_left)) + elif right_operand and right_operand.type == "binary_expression": + inner_right = None + for child in reversed(right_operand.children): + if child.type not in ALL_BINARY_OPERATORS: + inner_right = child + break + if inner_right: + modifications.append((n, inner_right)) + + for child in n.children: + collect_binary_ops(child) + + collect_binary_ops(node) + + modified_code = source_code + for expr_node, replacement in sorted( + modifications, key=lambda x: x[0].start_byte, reverse=True + ): + start_byte = expr_node.start_byte + end_byte = expr_node.end_byte + replacement_text = replacement.text.decode("utf-8") + modified_code = ( + modified_code[:start_byte] + replacement_text + modified_code[end_byte:] + ) + + return modified_code + + +class OperationChangeConstantsModifier(RustProceduralModifier): + explanation: str = CommonPMs.OPERATION_CHANGE_CONSTANTS.explanation + name: str = CommonPMs.OPERATION_CHANGE_CONSTANTS.name + conditions: list = CommonPMs.OPERATION_CHANGE_CONSTANTS.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Apply constant changes to Rust binary expressions.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._change_constants(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _change_constants(self, source_code: str, node) -> str: + """Recursively find and modify constants in binary operations.""" + modifications = [] + + def collect_binary_ops(n): + if n.type == "binary_expression": + for child in n.children: + if child.type == "integer_literal" and self.flip(): + try: + value = int(child.text.decode("utf-8")) + new_value = value + self.rand.choice([-1, 1]) + modifications.append((child, str(new_value))) + except ValueError: + pass + elif child.type == "float_literal" and self.flip(): + try: + value = float(child.text.decode("utf-8")) + delta = self.rand.choice([-0.1, 0.1, -1.0, 1.0]) + new_value = value + delta + modifications.append((child, str(new_value))) + except ValueError: + pass + + for child in n.children: + collect_binary_ops(child) + + collect_binary_ops(node) + + modified_code = source_code + for const_node, new_value in sorted( + modifications, key=lambda x: x[0].start_byte, reverse=True + ): + start_byte = const_node.start_byte + end_byte = const_node.end_byte + modified_code = ( + modified_code[:start_byte] + new_value + modified_code[end_byte:] + ) + + return modified_code diff --git a/swesmith/bug_gen/procedural/rust/remove.py b/swesmith/bug_gen/procedural/rust/remove.py new file mode 100644 index 00000000..69c4ee00 --- /dev/null +++ b/swesmith/bug_gen/procedural/rust/remove.py @@ -0,0 +1,158 @@ +import tree_sitter_rust as tsrs + +from swesmith.bug_gen.procedural.base import CommonPMs +from swesmith.bug_gen.procedural.rust.base import RustProceduralModifier +from swesmith.constants import BugRewrite, CodeEntity +from tree_sitter import Language, Parser + +RUST_LANGUAGE = Language(tsrs.language()) + + +class RemoveLoopModifier(RustProceduralModifier): + explanation: str = CommonPMs.REMOVE_LOOP.explanation + name: str = CommonPMs.REMOVE_LOOP.name + conditions: list = CommonPMs.REMOVE_LOOP.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Remove loop statements from the Rust code.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._remove_loops(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _remove_loops(self, source_code: str, node) -> str: + """Recursively find and remove loop statements.""" + removals = [] + + def collect_loops(n): + if n.type in ["for_expression", "while_expression", "loop_expression"]: + if self.flip(): + removals.append(n) + for child in n.children: + collect_loops(child) + + collect_loops(node) + + if not removals: + return source_code + + modified_source = source_code + for loop_node in reversed(removals): + start_byte = loop_node.start_byte + end_byte = loop_node.end_byte + + modified_source = modified_source[:start_byte] + modified_source[end_byte:] + + return modified_source + + +class RemoveConditionalModifier(RustProceduralModifier): + explanation: str = CommonPMs.REMOVE_CONDITIONAL.explanation + name: str = CommonPMs.REMOVE_CONDITIONAL.name + conditions: list = CommonPMs.REMOVE_CONDITIONAL.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Remove conditional statements from the Rust code.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._remove_conditionals(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _remove_conditionals(self, source_code: str, node) -> str: + """Recursively find and remove conditional statements.""" + removals = [] + + def collect_conditionals(n): + if n.type == "if_expression": + if self.flip(): + removals.append(n) + for child in n.children: + collect_conditionals(child) + + collect_conditionals(node) + + if not removals: + return source_code + + modified_source = source_code + for if_node in reversed(removals): + start_byte = if_node.start_byte + end_byte = if_node.end_byte + + modified_source = modified_source[:start_byte] + modified_source[end_byte:] + + return modified_source + + +class RemoveAssignModifier(RustProceduralModifier): + explanation: str = CommonPMs.REMOVE_ASSIGNMENT.explanation + name: str = CommonPMs.REMOVE_ASSIGNMENT.name + conditions: list = CommonPMs.REMOVE_ASSIGNMENT.conditions + + def modify(self, code_entity: CodeEntity) -> BugRewrite: + """Remove assignment statements from the Rust code.""" + if not self.flip(): + return None + + parser = Parser(RUST_LANGUAGE) + tree = parser.parse(bytes(code_entity.src_code, "utf8")) + + modified_code = self._remove_assignments(code_entity.src_code, tree.root_node) + + if modified_code == code_entity.src_code: + return None + + return BugRewrite( + rewrite=modified_code, + explanation=self.explanation, + strategy=self.name, + ) + + def _remove_assignments(self, source_code: str, node) -> str: + """Recursively find and remove assignment statements.""" + removals = [] + + def collect_assignments(n): + if n.type in ["let_declaration", "assignment_expression"]: + if self.flip(): + removals.append(n) + for child in n.children: + collect_assignments(child) + + collect_assignments(node) + + if not removals: + return source_code + + modified_source = source_code + for assign_node in reversed(removals): + start_byte = assign_node.start_byte + end_byte = assign_node.end_byte + + modified_source = modified_source[:start_byte] + modified_source[end_byte:] + + return modified_source diff --git a/tests/bug_gen/procedural/rust/__init__.py b/tests/bug_gen/procedural/rust/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bug_gen/procedural/rust/test_rust_control_flow.py b/tests/bug_gen/procedural/rust/test_rust_control_flow.py new file mode 100644 index 00000000..9751d059 --- /dev/null +++ b/tests/bug_gen/procedural/rust/test_rust_control_flow.py @@ -0,0 +1,144 @@ +import pytest +import tempfile +import os +from swesmith.bug_gen.adapters.rust import get_entities_from_file_rs +from swesmith.bug_gen.procedural.rust.control_flow import ( + ControlIfElseInvertModifier, + ControlShuffleLinesModifier, +) +import random + + +@pytest.mark.parametrize( + "src,expected", + [ + ( + """fn foo(x: i32) -> i32 { + if x > 0 { + return 1; + } else { + return -1; + } +}""", + """fn foo(x: i32) -> i32 { + if x > 0 { + return -1; + } else { + return 1; + } +}""", + ), + ( + """fn bar(condition: bool) -> &str { + if condition { + "true" + } else { + "false" + } +}""", + """fn bar(condition: bool) -> &str { + if condition { + "false" + } else { + "true" + } +}""", + ), + ( + """fn baz(x: i32) -> i32 { + if x == 0 { + let y = 1; + y + 2 + } else { + let z = 3; + z + 4 + } +}""", + """fn baz(x: i32) -> i32 { + if x == 0 { + let z = 3; + z + 4 + } else { + let y = 1; + y + 2 + } +}""", + ), + ], +) +def test_control_if_else_invert_modifier(src, expected): + """Test that ControlIfElseInvertModifier inverts if-else branches.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = ControlIfElseInvertModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert result.rewrite.strip() == expected.strip(), ( + f"Expected {expected}, got {result.rewrite}" + ) + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected_variants", + [ + ( + """fn foo() { + let a = 1; + let b = 2; +}""", + [ + "fn foo() {\n let a = 1;\n let b = 2;\n}", + "fn foo() {\n let b = 2;\n let a = 1;\n}", + ], + ), + ( + """fn bar() { + let x = 1; + let y = 2; + let z = 3; +}""", + [ + "fn bar() {\n let x = 1;\n let y = 2;\n let z = 3;\n}", + "fn bar() {\n let x = 1;\n let z = 3;\n let y = 2;\n}", + "fn bar() {\n let y = 2;\n let x = 1;\n let z = 3;\n}", + "fn bar() {\n let y = 2;\n let z = 3;\n let x = 1;\n}", + "fn bar() {\n let z = 3;\n let x = 1;\n let y = 2;\n}", + "fn bar() {\n let z = 3;\n let y = 2;\n let x = 1;\n}", + ], + ), + ], +) +def test_control_shuffle_lines_modifier(src, expected_variants): + """Test that ControlShuffleLinesModifier shuffles independent lines.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = ControlShuffleLinesModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert any( + result.rewrite.strip() == variant.strip() for variant in expected_variants + ), f"Expected one of {expected_variants}, got {result.rewrite}" + finally: + os.unlink(temp_path) diff --git a/tests/bug_gen/procedural/rust/test_rust_operations.py b/tests/bug_gen/procedural/rust/test_rust_operations.py new file mode 100644 index 00000000..5e2a06c9 --- /dev/null +++ b/tests/bug_gen/procedural/rust/test_rust_operations.py @@ -0,0 +1,320 @@ +import pytest +import tempfile +import os +from swesmith.bug_gen.adapters.rust import get_entities_from_file_rs +from swesmith.bug_gen.procedural.rust.operations import ( + OperationChangeModifier, + OperationFlipOperatorModifier, + OperationSwapOperandsModifier, + OperationBreakChainsModifier, + OperationChangeConstantsModifier, + FLIPPED_OPERATORS, +) +import random + + +@pytest.mark.parametrize( + "src,expected_variants", + [ + ( + """fn foo(a: i32, b: i32) -> i32 { + a + b +}""", + [ + "fn foo(a: i32, b: i32) -> i32 {\n a - b\n}", + "fn foo(a: i32, b: i32) -> i32 {\n a * b\n}", + "fn foo(a: i32, b: i32) -> i32 {\n a / b\n}", + "fn foo(a: i32, b: i32) -> i32 {\n a % b\n}", + ], + ), + ( + """fn bar(x: i32, y: i32) -> bool { + x == y +}""", + [ + "fn bar(x: i32, y: i32) -> bool {\n x != y\n}", + "fn bar(x: i32, y: i32) -> bool {\n x < y\n}", + "fn bar(x: i32, y: i32) -> bool {\n x <= y\n}", + "fn bar(x: i32, y: i32) -> bool {\n x > y\n}", + "fn bar(x: i32, y: i32) -> bool {\n x >= y\n}", + ], + ), + ( + """fn baz(a: u32, b: u32) -> u32 { + a & b +}""", + [ + "fn baz(a: u32, b: u32) -> u32 {\n a | b\n}", + "fn baz(a: u32, b: u32) -> u32 {\n a ^ b\n}", + ], + ), + ], +) +def test_operation_change_modifier(src, expected_variants): + """Test that OperationChangeModifier changes operators within the same category.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = OperationChangeModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + + found_variant = False + for _ in range(20): + result = modifier.modify(entities[0]) + if ( + result + and result.rewrite != src + and any( + result.rewrite.strip() == variant.strip() + for variant in expected_variants + ) + ): + found_variant = True + break + + assert found_variant, ( + f"Expected one of {expected_variants}, but got {result.rewrite if result else 'None'}" + ) + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected", + [ + ( + """fn foo(a: i32, b: i32) -> i32 { + a + b +}""", + """fn foo(a: i32, b: i32) -> i32 { + a - b +}""", + ), + ( + """fn bar(x: i32, y: i32) -> bool { + x == y +}""", + """fn bar(x: i32, y: i32) -> bool { + x != y +}""", + ), + ( + """fn baz(a: i32, b: i32) -> bool { + a < b +}""", + """fn baz(a: i32, b: i32) -> bool { + a > b +}""", + ), + ( + """fn qux(x: bool, y: bool) -> bool { + x && y +}""", + """fn qux(x: bool, y: bool) -> bool { + x || y +}""", + ), + ], +) +def test_operation_flip_operator_modifier(src, expected): + """Test that OperationFlipOperatorModifier flips operators to their opposites.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = OperationFlipOperatorModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert result.rewrite.strip() == expected.strip(), ( + f"Expected {expected}, got {result.rewrite}" + ) + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected", + [ + ( + """fn foo(a: i32, b: i32) -> i32 { + a + b +}""", + """fn foo(a: i32, b: i32) -> i32 { + b + a +}""", + ), + ( + """fn bar(x: i32, y: i32) -> bool { + x < y +}""", + """fn bar(x: i32, y: i32) -> bool { + y < x +}""", + ), + ( + """fn baz(a: i32, b: i32) -> i32 { + a - b +}""", + """fn baz(a: i32, b: i32) -> i32 { + b - a +}""", + ), + ], +) +def test_operation_swap_operands_modifier(src, expected): + """Test that OperationSwapOperandsModifier swaps operands in binary expressions.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = OperationSwapOperandsModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert result.rewrite.strip() == expected.strip(), ( + f"Expected {expected}, got {result.rewrite}" + ) + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected_variants", + [ + ( + """fn foo(a: i32, b: i32, c: i32) -> i32 { + a + b + c +}""", + [ + "fn foo(a: i32, b: i32, c: i32) -> i32 {\n a\n}", + "fn foo(a: i32, b: i32, c: i32) -> i32 {\n c\n}", + ], + ), + ], +) +def test_operation_break_chains_modifier(src, expected_variants): + """Test that OperationBreakChainsModifier breaks operation chains.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = OperationBreakChainsModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert any( + result.rewrite.strip() == variant.strip() for variant in expected_variants + ), f"Expected one of {expected_variants}, got {result.rewrite}" + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected_variants", + [ + ( + """fn foo() -> i32 { + 2 + x +}""", + [ + "fn foo() -> i32 {\n 1 + x\n}", + "fn foo() -> i32 {\n 3 + x\n}", + ], + ), + ( + """fn bar() -> i32 { + y - 5 +}""", + [ + "fn bar() -> i32 {\n y - 4\n}", + "fn bar() -> i32 {\n y - 6\n}", + ], + ), + ( + """fn baz() -> i32 { + 10 * 20 +}""", + [ + "fn baz() -> i32 {\n 9 * 20\n}", + "fn baz() -> i32 {\n 11 * 20\n}", + "fn baz() -> i32 {\n 10 * 19\n}", + "fn baz() -> i32 {\n 10 * 21\n}", + "fn baz() -> i32 {\n 9 * 19\n}", + "fn baz() -> i32 {\n 9 * 21\n}", + "fn baz() -> i32 {\n 11 * 19\n}", + "fn baz() -> i32 {\n 11 * 21\n}", + ], + ), + ], +) +def test_operation_change_constants_modifier(src, expected_variants): + """Test that OperationChangeConstantsModifier changes integer constants.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = OperationChangeConstantsModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert any( + result.rewrite.strip() == variant.strip() for variant in expected_variants + ), f"Expected one of {expected_variants}, got {result.rewrite}" + finally: + os.unlink(temp_path) + + +def test_operation_flip_operator_mappings(): + """Test that OperationFlipOperatorModifier uses correct operator mappings.""" + assert FLIPPED_OPERATORS["+"] == "-" + assert FLIPPED_OPERATORS["-"] == "+" + assert FLIPPED_OPERATORS["*"] == "/" + assert FLIPPED_OPERATORS["/"] == "*" + assert FLIPPED_OPERATORS["=="] == "!=" + assert FLIPPED_OPERATORS["!="] == "==" + assert FLIPPED_OPERATORS["<"] == ">" + assert FLIPPED_OPERATORS[">"] == "<" + assert FLIPPED_OPERATORS["<="] == ">=" + assert FLIPPED_OPERATORS[">="] == "<=" + assert FLIPPED_OPERATORS["&&"] == "||" + assert FLIPPED_OPERATORS["||"] == "&&" + assert FLIPPED_OPERATORS["&"] == "|" + assert FLIPPED_OPERATORS["|"] == "&" + assert FLIPPED_OPERATORS["<<"] == ">>" + assert FLIPPED_OPERATORS[">>"] == "<<" diff --git a/tests/bug_gen/procedural/rust/test_rust_remove.py b/tests/bug_gen/procedural/rust/test_rust_remove.py new file mode 100644 index 00000000..0de4d7ac --- /dev/null +++ b/tests/bug_gen/procedural/rust/test_rust_remove.py @@ -0,0 +1,217 @@ +import pytest +import tempfile +import os +from swesmith.bug_gen.adapters.rust import get_entities_from_file_rs +from swesmith.bug_gen.procedural.rust.remove import ( + RemoveLoopModifier, + RemoveConditionalModifier, + RemoveAssignModifier, +) +import random + + +@pytest.mark.parametrize( + "src,expected", + [ + ( + """fn foo() -> i32 { + for i in 0..3 { + println!("{}", i); + } + return 1; +}""", + """fn foo() -> i32 { + + return 1; +}""", + ), + ( + """fn bar() -> i32 { + while true { + break; + } + return 2; +}""", + """fn bar() -> i32 { + + return 2; +}""", + ), + ( + """fn baz() -> i32 { + let mut sum = 0; + for i in 0..10 { + sum += i; + } + sum +}""", + """fn baz() -> i32 { + let mut sum = 0; + + sum +}""", + ), + ], +) +def test_remove_loop_modifier(src, expected): + """Test that RemoveLoopModifier removes loop statements.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = RemoveLoopModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert result.rewrite.strip() == expected.strip(), ( + f"Expected {expected}, got {result.rewrite}" + ) + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected", + [ + ( + """fn foo(x: i32) -> i32 { + if x > 0 { + return x; + } + return 0; +}""", + """fn foo(x: i32) -> i32 { + + return 0; +}""", + ), + ( + """fn bar(x: i32) -> i32 { + if x < 0 { + return -1; + } else { + return 1; + } +}""", + """fn bar(x: i32) -> i32 { + +}""", + ), + ( + """fn baz(x: i32) -> i32 { + let mut result = 0; + if x > 10 { + result = x * 2; + } + result +}""", + """fn baz(x: i32) -> i32 { + let mut result = 0; + + result +}""", + ), + ], +) +def test_remove_conditional_modifier(src, expected): + """Test that RemoveConditionalModifier removes conditional statements.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = RemoveConditionalModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert result.rewrite.strip() == expected.strip(), ( + f"Expected {expected}, got {result.rewrite}" + ) + finally: + os.unlink(temp_path) + + +@pytest.mark.parametrize( + "src,expected", + [ + ( + """fn foo() -> i32 { + let x = 1; + return x; +}""", + """fn foo() -> i32 { + + return x; +}""", + ), + ( + """fn bar() -> i32 { + let mut y = 2; + y += 3; + return y; +}""", + """fn bar() -> i32 { + + y += 3; + return y; +}""", + ), + ( + """fn baz() -> i32 { + let z: i32 = 10; + z * 2 +}""", + """fn baz() -> i32 { + + z * 2 +}""", + ), + ( + """fn qux() -> i32 { + let mut a = 5; + a *= 2; + a +}""", + """fn qux() -> i32 { + + a *= 2; + a +}""", + ), + ], +) +def test_remove_assign_modifier(src, expected): + """Test that RemoveAssignModifier removes assignment statements.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".rs", delete=False) as f: + f.write(src) + f.flush() + temp_path = f.name + + try: + entities = [] + get_entities_from_file_rs(entities, temp_path) + assert len(entities) == 1 + + modifier = RemoveAssignModifier(likelihood=1.0, seed=42) + modifier.rand = random.Random(42) + result = modifier.modify(entities[0]) + + assert result is not None + assert result.rewrite.strip() == expected.strip(), ( + f"Expected {expected}, got {result.rewrite}" + ) + finally: + os.unlink(temp_path)