This repository contains the code base for "Unsupervised Recovery of Hidden Markov Models from Transformers with Evolutionary Algorithms", a technical report from Colin Lu and Dylan Bowman for research conducted during the Computational Mechanics Hackathon organized by PIBBSS and Simplex from June 1, 2024 to June 3, 2024.
Update 6/15: This project won first prize in the hackathon.
Prior work finds that transformer neural networks trained to mimic the output of a Hidden Markov Model (HMM) embed the optimal Bayesian beliefs for the HMM’s current state in their residual stream, which can be recovered via linear regression. In this work, we aim to address the problem of extracting information about the underlying HMM using the residual stream, without needing to know the MSP already. To do so, we use the
Assuming you already have a Python environment set up, clone the epsilon-transformers (GitHub) repository and follow the setup instruction for installation with pip. Then, clone this repository.
- The
srcfolder contains two scripts,generate_paths_and_beliefs.pyandexperiment.py. These scripts are meant to be run sequentially.-
generate_paths_and_beliefs.py: For each set of parameters insrc/msp_cfg.yaml, generate the optimal Bayesian belief states for every possible input sequence. The output tensor is cached in a.ptfile insrc/cached_belief_store.src/cached_belief_storecomes pre-populated with the beliefs from our grid search so there isn't a need to run this script unless you're generating beliefs for new sets of parameters. -
experiment.py: For each set of parameters insrc/msp_cfg.yamland for each of the pretrained models (Mess3(0.15, 0.6) and Mess3(0.05, 0.85)), we train a new probe and print out its$R^2$ and MSE scores. This script producesvisualization/r2.pklfor seeing how well a probe fits transformer activations to different Mess3 processes. By default, this script will also produce reconstruction visualizations insrc/imagesfolder for each set of parameters.-
--no-imageto disable reconstruction image. Image takes >1 min to generate, so this flag increases performance substantially for the probe training. -
--deviceto set the PyTorch device, which speeds up probe training substantially when set to a GPU.
-
-
- The
visualizationfolder contains two IPython notebooks,MSP_visualization.ipynbandr2_visualization.ipynb.-
MSP_visualization.ipynb: The primary section is titled Mess3 MSP: Chaos Game Fractal, and plots Mess3 MSPs for various alpha and x values. See examples for usage. The code also supports arbitrary 3-state HMMs defined by emission_and_transition_pi. -
r2_visualization.ipynb: Plot the$R^2$ values (via colors) for how good the linear fit is from transformer activations to Mess3 MSP geometries, for different values of alpha and x. Can observe that$R^2$ is approximately unimodal and convex, and maximized close to the actual Mess3 parameter values that generated the transformer's training sequence.
-
-
evo_alg_demo.ipynbruns evolutionary search over parameter sets using probe$R^2$ as a reward signal.