Skip to content

Official implementation of "Learning to Better Search with Language Models via Guided Reinforced Self-Training" (NeurIPS 2025)

License

Notifications You must be signed in to change notification settings

snu-mllab/guided-rest

Repository files navigation

Guided-ReST

This is the code for the paper Learning to Better Search with Language Models via Guided Reinforced Self-Training accepted to NeurIPS 2025.

Setup

# conda
conda create --name guided-rest python=3.12
conda activate guided-rest

# uv
pip install uv

# vllm
uv pip install vllm[flashinfer] --torch-backend=cu128

# verl
uv pip install -e .[gpu] --no-build-isolation --torch-backend=cu128

# fix packages
uv pip uninstall pynvml

Countdown

  1. Download the SFT and RL datasets
python -m recipe.countdown.download_data
  1. Download the model and tokenizer
python -m recipe.countdown.download_model
  1. Remove trim in the chat template to avoid inconsistent encoding

  2. Train the base model

sh recipe/countdown/scripts/llama_3.2_1b/base/run_sft.sh
sh recipe/countdown/scripts/run_merge.sh model_name=llama_3.2_1b_base_sft/global_step_3906
  1. Run Guided-ReST
# Generate trajectories
sh recipe/countdown/scripts/run_gen.sh model_name=llama_3.2_1b_base_sft/global_step_3906 temperature=1.0 num_iters=3 split=train start=0 num_examples=200000
sh recipe/countdown/scripts/run_gen.sh model_name=llama_3.2_1b_base_sft/global_step_3906 temperature=1.0 num_iters=3 split=valid start=0 num_examples=1000

# Prepare data
sh recipe/countdown/scripts/run_data.sh model_name=llama_3.2_1b_base_sft/global_step_3906 temperature=1.0 num_iters=3 split=train
sh recipe/countdown/scripts/run_data.sh model_name=llama_3.2_1b_base_sft/global_step_3906 temperature=1.0 num_iters=3 split=valid

# Run SFT
sh recipe/countdown/scripts/llama_3.2_1b/guided_rest/run_sft_1.sh
sh recipe/countdown/scripts/run_merge.sh model_name=llama_3.2_1b_guided_rest_sft_1/global_step_1546

# Repeat the above steps for 3 iterations
  1. Run PPO
sh recipe/countdown/scripts/llama_3.2_1b/guided_rest/run_rl.sh
sh recipe/countdown/scripts/run_merge.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor
  1. Run evaluation
  • Greedy sampling
# Generate trajectories
sh recipe/countdown/scripts/run_gen.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=0.0 num_iters=0 split=test_seen start=0 num_examples=10000
sh recipe/countdown/scripts/run_gen.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=0.0 num_iters=0 split=test_unseen start=0 num_examples=10000

# Compute accuracy
sh recipe/countdown/scripts/run_eval.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=0.0 num_iters=0 split=test_seen
sh recipe/countdown/scripts/run_eval.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=0.0 num_iters=0 split=test_unseen
  • Random sampling
# Generate trajectories with seeds from 0 to 32
sh recipe/countdown/scripts/run_gen.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=1.0 num_iters=0 split=test_seen start=0 num_examples=10000 seed=[seed]
sh recipe/countdown/scripts/run_gen.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=1.0 num_iters=0 split=test_unseen start=0 num_examples=10000 seed=[seed]

# Compute accuracy
sh recipe/countdown/scripts/run_eval.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=1.0 num_iters=0 split=test_seen
sh recipe/countdown/scripts/run_eval.sh model_name=llama_3.2_1b_guided_rest_rl/global_step_390/actor temperature=1.0 num_iters=0 split=test_unseen

Code self-repair

  1. Download the SFT and RL datasets
python -m recipe.code_repair.download_data
  1. Download the model and tokenizer
python -m recipe.code_repair.download_model
  1. Run Guided-ReST
# Generate trajectories with seeds 0 to 8
sh recipe/code_repair/scripts/run_gen.sh model_name=qwen2.5_7b temperature=1.0 num_turns=4 num_iters=3 split=train start=0 num_examples=16000 seed=[seed]
sh recipe/code_repair/scripts/run_gen.sh model_name=qwen2.5_7b temperature=1.0 num_turns=4 num_iters=3 split=valid start=0 num_examples=300 seed=[seed]

# Prepare data
sh recipe/code_repair/scripts/run_data.sh model_name=qwen2.5_7b temperature=1.0 num_iters=3 split=train
sh recipe/code_repair/scripts/run_data.sh model_name=qwen2.5_7b temperature=1.0 num_iters=3 split=valid

# Run SFT
sh recipe/code_repair/scripts/qwen2.5_7b/guided_rest/run_sft_1.sh
sh recipe/code_repair/scripts/run_merge.sh model_name=qwen2.5_7b_guided_rest_sft_1/global_step_348

# Repeat the above steps for 3 iterations
  1. Run evaluation
# Generate trajectories with seeds from 0 to 128
sh recipe/code_repair/scripts/run_gen.sh model_name=qwen2.5_7b_guided_rest_sft_3/global_step_618 temperature=1.0 num_turns=4 num_iters=0 split=test_cc start=0 num_examples=200 seed=[seed]
sh recipe/code_repair/scripts/run_gen.sh model_name=qwen2.5_7b_guided_rest_sft_3/global_step_618 temperature=1.0 num_turns=4 num_iters=0 split=test_cf start=0 num_examples=500 seed=[seed]

# Compute accuracy
sh recipe/code_repair/scripts/run_eval.sh model_name=qwen2.5_7b_guided_rest_sft_3/global_step_618 temperature=1.0 num_turns=4 num_iters=0 split=test_cc
sh recipe/code_repair/scripts/run_eval.sh model_name=qwen2.5_7b_guided_rest_sft_3/global_step_618 temperature=1.0 num_turns=4 num_iters=0 split=test_cf

Note

  • We recommend generating trajectories in parallel across multiple GPUs.

Citation

If you find this code useful, please consider citing this work.

@inproceedings{moon2025learning,
    title={Learning to Better Search with Language Models via Guided Reinforced Self-Training},
    author={Seungyong Moon and Bumsoo Park and Hyun Oh Song},
    booktitle={Neural Information Processing Systems},
    year={2025}
}

About

Official implementation of "Learning to Better Search with Language Models via Guided Reinforced Self-Training" (NeurIPS 2025)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages