Skip to content

Commit

Permalink
Merge pull request #82 from huridocs/fix-methods-errors
Browse files Browse the repository at this point in the history
Fix methods errors
  • Loading branch information
gabriel-piles authored Aug 2, 2024
2 parents f6d3287 + 5fb0635 commit 5f06578
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 37 deletions.
8 changes: 6 additions & 2 deletions src/Extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,12 @@ def remove_old_models(extractor_identifier: ExtractionIdentifier):

@staticmethod
def calculate_task(extraction_task: ExtractionTask) -> (bool, str):
extraction_name = extraction_task.params.id
extractor_identifier = ExtractionIdentifier(run_name=extraction_task.tenant, extraction_name=extraction_name)
extractor_identifier = ExtractionIdentifier(
run_name=extraction_task.tenant,
extraction_name=extraction_task.params.id,
metadata=extraction_task.params.metadata,
)

Extractor.remove_old_models(extractor_identifier)

if extraction_task.task == Extractor.CREATE_MODEL_TASK_NAME:
Expand Down
23 changes: 18 additions & 5 deletions src/QueueProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def __init__(self):
def process(self, id, message, rc, ts):
try:
task = ExtractionTask(**message)
config_logger.info(f"New task {task.model_dump()}")
config_logger.info(f"New task {self.task_to_string(task)}")
except ValidationError:
config_logger.error(f"Not a valid Redis message: {message}")
return True

self.log_process_information(message)
self.log_process_information(task)

task_calculated, error_message = Extractor.calculate_task(task)
if error_message:
Expand Down Expand Up @@ -77,13 +77,26 @@ def process(self, id, message, rc, ts):
error_message="",
data_url=data_url,
)
config_logger.info(model_results_message.model_dump())

config_logger.info(model_results_message.to_string())
self.results_queue.sendMessage().message(model_results_message.model_dump()).execute()
return True

def log_process_information(self, message):
@staticmethod
def task_to_string(extraction_task: ExtractionTask):
extraction_dict = extraction_task.model_dump()
if (
"params" in extraction_dict
and "options" in extraction_dict["params"]
and 10 < len(extraction_dict["params"]["options"])
):
extraction_dict["params"]["options"] = f'[hidden {len(extraction_dict["params"]["options"])} options]'

return str(extraction_dict)

def log_process_information(self, extraction_task: ExtractionTask):
try:
config_logger.info(f"Processing Redis message: {message}")
config_logger.info(f"Processing Redis message: {self.task_to_string(extraction_task)}")
config_logger.info(
f"Messages pending in queue: {self.task_queue.getQueueAttributes().exec_command()['msgs'] - 1}"
)
Expand Down
3 changes: 0 additions & 3 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ async def get_suggestions(tenant: str, extraction_id: str):

pdf_metadata_extraction_db.suggestions.delete_many(suggestions_filter)
config_logger.info(f"{len(suggestions_list)} suggestions created for {tenant} {extraction_id}")
if len(suggestions_list) > 2:
config_logger.info(json.dumps(suggestions_list[0]))
config_logger.info(json.dumps(suggestions_list[1]))

return json.dumps(suggestions_list)
except Exception:
Expand Down
1 change: 1 addition & 0 deletions src/data/ExtractionIdentifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class ExtractionIdentifier(BaseModel):
run_name: str
extraction_name: str
metadata: dict[str, str] = dict()

def get_path(self):
return join(DATA_PATH, self.run_name, self.extraction_name)
Expand Down
1 change: 1 addition & 0 deletions src/data/Params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class Params(BaseModel):
id: str
options: list[Option] = list()
multi_value: bool = False
metadata: dict[str, str] = dict()
3 changes: 3 additions & 0 deletions src/data/ResultsMessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ class ResultsMessage(BaseModel):
success: bool
error_message: str
data_url: Optional[str] = None

def to_string(self):
return f"tenant: {self.tenant}, id: {self.params.id}, task: {self.task}, success: {self.success}, error_message: {self.error_message}"
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import shutil
from collections import Counter
from os.path import join, exists
from pathlib import Path

Expand Down Expand Up @@ -90,7 +92,7 @@ def __init__(self, extraction_identifier: ExtractionIdentifier):
def create_model(self, extraction_data: ExtractionData):
self.options = extraction_data.options
self.multi_value = extraction_data.multi_value

send_logs(self.extraction_identifier, self.get_stats(extraction_data))
method = self.get_best_method(extraction_data)
method.train(extraction_data)

Expand Down Expand Up @@ -149,24 +151,28 @@ def load_options(self):
def get_best_method(self, multi_option_data: ExtractionData) -> PdfMultiOptionMethod:
best_method_instance = self.METHODS[0]
best_performance = 0
performance_log = "Performance aggregation:\n"
for method in self.METHODS:
performance = self.get_method_performance(method, multi_option_data)

