Skip to content

Commit 29a3acf

Browse files
committed
initial repo setup - happy open sourcing!
1 parent 60bcf12 commit 29a3acf

29 files changed

+4725
-1
lines changed

LICENSE.txt

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Non Commercial License Notice:
2+
3+
Copyright Sanofi 2024
4+
5+
Permission is hereby granted, free of charge, for academic research purposes only and for non-commercial uses only, to any person from academic research or non-profit organizations obtaining a copy of this software and associated documentation files (the "Software"), to use, copy, modify, or merge the Software, subject to the following conditions: this permission notice shall be included in all copies of the Software or of substantial portions of the Software.
6+
7+
For purposes of this license, “non-commercial use” excludes uses foreseeably resulting in a commercial benefit. To use this software for other purposes (such as the development of a commercial product, including but not limited to software, service, or pharmaceuticals, or in a collaboration with a private company), please contact SANOFI at [email protected].
8+
9+
All other rights are reserved. The Software is provided “as is”, without warranty of any kind, express or implied, including the warranties of noninfringement.
10+
11+
The Software is registered.

README.md

+74-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,74 @@
1-
# ExplainClinicalBERT
1+
[![License](https://img.shields.io/badge/License-Academic%20Non--Commercial-blue.svg)](LICENSE)
2+
[![Python](https://img.shields.io/badge/Python-3.10%2B-blue.svg)](https://www.python.org/)
3+
[![PyTorch](https://img.shields.io/badge/Pytorch-2.2-orange.svg)](https://pytorch.org/)
4+
5+
Generate Explanations for BERT Predictions on Structured Electronic Health Record Data
6+
======================================================================
7+
[](LICENSE)
8+
9+
Recent breakthroughs in large language models have been reapplied to structured electronic health records (EHR).
10+
11+
**To explain clinical BERT model predictions, we present an approach which leverages integrated gradients to attribute events in medical records that lead to an outcome prediction.**
12+
<div style="display: flex; justify-content: space-around; align-items: center;">
13+
<img src="figures/patient_explained.png" alt="Patient Explained" style="width: 30%;"/>
14+
<img src="figures/lab_markers_explained.png" alt="Clinical Tokens Explained" style="width: 30%;"/>
15+
<img src="figures/clinical_tokens_explained.png" alt="Lab Markers Explained" style="width: 30%;"/>
16+
</div>
17+
18+
The explainability approach we have developed can be applied to many diseases and prediction tasks using language models trained on structured electronic health records.
19+
20+
_ℹ️ This repository was created to compliment the manuscript "Predicting Progression and Key Drivers of Asthma with a Clinical BERT model and Integrated Gradients" which is available here: [coming soon]()_
21+
22+
## [Pre-requisite] Training a MEDBERT Model
23+
The explainability pipeline requires a BERT-based model trained on structured EHR data and fine-tuned for the specific disease prediction task.
24+
The pre-training procedure most closely follows the method described in the [TransformEHR](https://www.nature.com/articles/s41467-023-43715-z) paper,
25+
and the fine-tuning is primarily based on the approach used in the [Med-BERT](https://doi.org/10.1038/s41746-021-00455-y) model.
26+
We have also provided a sample input and output data sample in `data/dummy_data.parquet` and `output/output.parquet` respectively. This will demonstrate the format you may expect if you chose to adopt this code. The `output.parquet` dataset can be further post-processed to aggregate the top tokens for explainability.
27+
28+
## Install
29+
```
30+
# Install conda environment
31+
conda create -n bert-explainability python=3.10 -y
32+
conda activate bert-explainability
33+
34+
# Install dependencies
35+
pip install -r requirements.txt
36+
export PYTHONPATH=$PYTHONPATH:./src:./transformers_interpret
37+
```
38+
39+
## Sample Script for Running the Pipeline
40+
```
41+
python3 -m src.explainability.explainability './config/explainability_asthma.yaml' './bert_finetuning_asthma_model.tar.gz' './data/' './output/'
42+
```
43+
44+
## Input data format
45+
A small dummy dataset has been provided in `data/*` and `notebooks/example_walkthrough.ipynb`. The data is assumed to be parquet files stored locally or from s3, with the following schema: person_id (int), sorted_event_tokens (array<string>), day_position_tokens (array<int>).
46+
- person_id: A unique identifier for each individual.
47+
- day_position_tokens: An array representing the relative time (in days) of events, with 0 indicating demographic tokens.
48+
- sorted_event_tokens: A list of event codes associated with the individual. Each event corresponds to the relative date indicated by its index in the day_position_tokens array.
49+
- The first five tokens are always assumed to be demographic tokens, in the order of age, ethnicity, gender, race, and region.
50+
- label: Label for the patient for a specific prediction task.
51+
52+
## Config
53+
The config can be found in `config/explainability_asthma.yaml`. It can be modified for multiclass prediction purposes.
54+
It contains the following parameters:
55+
- model_max_len: Maximum token length for the model (e.g. 512).
56+
- training_label: Name of the label column.
57+
- internal_batch_size: Batch size used during processing.
58+
- demographic_token_starters: Prefixes of tokens that belong to demographic categories.
59+
- avg_token_type_baseline: A boolean flag that determines if the baseline for lab test tokens with percentile information is the average percentile (True) or the 5th percentile (False).
60+
61+
If doing multiclass predictions:
62+
- num_labels: Number of classes.
63+
64+
## Output data format
65+
The output is stored as an `output.parquet` file in the directory specified.
66+
67+
## Contacts
68+
For any inquiries please raise a git issue and we will try to follow-up in a timely manner.
69+
70+
## License
71+
This work is available for academic research and non-commercial use only. See the _LICENSE_ file for details.
72+
73+
## Acknowledgements
74+
This package utilizes functions from [transformers-interpret](https://github.com/cdpierse/transformers-interpret). All utilized functions are located in the `transformers_interpret/` subdirectory and are licensed under the Apache License Version 2.0.

config/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import os
2+
PROJECT_DIR = os.path.split(os.path.split(__file__)[0])[0]

config/explainability_asthma.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
model_max_len: 512
2+
training_label: "label"
3+
internal_batch_size: 16
4+
demographic_token_starters:
5+
- REGION
6+
- GENDER
7+
- ETHNICITY
8+
- AGE
9+
- RACE
10+
avg_token_type_baseline: false
11+
12+
# if doing multiclass predictions:
13+
# num_labels: 3

data/dummy_data.parquet

4.56 KB
Binary file not shown.

figures/clinical_tokens_explained.png

341 KB
Loading

figures/lab_markers_explained.png

352 KB
Loading

figures/patient_explained.png

100 KB
Loading

0 commit comments

Comments
 (0)