|
1 |
| -# ExplainClinicalBERT |
| 1 | +[](LICENSE) |
| 2 | +[](https://www.python.org/) |
| 3 | +[](https://pytorch.org/) |
| 4 | + |
| 5 | +Generate Explanations for BERT Predictions on Structured Electronic Health Record Data |
| 6 | +====================================================================== |
| 7 | +[](LICENSE) |
| 8 | + |
| 9 | +Recent breakthroughs in large language models have been reapplied to structured electronic health records (EHR). |
| 10 | + |
| 11 | +**To explain clinical BERT model predictions, we present an approach which leverages integrated gradients to attribute events in medical records that lead to an outcome prediction.** |
| 12 | +<div style="display: flex; justify-content: space-around; align-items: center;"> |
| 13 | + <img src="figures/patient_explained.png" alt="Patient Explained" style="width: 30%;"/> |
| 14 | + <img src="figures/lab_markers_explained.png" alt="Clinical Tokens Explained" style="width: 30%;"/> |
| 15 | + <img src="figures/clinical_tokens_explained.png" alt="Lab Markers Explained" style="width: 30%;"/> |
| 16 | +</div> |
| 17 | + |
| 18 | +The explainability approach we have developed can be applied to many diseases and prediction tasks using language models trained on structured electronic health records. |
| 19 | + |
| 20 | +_ℹ️ This repository was created to compliment the manuscript "Predicting Progression and Key Drivers of Asthma with a Clinical BERT model and Integrated Gradients" which is available here: [coming soon]()_ |
| 21 | + |
| 22 | +## [Pre-requisite] Training a MEDBERT Model |
| 23 | +The explainability pipeline requires a BERT-based model trained on structured EHR data and fine-tuned for the specific disease prediction task. |
| 24 | +The pre-training procedure most closely follows the method described in the [TransformEHR](https://www.nature.com/articles/s41467-023-43715-z) paper, |
| 25 | +and the fine-tuning is primarily based on the approach used in the [Med-BERT](https://doi.org/10.1038/s41746-021-00455-y) model. |
| 26 | +We have also provided a sample input and output data sample in `data/dummy_data.parquet` and `output/output.parquet` respectively. This will demonstrate the format you may expect if you chose to adopt this code. The `output.parquet` dataset can be further post-processed to aggregate the top tokens for explainability. |
| 27 | + |
| 28 | +## Install |
| 29 | +``` |
| 30 | +# Install conda environment |
| 31 | +conda create -n bert-explainability python=3.10 -y |
| 32 | +conda activate bert-explainability |
| 33 | +
|
| 34 | +# Install dependencies |
| 35 | +pip install -r requirements.txt |
| 36 | +export PYTHONPATH=$PYTHONPATH:./src:./transformers_interpret |
| 37 | +``` |
| 38 | + |
| 39 | +## Sample Script for Running the Pipeline |
| 40 | +``` |
| 41 | +python3 -m src.explainability.explainability './config/explainability_asthma.yaml' './bert_finetuning_asthma_model.tar.gz' './data/' './output/' |
| 42 | +``` |
| 43 | + |
| 44 | +## Input data format |
| 45 | +A small dummy dataset has been provided in `data/*` and `notebooks/example_walkthrough.ipynb`. The data is assumed to be parquet files stored locally or from s3, with the following schema: person_id (int), sorted_event_tokens (array<string>), day_position_tokens (array<int>). |
| 46 | +- person_id: A unique identifier for each individual. |
| 47 | +- day_position_tokens: An array representing the relative time (in days) of events, with 0 indicating demographic tokens. |
| 48 | +- sorted_event_tokens: A list of event codes associated with the individual. Each event corresponds to the relative date indicated by its index in the day_position_tokens array. |
| 49 | + - The first five tokens are always assumed to be demographic tokens, in the order of age, ethnicity, gender, race, and region. |
| 50 | +- label: Label for the patient for a specific prediction task. |
| 51 | + |
| 52 | +## Config |
| 53 | +The config can be found in `config/explainability_asthma.yaml`. It can be modified for multiclass prediction purposes. |
| 54 | +It contains the following parameters: |
| 55 | +- model_max_len: Maximum token length for the model (e.g. 512). |
| 56 | +- training_label: Name of the label column. |
| 57 | +- internal_batch_size: Batch size used during processing. |
| 58 | +- demographic_token_starters: Prefixes of tokens that belong to demographic categories. |
| 59 | +- avg_token_type_baseline: A boolean flag that determines if the baseline for lab test tokens with percentile information is the average percentile (True) or the 5th percentile (False). |
| 60 | + |
| 61 | +If doing multiclass predictions: |
| 62 | + - num_labels: Number of classes. |
| 63 | + |
| 64 | +## Output data format |
| 65 | +The output is stored as an `output.parquet` file in the directory specified. |
| 66 | + |
| 67 | +## Contacts |
| 68 | +For any inquiries please raise a git issue and we will try to follow-up in a timely manner. |
| 69 | + |
| 70 | +## License |
| 71 | +This work is available for academic research and non-commercial use only. See the _LICENSE_ file for details. |
| 72 | + |
| 73 | +## Acknowledgements |
| 74 | +This package utilizes functions from [transformers-interpret](https://github.com/cdpierse/transformers-interpret). All utilized functions are located in the `transformers_interpret/` subdirectory and are licensed under the Apache License Version 2.0. |
0 commit comments