Skip to content

Commit

Permalink
Merge pull request #80 from huridocs/benchmark-base-line
Browse files Browse the repository at this point in the history
Add performance report
  • Loading branch information
gabriel-piles authored Jun 21, 2024
2 parents 499906b + 23ef41a commit 3341c1b
Show file tree
Hide file tree
Showing 24 changed files with 753 additions and 102 deletions.
15 changes: 7 additions & 8 deletions src/extractors/ExtractorBase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import random
from abc import abstractmethod
from collections import Counter
from os import makedirs
from os.path import exists
from pathlib import Path
Expand Down Expand Up @@ -43,18 +42,18 @@ def is_multilingual(multi_option_data: ExtractionData) -> bool:
return False

@staticmethod
def get_train_test_sets(
extraction_data: ExtractionData, seed: int = 22, limit_samples: bool = True
) -> (ExtractionData, ExtractionData):
if len(extraction_data.samples) < 15:
def get_train_test_sets(extraction_data: ExtractionData, limit_samples: bool = True) -> (ExtractionData, ExtractionData):
if len(extraction_data.samples) < 8:
return extraction_data, extraction_data

train_size = int(len(extraction_data.samples) * 0.8)
random.seed(seed)
random.shuffle(extraction_data.samples)

train_set: list[TrainingSample] = extraction_data.samples[:train_size]
test_set: list[TrainingSample] = extraction_data.samples[train_size:]

if len(extraction_data.samples) < 15:
test_set: list[TrainingSample] = extraction_data.samples[-10:]
else:
test_set = extraction_data.samples[train_size:]

