Skip to content

Feature: CheXpert example [WIP] #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions chexpert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# MLCube: Chexpert Example
This example demonstrates how to use MLCube to work with a computer vision model trained on the CheXpert Dataset.

CheXpert is a large dataset of chest X-rays and competition for automated chest x-ray interpretation, which features uncertainty labels and radiologist-labeled reference standard evaluation sets.

The model used here is based on the top 1 solution of the CheXpert challenge, which can be found [here](https://github.com/jfhealthcare/Chexpert).

### Project setup
```Python
# Create Python environment
virtualenv -p python3 ./env && source ./env/bin/activate

# Install MLCube and MLCube docker runner from GitHub repository (normally, users will just run `pip install mlcube mlcube_docker`)
git clone https://github.com/sergey-serebryakov/mlbox.git && cd mlbox && git checkout feature/configV2
cd ./mlcube && python setup.py bdist_wheel && pip install --force-reinstall ./dist/mlcube-* && cd ..
cd ./runners/mlcube_docker && python setup.py bdist_wheel && pip install --force-reinstall --no-deps ./dist/mlcube_docker-* && cd ../../..
```

## Clone MLCube examples and go to chexpert
```
git clone https://github.com/mlperf/mlcube_examples.git && cd ./mlcube_examples
git fetch origin pull/34/head:chest-xray-example && git checkout chest-xray-example
cd ./chexpert
```

## Get the data
Because the Chexpert Dataset contains sensitive information, signing an user agreement is required before obtaining the data. This means that we cannot automate the data download process. To obtain the dataset:

1. sign up at the [Chexpert Dataset Download Agreement](https://stanfordmlgroup.github.io/competitions/chexpert/#agreement) and download the small dataset from the link sent to your email.
2. Unzip and place the `CheXpert-v1.0-small` folder inside `mlcube/workspace/data` folder. Your folder structure should look like this:

```
.
├── mlcube
│ └── workspace
│ └── Data
│ └── CheXpert-v1.0-small
│ ├── valid
│ └── valid.csv
└── project
```

## Run Chexpert MLCube on a local machine with Docker runner
```
# Run Chexpert training tasks: download data, download the model and generate predictions
mlcube run --task download_model
mlcube run --task preprocess
mlcube run --task infer
```

Parameters defined in **mlcube.yaml** can be overridden using: `param=input`, example:

```
mlcube run --task download_model data_dir=path_to_custom_dir
```

We are targeting pull-type installation, so MLCubes should be available on docker hub. If not, try this:

```
mlcube run ... -Pdocker.build_strategy=auto
```

By default, at the end of the download_model task, Chexpert model will be saved in `workspace/model`.

By default, at the end of the infer task, results will be saved in `workspace/inferences.txt`.
29 changes: 29 additions & 0 deletions chexpert/mlcube/mlcube.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: MLCommons Chexpert
description: MLCommons Chexpert example for inference with the Chexpert model.
authors:
- {name: "MLCommons Best Practices Working Group"}

platform:
accelerator_count: 0

docker:
# Image name.
image: mlcommons/chexpert:0.0.1
# Docker build context relative to $MLCUBE_ROOT. Default is `build`.
build_context: "../project"
# Docker file name within docker build context, default is `Dockerfile`.
build_file: "Dockerfile"

tasks:
download_model:
# Download model files
parameters:
outputs: {model_dir: model/}
preprocess:
parameters:
inputs: {data_dir: data/CheXpert-v1.0-small}
infer:
# predict on data
parameters:
inputs: {data_dir: data/CheXpert-v1.0-small, model_dir: model/}
outputs: {log_dir: inference_logs/, out_dir: ./}
31 changes: 31 additions & 0 deletions chexpert/project/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
FROM ubuntu:18.04
MAINTAINER MLPerf MLBox Working Group

RUN apt-get update && \
apt-get install -y --no-install-recommends \
software-properties-common \
python3-dev \
curl \
wget \
libsm6 libxext6 libxrender-dev && \
rm -rf /var/lib/apt/lists/*

RUN add-apt-repository ppa:deadsnakes/ppa -y && apt-get update

RUN apt-get install python3.7 -y

RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
python3.7 get-pip.py && \
rm get-pip.py

COPY ./requirements.txt project/requirements.txt

RUN python3.7 -m pip install --upgrade pip

RUN python3.7 -m pip install --no-cache-dir -r project/requirements.txt

COPY . /project

WORKDIR /project

ENTRYPOINT ["python3.7", "mlcube.py"]
179 changes: 179 additions & 0 deletions chexpert/project/chexpert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import yaml
import sys
import argparse
import logging
import logging.config
import json
import time
from tqdm import tqdm
from enum import Enum
from typing import List
from easydict import EasyDict as edict
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.nn import DataParallel
import torch.nn.functional as F

# sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')

from data.dataset import ImageDataset # noqa
from model.classifier import Classifier # noqa

logger = logging.getLogger(__name__)


class Task(str, Enum):
DownloadData = "download_data"
DownloadCkpt = "download_ckpt"
Infer = "infer"


def create_directory(path: str) -> None:
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)


def infer(task_args: List[str]) -> None:
""" Task: infer

Input parameters:
--data_dir, --ckpt_dir, --out_dir
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "--data-dir", type=str, default=None, help="Dataset path."
)
parser.add_argument(
"--model_dir", "--model-dir", type=str, default=None, help="Model location."
)
parser.add_argument(
"--out_dir", "--out-dir", type=str, default=None, help="Model output directory."
)

args = parser.parse_args(args=task_args)
run(args)


def get_pred(output, cfg):
if cfg.criterion == "BCE" or cfg.criterion == "FL":
for num_class in cfg.num_classes:
assert num_class == 1
pred = torch.sigmoid(output.view(-1)).cpu().detach().numpy()
elif cfg.criterion == "CE":
for num_class in cfg.num_classes:
assert num_class >= 2
prob = F.softmax(output)
pred = prob[:, 1].cpu().detach().numpy()
else:
raise Exception("Unknown criterion : {}".format(cfg.criterion))

return pred


def test_epoch(cfg, model, device, dataloader, out_csv_path):
torch.set_grad_enabled(False)
steps = len(dataloader)
dataiter = iter(dataloader)
num_tasks = len(cfg.num_classes)

test_header = [
"Path",
"Cardiomegaly",
"Edema",
"Consolidation",
"Atelectasis",
"Pleural Effusion",
]

with open(out_csv_path, "w") as f:
f.write(",".join(test_header) + "\n")
for step in tqdm(range(steps)):
image, path = next(dataiter)
image = image.to(device)
output, __ = model(image)
batch_size = len(path)
pred = np.zeros((num_tasks, batch_size))

for i in range(num_tasks):
pred[i] = get_pred(output[i], cfg)

for i in range(batch_size):
batch = ",".join(map(lambda x: "{}".format(x), pred[:, i]))
result = path[i] + "," + batch
f.write(result + "\n")
logging.info(
"{}, Image : {}, Prob : {}".format(
time.strftime("%Y-%m-%d %H:%M:%S"), path[i], batch
)
)


def run(args):
ckpt_path = os.path.join(args.model_dir, "model.pth")
config_path = os.path.join(args.model_dir, "config.json")
print(config_path)
with open(config_path) as f:
cfg = edict(json.load(f))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt = torch.load(ckpt_path, map_location=device)
model = Classifier(cfg).to(device).eval()
model.load_state_dict(ckpt)

out_csv_path = os.path.join(args.out_dir, "inferences.csv")
in_csv_path = os.path.join(args.data_dir, "valid.csv")

dataloader_test = DataLoader(
ImageDataset(in_csv_path, cfg, args.data_dir, mode="test"),
batch_size=cfg.dev_batch_size,
drop_last=False,
shuffle=False,
)

test_epoch(cfg, model, device, dataloader_test, out_csv_path)


def main():
"""
chexpert.py task task_specific_parameters...
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--log_dir", "--log-dir", type=str, required=True, help="Logging directory."
)
mlcube_args, task_args = parser.parse_known_args()

os.makedirs(mlcube_args.log_dir, exist_ok=True)
logger_config = {
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"standard": {
"format": "%(asctime)s - %(name)s - %(threadName)s - %(levelname)s - %(message)s"
},
},
"handlers": {
"file_handler": {
"class": "logging.FileHandler",
"level": "INFO",
"formatter": "standard",
"filename": os.path.join(
mlcube_args.log_dir, f"mlcube_chexpert_infer.log"
),
}
},
"loggers": {
"": {"level": "INFO", "handlers": ["file_handler"]},
"__main__": {"level": "NOTSET", "propagate": "yes"},
"tensorflow": {"level": "NOTSET", "propagate": "yes"},
},
}
logging.config.dictConfig(logger_config)
infer(task_args)


if __name__ == "__main__":
main()
88 changes: 88 additions & 0 deletions chexpert/project/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
from torch.utils.data import Dataset
import cv2
import os
from PIL import Image
from data.imgaug import GetTransforms
from data.utils import transform

np.random.seed(0)


class ImageDataset(Dataset):
def __init__(self, label_path, cfg, data_path, mode="train"):
self.cfg = cfg
self.data_path = data_path
self._label_header = None
self._image_paths = []
self._labels = []
self._mode = mode
self.dict = [
{"1.0": "1", "": "0", "0.0": "0", "-1.0": "0"},
{"1.0": "1", "": "0", "0.0": "0", "-1.0": "1"},
]
with open(label_path) as f:
header = f.readline().strip("\n").split(",")
self._label_header = [
header[7],
header[10],
header[11],
header[13],
header[15],
]
for line in f:
labels = []
fields = line.strip("\n").split(",")
image_path = fields[0]
image_root_path = os.path.join(self.data_path, image_path)
# image_path = os.path.join(data_path, fields[0])
flg_enhance = False
for index, value in enumerate(fields[5:]):
if index == 5 or index == 8:
labels.append(self.dict[1].get(value))
if (
self.dict[1].get(value) == "1"
and self.cfg.enhance_index.count(index) > 0
):
flg_enhance = True
elif index == 2 or index == 6 or index == 10:
labels.append(self.dict[0].get(value))
if (
self.dict[0].get(value) == "1"
and self.cfg.enhance_index.count(index) > 0
):
flg_enhance = True
# labels = ([self.dict.get(n, n) for n in fields[5:]])
labels = list(map(int, labels))
self._image_paths.append(image_path)
assert os.path.exists(image_root_path), image_path
self._labels.append(labels)
if flg_enhance and self._mode == "train":
for i in range(self.cfg.enhance_times):
self._image_paths.append(image_path)
self._labels.append(labels)
self._num_image = len(self._image_paths)

def __len__(self):
return self._num_image

def __getitem__(self, idx):
image_root_path = os.path.join(self.data_path, self._image_paths[idx])
image = cv2.imread(image_root_path, 0)
image = Image.fromarray(image)
if self._mode == "train":
image = GetTransforms(image, type=self.cfg.use_transforms_type)
image = np.array(image)
image = transform(image, self.cfg)
labels = np.array(self._labels[idx]).astype(np.float32)

path = self._image_paths[idx]

if self._mode == "train" or self._mode == "dev":
return (image, labels)
elif self._mode == "test":
return (image, path)
elif self._mode == "heatmap":
return (image, path, labels)
else:
raise Exception("Unknown mode : {}".format(self._mode))
Loading