Skip to content

Commit 6803c09

Browse files
committed
Create generate_reports_for_images.py
1 parent f09fcc4 commit 6803c09

File tree

3 files changed

+210
-3
lines changed

3 files changed

+210
-3
lines changed

README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,13 @@ As a side note - we cannot provide you these files directly (i.e. you have to cr
5151

5252
Please read [README_TRAIN_TEST.md](README_TRAIN_TEST.md) for specific information on training and testing the model.
5353

54+
## Inference
55+
56+
To generate reports for a list of images, run "**python generate_reports_for_images.py**" in src/full_model/. Specify the model checkpoint, the list of image paths and the paths to the txt file with the generated reports in the main function.
57+
5458
## Model checkpoint
5559

56-
You can download the full model checkpoint from this [google drive link](https://drive.google.com/file/d/1P0ewzWKCAS86-poH4ZSf-xibRGpcUFQR/view?usp=sharing).
60+
You can download the full model checkpoint from this [google drive link](https://drive.google.com/file/d/1rDxqzOhjqydsOrITJrX0Rj1PAdMeP7Wy/view?usp=sharing).
5761

5862
## Citation
5963

environment.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ dependencies:
2929
- scikit-learn==1.1.2
3030
- scipy==1.9.1
3131
- setuptools==61.2.0
32-
- spacy==3.4.1
32+
- spacy==3.5.3
3333
- spacy-alignments==0.8.5
3434
- spacy-legacy==3.0.10
3535
- spacy-loggers==1.0.3
36-
- spacy-transformers==1.1.8
36+
- spacy-transformers==1.2.5
3737
- statsmodels==0.13.2
3838
- tensorboard==2.9.0
3939
- tensorboard-data-server==0.6.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""
2+
Specify the checkpoint_path, images_paths and generated_reports_txt_path in the main function
3+
before running this script.
4+
5+
If you encounter any spacy-related errors, try upgrading spacy to version 3.5.3 and spacy-transformers to version 1.2.5
6+
pip install -U spacy
7+
pip install -U spacy-transformers
8+
"""
9+
10+
from collections import defaultdict
11+
12+
import albumentations as A
13+
import cv2
14+
import evaluate
15+
import spacy
16+
import torch
17+
from albumentations.pytorch import ToTensorV2
18+
from tqdm import tqdm
19+
20+
from src.full_model.report_generation_model import ReportGenerationModel
21+
from src.full_model.train_full_model import get_tokenizer
22+
23+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24+
25+
BERTSCORE_SIMILARITY_THRESHOLD = 0.9
26+
IMAGE_INPUT_SIZE = 512
27+
MAX_NUM_TOKENS_GENERATE = 300
28+
NUM_BEAMS = 4
29+
mean = 0.471 # see get_transforms in src/dataset/compute_mean_std_dataset.py
30+
std = 0.302
31+
32+
33+
def write_generated_reports_to_txt(images_paths, generated_reports, generated_reports_txt_path):
34+
with open(generated_reports_txt_path, "w") as f:
35+
for image_path, report in zip(images_paths, generated_reports):
36+
f.write(f"Image path: {image_path}\n")
37+
f.write(f"Generated report: {report}\n\n")
38+
f.write("=" * 30)
39+
f.write("\n\n")
40+
41+
42+
def remove_duplicate_generated_sentences(generated_report, bert_score, sentence_tokenizer):
43+
def check_gen_sent_in_sents_to_be_removed(gen_sent, similar_generated_sents_to_be_removed):
44+
for lists_of_gen_sents_to_be_removed in similar_generated_sents_to_be_removed.values():
45+
if gen_sent in lists_of_gen_sents_to_be_removed:
46+
return True
47+
48+
return False
49+
50+
# since different (closely related) regions can have the same generated sentence, we first remove exact duplicates
51+
52+
# use sentence tokenizer to separate the generated sentences
53+
gen_sents = sentence_tokenizer(generated_report).sents
54+
55+
# convert spacy.tokens.span.Span object into str by using .text attribute
56+
gen_sents = [sent.text for sent in gen_sents]
57+
58+
# remove exact duplicates using a dict as an ordered set
59+
# note that dicts are insertion ordered as of Python 3.7
60+
gen_sents = list(dict.fromkeys(gen_sents))
61+
62+
# there can still be generated sentences that are not exact duplicates, but nonetheless very similar
63+
# e.g. "The cardiomediastinal silhouette is normal." and "The cardiomediastinal silhouette is unremarkable."
64+
# to remove these "soft" duplicates, we use bertscore
65+
66+
# similar_generated_sents_to_be_removed maps from one sentence to a list of similar sentences that are to be removed
67+
similar_generated_sents_to_be_removed = defaultdict(list)
68+
69+
for i in range(len(gen_sents)):
70+
gen_sent_1 = gen_sents[i]
71+
72+
for j in range(i + 1, len(gen_sents)):
73+
if check_gen_sent_in_sents_to_be_removed(gen_sent_1, similar_generated_sents_to_be_removed):
74+
break
75+
76+
gen_sent_2 = gen_sents[j]
77+
if check_gen_sent_in_sents_to_be_removed(gen_sent_2, similar_generated_sents_to_be_removed):
78+
continue
79+
80+
bert_score_result = bert_score.compute(
81+
lang="en", predictions=[gen_sent_1], references=[gen_sent_2], model_type="distilbert-base-uncased"
82+
)
83+
84+
if bert_score_result["f1"][0] > BERTSCORE_SIMILARITY_THRESHOLD:
85+
# remove the generated similar sentence that is shorter
86+
if len(gen_sent_1) > len(gen_sent_2):
87+
similar_generated_sents_to_be_removed[gen_sent_1].append(gen_sent_2)
88+
else:
89+
similar_generated_sents_to_be_removed[gen_sent_2].append(gen_sent_1)
90+
91+
generated_report = " ".join(
92+
sent
93+
for sent in gen_sents
94+
if not check_gen_sent_in_sents_to_be_removed(sent, similar_generated_sents_to_be_removed)
95+
)
96+
97+
return generated_report
98+
99+
100+
def convert_generated_sentences_to_report(generated_sents_for_selected_regions, bert_score, sentence_tokenizer):
101+
generated_report = " ".join(sent for sent in generated_sents_for_selected_regions)
102+
103+
generated_report = remove_duplicate_generated_sentences(generated_report, bert_score, sentence_tokenizer)
104+
return generated_report
105+
106+
107+
def get_report_for_image(model, image_tensor, tokenizer, bert_score, sentence_tokenizer):
108+
with torch.autocast(device_type="cuda", dtype=torch.float16):
109+
output = model.generate(
110+
image_tensor.to(device, non_blocking=True),
111+
max_length=MAX_NUM_TOKENS_GENERATE,
112+
num_beams=NUM_BEAMS,
113+
early_stopping=True,
114+
)
115+
116+
beam_search_output, _, _, _ = output
117+
118+
generated_sents_for_selected_regions = tokenizer.batch_decode(
119+
beam_search_output, skip_special_tokens=True, clean_up_tokenization_spaces=True
120+
) # list[str]
121+
122+
generated_report = convert_generated_sentences_to_report(
123+
generated_sents_for_selected_regions, bert_score, sentence_tokenizer
124+
) # str
125+
126+
return generated_report
127+
128+
129+
def get_image_tensor(image_path):
130+
# cv2.imread by default loads an image with 3 channels
131+
# since we have grayscale images, we only have 1 channel and thus use cv2.IMREAD_UNCHANGED to read in the 1 channel
132+
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # shape (3056, 2544)
133+
134+
val_test_transforms = A.Compose(
135+
[
136+
A.LongestMaxSize(max_size=IMAGE_INPUT_SIZE, interpolation=cv2.INTER_AREA),
137+
A.PadIfNeeded(min_height=IMAGE_INPUT_SIZE, min_width=IMAGE_INPUT_SIZE, border_mode=cv2.BORDER_CONSTANT),
138+
A.Normalize(mean=mean, std=std),
139+
ToTensorV2(),
140+
]
141+
)
142+
143+
transform = val_test_transforms(image=image)
144+
image_transformed = transform["image"] # shape (1, 512, 512)
145+
image_transformed_batch = image_transformed.unsqueeze(0) # shape (1, 1, 512, 512)
146+
147+
return image_transformed_batch
148+
149+
150+
def get_model(checkpoint_path):
151+
checkpoint = torch.load(
152+
checkpoint_path,
153+
map_location=torch.device("cpu"),
154+
)
155+
156+
# if there is a key error when loading checkpoint, try uncommenting down below
157+
# since depending on the torch version, the state dicts may be different
158+
# checkpoint["model"]["object_detector.rpn.head.conv.weight"] = checkpoint["model"].pop("object_detector.rpn.head.conv.0.0.weight")
159+
# checkpoint["model"]["object_detector.rpn.head.conv.bias"] = checkpoint["model"].pop("object_detector.rpn.head.conv.0.0.bias")
160+
model = ReportGenerationModel(pretrain_without_lm_model=True)
161+
model.load_state_dict(checkpoint["model"])
162+
model.to(device, non_blocking=True)
163+
model.eval()
164+
165+
del checkpoint
166+
167+
return model
168+
169+
170+
def main():
171+
checkpoint_path = ".../___.pt"
172+
model = get_model(checkpoint_path)
173+
174+
print("Model instantiated.")
175+
176+
# paths to the images that we want to generate reports for
177+
images_paths = [
178+
".../___.jpg",
179+
".../___.jpg",
180+
".../___.jpg",
181+
]
182+
183+
generated_reports_txt_path = ".../___.txt"
184+
generated_reports = []
185+
186+
bert_score = evaluate.load("bertscore")
187+
sentence_tokenizer = spacy.load("en_core_web_trf")
188+
tokenizer = get_tokenizer()
189+
190+
# if you encounter a spacy-related error, try upgrading spacy to version 3.5.3 and spacy-transformers to version 1.2.5
191+
# pip install -U spacy
192+
# pip install -U spacy-transformers
193+
194+
for image_path in tqdm(images_paths):
195+
image_tensor = get_image_tensor(image_path) # shape (1, 1, 512, 512)
196+
generated_report = get_report_for_image(model, image_tensor, tokenizer, bert_score, sentence_tokenizer)
197+
generated_reports.append(generated_report)
198+
199+
write_generated_reports_to_txt(images_paths, generated_reports, generated_reports_txt_path)
200+
201+
202+
if __name__ == "__main__":
203+
main()

0 commit comments

Comments
 (0)