Skip to content

Commit d90b060

Browse files
authored
Merge pull request #27 from dblasko/eval-mem-leak
Fix test set evaluation GPU memory leak
2 parents eb64933 + 41f0eaa commit d90b060

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

src/full_model/evaluate_full_model/evaluate_language_model.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -982,12 +982,19 @@ def update_gen_and_ref_sentences_for_regions(
982982
index_gen_ref_sentence += 1
983983

984984

985-
def get_generated_reports(generated_sentences_for_selected_regions, selected_regions, sentence_tokenizer, bertscore_threshold):
985+
def get_generated_reports(
986+
generated_sentences_for_selected_regions,
987+
selected_regions,
988+
sentence_tokenizer,
989+
bertscore_threshold,
990+
bert_score,
991+
):
986992
"""
987993
Args:
988994
generated_sentences_for_selected_regions (List[str]): of length "num_regions_selected_in_batch"
989995
selected_regions ([batch_size x 29]): boolean array that has exactly "num_regions_selected_in_batch" True values
990996
sentence_tokenizer: used in remove_duplicate_generated_sentences to separate the generated sentences
997+
bert_score: instance of the evaluate bert score evaluation module
991998
992999
Return:
9931000
generated_reports (List[str]): list of length batch_size containing generated reports for every image in batch
@@ -1055,8 +1062,6 @@ def check_gen_sent_in_sents_to_be_removed(gen_sent, similar_generated_sents_to_b
10551062

10561063
return gen_report_single_image, similar_generated_sents_to_be_removed
10571064

1058-
bert_score = evaluate.load("bertscore")
1059-
10601065
generated_reports = []
10611066
removed_similar_generated_sentences = []
10621067
curr_index = 0

src/full_model/test_set_evaluation.py

+10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pandas as pd
1212
import spacy
1313
import torch
14+
import evaluate
1415
from torch.utils.data import DataLoader
1516
import torchmetrics
1617
from tqdm import tqdm
@@ -268,6 +269,7 @@ def iterate_over_test_loader(test_loader):
268269

269270
# used in function get_generated_reports
270271
sentence_tokenizer = spacy.load("en_core_web_trf")
272+
bert_score = evaluate.load("bertscore")
271273

272274
with torch.no_grad():
273275
for num_batch, batch in tqdm(enumerate(test_loader)):
@@ -342,6 +344,14 @@ def iterate_over_test_loader(test_loader):
342344
selected_regions,
343345
sentence_tokenizer,
344346
BERTSCORE_SIMILARITY_THRESHOLD
347+
generated_reports, removed_similar_generated_sentences = (
348+
get_generated_reports(
349+
generated_sents_for_selected_regions,
350+
selected_regions,
351+
sentence_tokenizer,
352+
BERTSCORE_SIMILARITY_THRESHOLD,
353+
bert_score,
354+
)
345355
)
346356

347357
gen_and_ref_sentences["generated_sentences"].extend(generated_sents_for_selected_regions)

0 commit comments

Comments
 (0)