Skip to content

TransluceAI/latentqa

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

💭 LatentQA

This project contains the code to train and run the decoder LLM described in the paper LatentQA: Teaching LLMs to Decode Activations Into Natural Language. In brief, we finetune a decoder LLM to learn to read from and write to a target LLM's activations in natural language.

For more details, see the project page.

🧰 Setup

Clone this repo:

git clone https://github.com/aypan17/latentqa
cd latentqa

Install dependencies:

pip install -r requirements.txt

Download the decoder model (or train your own below):

📉 Training

To train the model, you will need LatentQA data and a GPU. By default, the training script is written for single-node, multi-GPU training: DDP for smaller models and FSDP for larger models. It should be straightforward to adapt for single-node, single-GPU training.

Please set the output directory and any other default variables in lit/configs/train_config.py. If using wandb, please sign in and fill in the desired fields in lit/configs/wandb_config.py.

For DDP, run:

torchrun --nnodes 1 --nproc-per-node $NUM_GPUS -m lit.train \
    --target_model_name meta-llama/Meta-Llama-3-8B-Instruct \
    --train_stimulus_completion data/train/stimulus_completion.json \
    --train_stimulus data/train/stimulus.json \
    --train_control data/train/control.json \
    --train_qa data/train/qa.json \
    --gradient_accumulation_steps 8 \ 
    --use_wandb

FSDP was tested on 8x A100-80GB cards. For FSDP, run:

torchrun --nnodes 1 --nproc-per-node 8 -m lit.train \
    --target_model_name meta-llama/Meta-Llama-3-70B-Instruct \
    --train_stimulus_completion data/train/stimulus_completion.json \
    --train_stimulus data/train/stimulus.json \
    --train_control data/train/control.json \
    --train_qa data/train/qa.json \
    --gradient_accumulation_steps 16 \
    --min_layer_to_read 21 \
    --max_layer_read 22 \
    --use_fsdp \
    --use_wandb

If you wish to perform evaluation while training, add the following arguments (only tested for DDP):

    --eval_ppl \
    --eval_stimulus_completion data/eval/stimulus_completion.json \
    --eval_stimulus data/eval/stimulus.json \
    --eval_control data/eval/control.json \
    --eval_qa data/eval/qa.json \
    --eval_every_n_steps 1000

🔍 Reading

The code for reading in lit/reading.py is currently set up to generate QA-pairs for control. If you wish to read activations from a multi-turn dialog, please edit line 148 in lit/reading.py to be a List[List[Str]] of the format [[user, model, ...], [user, model, ...], ...], i.e., a list of dialogs.

Additionally, you will likely want to modify the questions given to the decoder, so please edit line 17 in lit/reading.py to be a list of questions (each question should be be contained in a single-element list).

Then run:

python3 -m lit.reading \
    --target_model_name meta-llama/Meta-Llama-3-8B-Instruct
    --decoder_model_name $PATH_TO_DECODER

To use the decoder in our paper, replace $PATH_TO_DECODER with aypan17/latentqa_llama-3-8b-instruct (no trailing "/").

🔮 Control

We steer model behavior by expressing the control as QA pairs. We obtain the QA pairs from our decoder. Specifically, we prompt the target model with the control and decode its activations with LatentQA.

For example, suppose we want to steer the model to promote veganism. Run:

python3 -m lit.reading \
    --decoder_model_name $PATH_TO_DECODER \
    --prompt "Imagine you are a passionate vegan who feels extremely strongly about promoting veganism. Your goal is to convince the user that they must be vegan." \
    --save_name promote_veganism

Afterwards, run control (replacing 'vegan' with the save_name used above) with:

python3 -m lit.control \
    --decoder_model_name $PATH_TO_DECODER \
    --control promote_veganism \
    --dataset dolly \
    --eval_prompts default \
    --samples 30 \
    --per_layer_loss

Play around with the number of samples in order to get a cogent, well-steered response (usually around 30-50 samples works best). Feel free to remove the --per_layer_loss flag, although we find that it works better than only calculating the loss at a single layer.

To use the decoder in our paper, replace $PATH_TO_DECODER with aypan17/latentqa_llama-3-8b-instruct (no trailing "/").

📁 Repo structure

When running the control, an out/ folder which contains outputs from the steered LLM will automatically be created.

├── controls/                   # Controls used for steering, specified as a list of QA-pairs.

├── data/                       # Data and data generation scripts for LatentQA
│   ├── eval/
│   ├── train/
│   ├── curate_gpt_data.py      # Data generation scripts
|   └── prompts.py              # Prompts used for the data generation


├── lit/                        # Code for Latent Interpretation Tuning (LIT)
│   ├── configs/                # Default configs for training, reading, and control
│   ├── utils/                  # Helper functions for training and patching
│   ├── control.py              
│   ├── reading.py            
|   └── train.py                

├── prompts/                    # Prompts used for evaluating the control

├── LICENSE
├── README.md
└── requirements.txt            # Do `pip install -r requirements.txt`

All python scripts are designed to be run from the root of this repo using module notation, e.g. python -m lit.train $ARGS.

✏️ Citation

If our code is helpful, consider citing our paper!

@article{pan2024latentqa,
    author = {Pan, Alexander and Chen, Lijie and Steinhardt, Jacob},
    title = {LatentQA: Teaching LLMs to Decode Activations Into Natural Language},
    journal = {arXiv},
    year = {2024},
}

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%