if limit_samples:
train_set = train_set[:80]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from data.ExtractionIdentifier import ExtractionIdentifier
from data.Option import Option
from data.ExtractionData import ExtractionData
from data.TrainingSample import TrainingSample
from extractors.ExtractorBase import ExtractorBase
from extractors.pdf_to_multi_option_extractor.MultiLabelMethod import MultiLabelMethod
from extractors.pdf_to_multi_option_extractor.FilterSegmentsMethod import FilterSegmentsMethod
Expand All @@ -19,15 +20,17 @@ def __init__(
self.multi_label_method = multi_label_method
self.filter_segments_method = filter_segments_method
self.extraction_identifier = ExtractionIdentifier(run_name="not set", extraction_name="not set")
self.options: list[Option] = []
self.options: list[Option] = list()
self.multi_value = False
self.base_path = ""
self.extraction_data = None

def set_parameters(self, multi_option_data: ExtractionData):
self.extraction_identifier = multi_option_data.extraction_identifier
self.options = multi_option_data.options
self.multi_value = multi_option_data.multi_value
self.base_path = multi_option_data.extraction_identifier.get_path()
self.extraction_data = multi_option_data

def get_name(self):
if self.filter_segments_method and self.multi_label_method:
Expand All @@ -44,7 +47,7 @@ def get_performance(self, multi_option_data: ExtractionData, repetitions: int =
scores = list()
seeds = [22, 23, 24, 25]
for i in range(repetitions):
train_set, test_set = ExtractorBase.get_train_test_sets(multi_option_data, seeds[i])
train_set, test_set = ExtractorBase.get_train_test_sets(multi_option_data)
truth_one_hot = self.one_hot_to_options_list([x.labeled_data.values for x in test_set.samples], self.options)

self.train(train_set)
Expand Down Expand Up @@ -97,6 +100,12 @@ def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]:

return predictions

def get_samples_for_context(self, extraction_data: ExtractionData) -> list[TrainingSample]:
if self.extraction_data:
return self.extraction_data.samples

return extraction_data.samples

def can_be_used(self, multi_option_data: ExtractionData) -> bool:
if self.multi_label_method:
multi_label = self.multi_label_method(self.extraction_identifier, self.options, self.multi_value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@
from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.FuzzySegmentSelector import (
FuzzySegmentSelector,
)
from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.NextWordsTokenSelectorFuzzy75 import (
NextWordsTokenSelectorFuzzy75,
)

from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.PreviousWordsSentenceSelectorFuzzyCommas import (
PreviousWordsSentenceSelectorFuzzyCommas,
)
from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.PreviousWordsTokenSelectorFuzzy75 import (
PreviousWordsTokenSelectorFuzzy75,
)
from extractors.pdf_to_multi_option_extractor.multi_option_extraction_methods.SentenceSelectorFuzzyCommas import (
SentenceSelectorFuzzyCommas,
)
from send_logs import send_logs


Expand All @@ -53,6 +66,10 @@ class PdfToMultiOptionExtractor(ExtractorBase):
FuzzyAll75(),
FuzzyAll88(),
FuzzyAll100(),
PreviousWordsTokenSelectorFuzzy75(),
NextWordsTokenSelectorFuzzy75(),
PreviousWordsSentenceSelectorFuzzyCommas(),
SentenceSelectorFuzzyCommas(),
FastSegmentSelectorFuzzy95(),
FastSegmentSelectorFuzzyCommas(),
FuzzySegmentSelector(),
Expand Down Expand Up @@ -123,7 +140,7 @@ def get_predictions(self, predictions_samples: list[PredictionSample]) -> (list[
if not self.multi_value:
prediction = [x[:1] for x in prediction]

return training_samples, prediction
return method.get_samples_for_context(extraction_data), prediction

def load_options(self):
if not exists(self.options_path) or not exists(self.multi_value_path):
Expand All @@ -139,7 +156,7 @@ def get_best_method(self, multi_option_data: ExtractionData) -> PdfMultiOptionMe
best_method_instance = self.METHODS[0]
best_performance = 0
for method in self.METHODS:
performance = self.get_performance(method, multi_option_data)
performance = self.get_method_performance(method, multi_option_data)

if performance == 100:
send_logs(self.extraction_identifier, f"Best method {method.get_name()} with {performance}%")
Expand All @@ -152,10 +169,10 @@ def get_best_method(self, multi_option_data: ExtractionData) -> PdfMultiOptionMe
send_logs(self.extraction_identifier, f"Best method {best_method_instance.get_name()}")
return best_method_instance

def get_performance(self, method, multi_option_data):
def get_method_performance(self, method: PdfMultiOptionMethod, multi_option_data: ExtractionData):
method.set_parameters(multi_option_data)

if len(self.METHODS) == 1 or not method.can_be_used(multi_option_data):
if not method.can_be_used(multi_option_data):
return 0

send_logs(self.extraction_identifier, f"Checking {method.get_name()}")
Expand All @@ -164,9 +181,9 @@ def get_performance(self, method, multi_option_data):
performance = method.get_performance(multi_option_data)
except Exception as e:
send_logs(self.extraction_identifier, f"Error checking {method.get_name()}: {e}", Severity.error)

performance = 0

self.reset_extraction_data(multi_option_data)
send_logs(self.extraction_identifier, f"Performance {method.get_name()}: {performance}%")
return performance

Expand All @@ -187,3 +204,9 @@ def can_be_used(self, extraction_data: ExtractionData) -> bool:
return True

return False

@staticmethod
def reset_extraction_data(multi_option_data: ExtractionData):
for sample in multi_option_data.samples:
for segment in sample.pdf_data.pdf_data_segments:
segment.ml_label = 0
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from collections import Counter
from os.path import join
from os.path import join, exists
from pathlib import Path

import numpy as np
Expand All @@ -26,19 +26,26 @@ def __init__(self, extraction_identifier: ExtractionIdentifier):
self.next_words_path = join(self.fast_segment_selector_path, "next_words.txt")
self.model_path = join(self.fast_segment_selector_path, "lightgbm_model.txt")

def get_features(self, segment: PdfDataSegment):
def get_features(self, segment: PdfDataSegment, segments: list[PdfDataSegment]):
features = list()
text = segment.text_content

index = self.text_segments.index(segment)
previous_segment_text = self.text_segments[index - 1].text_content if index > 0 else ""
next_segment_text = self.text_segments[index + 1].text_content if index + 1 < len(self.text_segments) else ""
if segment in self.text_segments:
index = self.text_segments.index(segment)
previous_segment_texts = self.clean_texts(self.text_segments[index - 1]) if index > 0 else []
next_segment_texts = (
self.clean_texts(self.text_segments[index + 1]) if index + 1 < len(self.text_segments) else []
)
else:
index = segments.index(segment)
previous_segment_texts = self.clean_texts(segments[index - 1]) if index > 0 else ""
next_segment_texts = self.clean_texts(segments[index + 1]) if index + 1 < len(segments) else ""

for word in self.previous_words:
features.append(1 if word in previous_segment_text.lower() else 0)
features.append(1 if word in previous_segment_texts else 0)

for word in self.next_words:
features.append(1 if word in next_segment_text.lower() else 0)
features.append(1 if word in next_segment_texts else 0)

features.append(len([x for x in text if x == ","]) / len(text) if text else 0)

Expand All @@ -52,30 +59,33 @@ def get_most_common_words(train_segments):
return [x[0] for x in counter.most_common(30)]

@staticmethod
def get_predictive_common_words(segments):
def clean_texts(pdf_segment: PdfDataSegment) -> list[str]:
clean_letters = [letter for letter in pdf_segment.text_content.lower() if letter.isalnum() or letter == " "]
return "".join(clean_letters).split()

def save_predictive_common_words(self, segments):
most_common_words = FastSegmentSelector.get_most_common_words(segments)
counter_previous_segment = Counter()
counter_next_segment = Counter()

for previous_segment, segment, next_segment in zip(segments, segments[1:], segments[2:]):
if segment.ml_label:
counter_previous_segment.update(
[x for x in previous_segment.text_content.strip().lower().split() if x not in most_common_words]
)
counter_next_segment.update(
[x for x in next_segment.text_content.strip().lower().split() if x not in most_common_words]
)
break
if not segment.ml_label:
continue

return ([x[0] for x in counter_previous_segment.most_common(3)], [x[0] for x in counter_next_segment.most_common(3)])
counter_previous_segment.update([x for x in self.clean_texts(previous_segment) if x not in most_common_words])
counter_next_segment.update([x for x in self.clean_texts(next_segment) if x not in most_common_words])
break

def create_model(self, segments: list[PdfDataSegment]):
self.text_segments = [x for x in segments if x.segment_type in self.text_types]
self.previous_words, self.next_words = self.get_predictive_common_words(self.text_segments)
self.previous_words = [x[0] for x in counter_previous_segment.most_common(2)]
self.next_words = [x[0] for x in counter_next_segment.most_common(2)]

Path(self.previous_words_path).write_text(json.dumps(self.previous_words))
Path(self.next_words_path).write_text(json.dumps(self.next_words))

def create_model(self, segments: list[PdfDataSegment]):
self.text_segments = [x for x in segments if x.segment_type in self.text_types]
self.save_predictive_common_words(self.text_segments)

x, y = self.get_x_y(segments)

train_data = lgb.Dataset(x, y)
Expand All @@ -88,7 +98,7 @@ def get_x_y(self, segments):
y = []

for segment in segments:
x_rows.append(self.get_features(segment))
x_rows.append(self.get_features(segment, segments))
y.append(segment.ml_label)

x_train = np.zeros((len(x_rows), len(x_rows[0]) if x_rows else 0))
Expand All @@ -109,9 +119,11 @@ def predict(self, segments):
return [segment for i, segment in enumerate(segments) if predictions[i] > 0.5]

def load_repeated_words(self):
try:
self.previous_words = []
self.next_words = []

if exists(self.previous_words_path):
self.previous_words = json.loads(Path(self.previous_words_path).read_text())

if exists(self.next_words_path):
self.next_words = json.loads(Path(self.next_words_path).read_text())
except:
self.previous_words = []
self.next_words = []
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_appearances(self, pdf_segment: PdfDataSegment, options: list[str]) -> li
if fuzz.partial_ratio(option, pdf_segment.text_content.lower()) >= self.threshold:
appearances.append(option)

return list(set(appearances))
return list(dict.fromkeys(appearances))

def train(self, multi_option_data: ExtractionData):
marked_segments = list()
Expand All @@ -41,27 +41,28 @@ def train(self, multi_option_data: ExtractionData):
FastSegmentSelector(self.extraction_identifier).create_model(marked_segments)

def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]:
predict_data = self.get_prediction_data(multi_option_data)
return FuzzyAll95().predict(predict_data)
self.set_parameters(multi_option_data)
self.extraction_data = self.get_prediction_data(multi_option_data)
return FuzzyAll95().predict(self.extraction_data)

def get_prediction_data(self, multi_option_data):
def get_prediction_data(self, extraction_data: ExtractionData) -> ExtractionData:
fast_segment_selector = FastSegmentSelector(self.extraction_identifier)
predict_samples = list()
for sample in multi_option_data.samples:
for sample in extraction_data.samples:
selected_segments = fast_segment_selector.predict(self.fix_two_pages_segments(sample))

self.mark_segments_for_context(sample.pdf_data.pdf_data_segments, selected_segments)
self.mark_segments_for_context(selected_segments)

pdf_data = PdfData(None)
pdf_data = PdfData(None, file_name=sample.pdf_data.file_name)
pdf_data.pdf_data_segments = selected_segments

training_sample = TrainingSample(pdf_data=pdf_data, labeled_data=sample.labeled_data)
predict_samples.append(training_sample)

return ExtractionData(
samples=predict_samples,
options=multi_option_data.options,
multi_value=multi_option_data.multi_value,
options=self.extraction_data.options,
multi_value=self.extraction_data.multi_value,
extraction_identifier=self.extraction_identifier,
)

Expand Down Expand Up @@ -94,7 +95,7 @@ def get_cleaned_options(self, options: list[Option]) -> list[str]:

def get_marked_segments(self, training_sample: TrainingSample) -> list[PdfDataSegment]:
cleaned_values = self.get_cleaned_options(training_sample.labeled_data.values)
appearances_threshold = math.ceil(len(cleaned_values) * self.threshold / 100)
appearances_threshold = math.ceil(len(cleaned_values) * 0.68)

if not appearances_threshold:
return training_sample.pdf_data.pdf_data_segments
Expand All @@ -106,7 +107,6 @@ def get_marked_segments(self, training_sample: TrainingSample) -> list[PdfDataSe

if appearances_threshold <= appearances:
segment.ml_label = 1
break

return fixed_segments

Expand All @@ -117,7 +117,6 @@ def fix_two_pages_segments(self, training_sample: TrainingSample) -> list[PdfDat
merged_segment = None
for segment in training_sample.pdf_data.pdf_data_segments:
if segment == merged_segment:
fixed_segments.append(segment)
merged_segment = None
continue

Expand All @@ -138,19 +137,11 @@ def fix_segment(segment: PdfDataSegment, text_type_segments: list[PdfDataSegment
return segment, None

segment = deepcopy(segment)
text_type_segments[index + 1] = deepcopy(text_type_segments[index + 1])

segment.text_content += " " + text_type_segments[index + 1].text_content
text_type_segments[index + 1].text_content = segment.text_content

return segment, text_type_segments[index + 1]

@staticmethod
def mark_segments_for_context(all_segments: list[PdfDataSegment], selected_segments: list[PdfDataSegment]):
for segment in all_segments:
for selected_segment in selected_segments:
if segment.page_number != selected_segment.page_number:
continue

if segment.bounding_box.get_intersection_percentage(selected_segment.bounding_box) > 0.1:
segment.ml_label = 1
break
def mark_segments_for_context(segments: list[PdfDataSegment]):
for segment in segments:
segment.ml_label = 1
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,10 @@
class FastSegmentSelectorFuzzyCommas(FastSegmentSelectorFuzzy95):

def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]:
predict_data = self.get_prediction_data(multi_option_data)
return FuzzyCommas().predict(predict_data)
self.set_parameters(multi_option_data)
self.extraction_data = self.get_prediction_data(multi_option_data)
return FuzzyCommas().predict(self.extraction_data)

def train(self, multi_option_data: ExtractionData):
super().train(multi_option_data)
FuzzyCommas().train(multi_option_data)
Loading

0 comments on commit 3341c1b

Please sign in to comment.