performance_log += f"{method.get_name()}: {round(performance, 2)}%\n"
if performance == 100:
send_logs(self.extraction_identifier, performance_log)
send_logs(self.extraction_identifier, f"Best method {method.get_name()} with {performance}%")
return method

if performance > best_performance:
best_performance = performance
best_method_instance = method

send_logs(self.extraction_identifier, performance_log)
send_logs(self.extraction_identifier, f"Best method {best_method_instance.get_name()}")
return best_method_instance

def get_method_performance(self, method: PdfMultiOptionMethod, multi_option_data: ExtractionData):
def get_method_performance(self, method: PdfMultiOptionMethod, multi_option_data: ExtractionData) -> float:
method.set_parameters(multi_option_data)

if not method.can_be_used(multi_option_data):
send_logs(self.extraction_identifier, f"Not valid method {method.get_name()}")
return 0

send_logs(self.extraction_identifier, f"Checking {method.get_name()}")
Expand All @@ -178,7 +184,11 @@ def get_method_performance(self, method: PdfMultiOptionMethod, multi_option_data
performance = 0

self.reset_extraction_data(multi_option_data)
send_logs(self.extraction_identifier, f"Performance {method.get_name()}: {performance}%")

if method.multi_label_method:
shutil.rmtree(method.base_path, ignore_errors=True)

send_logs(self.extraction_identifier, f"Performance {method.get_name()}: {round(performance, 2)}%")
return performance

def get_predictions_method(self):
Expand All @@ -204,3 +214,21 @@ def reset_extraction_data(multi_option_data: ExtractionData):
for sample in multi_option_data.samples:
for segment in sample.pdf_data.pdf_data_segments:
segment.ml_label = 0

@staticmethod
def get_stats(extraction_data: ExtractionData):
options = Counter()
for sample in extraction_data.samples:
options.update([option.label for option in sample.labeled_data.values])
languages = Counter()
for sample in extraction_data.samples:
languages.update([sample.labeled_data.language_iso])

options_count = len(extraction_data.options)
stats = f"\nNumber of options: {options_count}\n"
stats += f"Number of samples: {len(extraction_data.samples)}\n"
stats += f"Languages\n"
stats += "\n".join([f"{key} {value}" for key, value in languages.most_common()])
stats += f"\nOptions\n"
stats += "\n".join([f"{key} {value}" for key, value in options.most_common()])
return stats
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from os.path import join, exists

import pandas as pd
import torch
from datasets import load_dataset

from data.ExtractionData import ExtractionData
Expand All @@ -14,6 +15,7 @@
from extractors.bert_method_scripts.EarlyStoppingAfterInitialTraining import EarlyStoppingAfterInitialTraining
from extractors.bert_method_scripts.get_batch_size import get_batch_size, get_max_steps
from extractors.pdf_to_multi_option_extractor.MultiLabelMethod import MultiLabelMethod
from send_logs import send_logs


class SetFitMethod(MultiLabelMethod):
Expand Down Expand Up @@ -105,6 +107,10 @@ def predict(self, multi_option_data: ExtractionData) -> list[list[Option]]:
return self.predictions_to_options_list(predictions.tolist())

def can_be_used(self, extraction_data: ExtractionData) -> bool:
if not torch.cuda.is_available():
send_logs(self.extraction_identifier, f"GPU not available for {self.get_name()}")
return False

if not extraction_data.multi_value:
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from os.path import join, exists

import pandas as pd
import torch.cuda
from datasets import load_dataset

from data.ExtractionData import ExtractionData
Expand All @@ -14,13 +15,18 @@
from extractors.bert_method_scripts.EarlyStoppingAfterInitialTraining import EarlyStoppingAfterInitialTraining
from extractors.bert_method_scripts.get_batch_size import get_batch_size, get_max_steps
from extractors.pdf_to_multi_option_extractor.MultiLabelMethod import MultiLabelMethod
from send_logs import send_logs


class SingleLabelSetFitMethod(MultiLabelMethod):

model_name = "sentence-transformers/paraphrase-mpnet-base-v2"

def can_be_used(self, extraction_data: ExtractionData) -> bool:
if not torch.cuda.is_available():
send_logs(self.extraction_identifier, f"GPU not available for {self.get_name()}")
return False

if extraction_data.multi_value:
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def get_sentence_segment_list(self, pdf_data_segments) -> list[(str, PdfDataSegm

sentence_segment_list.append((text, segment))

if not sentence_segment_list:
return list()

sentences_across_pages = list()
sentences_across_pages.append(sentence_segment_list[0])
for sentence, next_sentence in zip(sentence_segment_list, sentence_segment_list[1:]):
Expand All @@ -63,6 +66,9 @@ def get_sentence_segment_list(self, pdf_data_segments) -> list[(str, PdfDataSegm

def get_segments_merged(self, segments):
segments = [segment for segment in segments if segment.text_content.strip()]
if not segments:
return list()

merged_sentences = [segments[0]]
for segment in segments[1:]:
previous_segment_text = " ".join(merged_sentences[-1].text_content.split())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@ class SameInputOutputMethod(TextToTextMethod):
def train(self, extraction_data: ExtractionData):
pass

@staticmethod
def trim_text(tag_texts: list[str]) -> str:
text = " ".join(tag_texts)
return " ".join(text.split())

def predict(self, predictions_samples: list[PredictionSample]) -> list[str]:
return [" ".join(x.tags_texts) for x in predictions_samples]
return [self.trim_text(x.tags_texts) for x in predictions_samples]
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def test_performance_100_with_multiline(self):
former Yugoslav Republic of Macedonia, United Kingdom of Great Britain and Northern
Ireland, Venezuela (Bolivarian Republic of)"""

tags_text = (
"Albania, Algeria, Argentina, Bolivia (Plurinational State of), Brazil, Congo, Côte d’Ivoire, "
"El Salvador, Estonia, France, Gabon, Germany, Ireland, Kazakhstan, Latvia, Mexico, "
"Montenegro, Namibia, Netherlands, Paraguay, Portugal, Sierra Leone, South Africa, "
"the former Yugoslav Republic of Macedonia, United Kingdom of Great Britain and "
"Northern Ireland, Venezuela (Bolivarian Republic of)"
)

sample = TrainingSample(labeled_data=LabeledData(label_text=label_text, language_iso="en"), tags_texts=[tags_text])
tags_text = [
"Albania, Algeria, Argentina, Bolivia (Plurinational State of), Brazil, Congo, Côte d’Ivoire, ",
"El Salvador, Estonia, France, Gabon, Germany, Ireland, Kazakhstan, Latvia, Mexico, ",
"Montenegro, Namibia, Netherlands, Paraguay, Portugal, Sierra Leone, South Africa, ",
"the former Yugoslav Republic of Macedonia, United Kingdom of Great Britain and",
"Northern Ireland, Venezuela (Bolivarian Republic of)",
]

sample = TrainingSample(labeled_data=LabeledData(label_text=label_text, language_iso="en"), tags_texts=tags_text)

extraction_data = ExtractionData(samples=[sample], extraction_identifier=extraction_identifier)

Expand Down
2 changes: 1 addition & 1 deletion src/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestApp(TestCase):

def test_info(self):
with TestClient(app) as client:
response = client.get("/info")
response = client.get("/")

self.assertEqual(200, response.status_code)

Expand Down
24 changes: 12 additions & 12 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ def test_pdf_to_text(self):
task = ExtractionTask(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
)
QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()

results_message = self.get_results_message()
expected_result = ResultsMessage(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
success=True,
error_message="",
data_url=None,
Expand All @@ -94,15 +94,15 @@ def test_pdf_to_text(self):
task = ExtractionTask(
tenant=tenant,
task="suggestions",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
)
QUEUE.sendMessage(delay=0).message(str(task.model_dump_json())).execute()

results_message = self.get_results_message()
expected_result = ResultsMessage(
tenant=tenant,
task="suggestions",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
success=True,
error_message="",
data_url=f"{SERVER_URL}/get_suggestions/{tenant}/{extraction_id}",
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_create_model_without_data(self):
task = ExtractionTask(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()
Expand All @@ -146,7 +146,7 @@ def test_create_model_without_data(self):
expected_result = ResultsMessage(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
success=False,
error_message="No data to create model",
data_url=None,
Expand All @@ -157,15 +157,15 @@ def test_create_model_without_data(self):
task = ExtractionTask(
tenant=tenant,
task="suggestions",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
)
QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()

results_message = self.get_results_message()
expected_result = ResultsMessage(
tenant=tenant,
task="suggestions",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
success=False,
error_message="No data to calculate suggestions",
data_url=None,
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_pdf_to_multi_option(self):
task = ExtractionTask(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id, options=options, multi_value=False),
params=Params(id=extraction_id, options=options, multi_value=False, metadata={"name": "test"}),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()
Expand All @@ -226,7 +226,7 @@ def test_pdf_to_multi_option(self):
task = ExtractionTask(
tenant=tenant,
task="suggestions",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_text_to_multi_option(self):
task = ExtractionTask(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id, options=options, multi_value=True),
params=Params(id=extraction_id, options=options, multi_value=True, metadata={"name": "test"}),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()
Expand Down Expand Up @@ -308,7 +308,7 @@ def test_text_to_multi_option(self):
task = ExtractionTask(
tenant=tenant,
task="suggestions",
params=Params(id=extraction_id),
params=Params(id=extraction_id, metadata={"name": "test"}),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()
Expand Down

0 comments on commit 5f06578

Please sign in to comment.