diff --git a/backend/problem/management/commands/import_fracas.py b/backend/problem/management/commands/import_fracas.py index 2af1b8a..5f18879 100644 --- a/backend/problem/management/commands/import_fracas.py +++ b/backend/problem/management/commands/import_fracas.py @@ -4,13 +4,20 @@ from tqdm import tqdm from langpro_annotator.logger import logger -from problem.services import get_fracas_problems +from problem.services import FracasData from problem.models import Problem class Command(BaseCommand): help = "Import FraCaS problems from fracas.xml." + ENTAILMENT_LABELS = { + "yes": Problem.EntailmentLabel.ENTAILMENT, + "no": Problem.EntailmentLabel.CONTRADICTION, + "unknown": Problem.EntailmentLabel.NEUTRAL, + "undefined": Problem.EntailmentLabel.UNKNOWN, + } + def add_arguments(self, parser): parser.add_argument( "--file", @@ -25,13 +32,6 @@ def handle(self, *args, **options): fracas_path = options["fracas_path"] self.import_fracas_problems(fracas_path) - @staticmethod - def _text_from_element(element: ET.Element) -> str: - """ - Extracts stripped text from an XML element, returning an empty string if the element is None or has no text. - """ - return element.text.strip() if element is not None and element.text else "" - @staticmethod def _annotate_section_subsections(tree: ET.ElementTree) -> None: """ @@ -72,7 +72,9 @@ def import_fracas_problems(self, fracas_path: str) -> None: created = 0 skipped = 0 - existing_fracas_problems = get_fracas_problems() + existing_fracas_problems = Problem.objects.filter( + dataset=Problem.Dataset.FRACAS + ) existing_fracas_ids = {p.fracas_id for p in existing_fracas_problems} for problem in tqdm(all_problems, desc="Importing FraCaS problems"): @@ -88,33 +90,22 @@ def import_fracas_problems(self, fracas_path: str) -> None: skipped += 1 continue - question = self._text_from_element(problem.find("q")) - hypothesis = self._text_from_element(problem.find("h")) - answer = self._text_from_element(problem.find("a")) - note = self._text_from_element(problem.find("note")) - - section = problem.get("section") - subsection = problem.get("subsection") + hypothesis = FracasData._text_from_element(problem.find("h")) fracas_answer = problem.get("fracas_answer") - fracas_nonstandard = problem.get("fracas_nonstandard", False) == "true" - premise_nodes = problem.findall("p") premises = [node.text.strip() for node in premise_nodes if node.text] + entailment_label = self.ENTAILMENT_LABELS.get( + fracas_answer, Problem.EntailmentLabel.UNKNOWN + ) + + extra_data = FracasData.import_data(problem) Problem.objects.create( - type=Problem.ProblemType.FRACAS, - content={ - "fracas_id": int(problem_id), - "question": question, - "hypothesis": hypothesis, - "answer": answer, - "fracas_answer": fracas_answer, - "fracas_non_standard": fracas_nonstandard, - "note": note, - "section_name": section, - "subsection_name": subsection, - "premises": premises, - }, + dataset=Problem.Dataset.FRACAS, + premises=premises, + hypothesis=hypothesis, + entailment_label=entailment_label, + extra_data=extra_data, ) created += 1 diff --git a/backend/problem/management/commands/import_sick.py b/backend/problem/management/commands/import_sick.py index f35c8f9..eeada16 100644 --- a/backend/problem/management/commands/import_sick.py +++ b/backend/problem/management/commands/import_sick.py @@ -5,12 +5,18 @@ from langpro_annotator.logger import logger from problem.models import Problem -from problem.services import get_sick_problems +from problem.services import SickData class Command(BaseCommand): help = "Import SICK problems from SICK.txt (a TSV file)." + ENTAILMENT_LABELS = { + "NEUTRAL": Problem.EntailmentLabel.NEUTRAL, + "ENTAILMENT": Problem.EntailmentLabel.ENTAILMENT, + "CONTRADICTION": Problem.EntailmentLabel.CONTRADICTION, + } + def add_arguments(self, parser): parser.add_argument( "--file", @@ -33,7 +39,7 @@ def import_sick_problems(self, sick_path: str) -> None: skipped = 0 created = 0 - existing_sick_problems = get_sick_problems() + existing_sick_problems = Problem.objects.filter(dataset=Problem.Dataset.SICK) existing_pair_ids = {p.pair_id for p in existing_sick_problems} with open(sick_path, "r", encoding="utf-8") as file: @@ -45,11 +51,20 @@ def import_sick_problems(self, sick_path: str) -> None: skipped += 1 continue - created += 1 + entailment_label = self.ENTAILMENT_LABELS.get( + problem["entailment_label"], Problem.EntailmentLabel.UNKNOWN + ) + + extra_data = SickData.import_data(problem) + Problem.objects.create( - type=Problem.ProblemType.SICK, - content=problem, + dataset=Problem.Dataset.SICK, + premises=[problem["sentence_A"]], + hypothesis=problem["sentence_B"], + entailment_label=entailment_label, + extra_data=extra_data, ) + created += 1 logger.info( f"SICK problems import complete! Created: {created} | Skipped: {skipped}" diff --git a/backend/problem/management/commands/import_snli.py b/backend/problem/management/commands/import_snli.py index 83607fc..1d8cbe2 100644 --- a/backend/problem/management/commands/import_snli.py +++ b/backend/problem/management/commands/import_snli.py @@ -5,11 +5,19 @@ from langpro_annotator.logger import logger from problem.models import Problem +from problem.services import SNLIData class Command(BaseCommand): help = "Import SNLI 1.0 problems and save them in the DB. Use the flags --dev, --train, --test to specify the paths to the SNLI files. The development set contains 10K problems, the training set contains 550K problems, and the test set contains 10K problems." + ENTAILMENT_LABELS = { + "entailment": Problem.EntailmentLabel.ENTAILMENT, + "contradiction": Problem.EntailmentLabel.CONTRADICTION, + "neutral": Problem.EntailmentLabel.NEUTRAL, + "none": Problem.EntailmentLabel.UNKNOWN, # For empty gold labels. + } + def add_arguments(self, parser): parser.add_argument( "--dev", @@ -53,8 +61,10 @@ def import_snli_problems(self, snli_paths: list[tuple[str, str]]) -> None: skipped = 0 created = 0 - existing_snli_problems = Problem.objects.filter(type=Problem.ProblemType.SNLI) - existing_pair_ids = {p.content.get("pairID") for p in existing_snli_problems} + existing_snli_problems = Problem.objects.filter(dataset=Problem.Dataset.SNLI) + existing_pair_ids = existing_snli_problems.values_list( + "extra_data__pair_id", flat=True + ) for subset, snli_path in snli_paths: try: @@ -70,20 +80,31 @@ def import_snli_problems(self, snli_paths: list[tuple[str, str]]) -> None: skipped += 1 continue - problem["subset"] = subset - - # Handle empty gold labels. - if problem["gold_label"] == "-": - problem["gold_label"] = "none" - # Handle empty labels. - for key in ["label1", "label2", "label3", "label4", "label5"]: - if problem[key] == "": - problem[key] = "none" + for key in [ + "gold_label", + "label1", + "label2", + "label3", + "label4", + "label5", + ]: + label_value = problem.get(key, "") + if label_value in ["-", ""]: + problem[key] = self.ENTAILMENT_LABELS["none"] + else: + problem[key] = self.ENTAILMENT_LABELS.get( + label_value, Problem.EntailmentLabel.UNKNOWN + ) + + extra_data = SNLIData.import_data(problem, subset) Problem.objects.create( - type=Problem.ProblemType.SNLI, - content=problem, + dataset=Problem.Dataset.SNLI, + premises=[problem["sentence1"]], + hypothesis=problem["sentence2"], + entailment_label=problem["gold_label"], + extra_data=extra_data, ) created += 1 existing_pair_ids.add(problem["pairID"]) diff --git a/backend/problem/migrations/0003_rename_content_problem_extra_data_and_more.py b/backend/problem/migrations/0003_rename_content_problem_extra_data_and_more.py new file mode 100644 index 0000000..5bf4049 --- /dev/null +++ b/backend/problem/migrations/0003_rename_content_problem_extra_data_and_more.py @@ -0,0 +1,63 @@ +# Generated by Django 4.2.20 on 2025-07-08 10:16 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("problem", "0002_alter_problem_type"), + ] + + operations = [ + migrations.RenameField( + model_name="problem", + old_name="content", + new_name="extra_data", + ), + migrations.RemoveField( + model_name="problem", + name="type", + ), + migrations.AddField( + model_name="problem", + name="dataset", + field=models.CharField( + choices=[ + ("sick", "Sick"), + ("fracas", "FraCaS"), + ("snli", "SNLI"), + ("user", "User"), + ], + default="user", + max_length=255, + ), + ), + migrations.AddField( + model_name="problem", + name="entailment_label", + field=models.CharField( + choices=[ + ("neutral", "Neutral"), + ("entailment", "Entailment"), + ("contradiction", "Contradiction"), + ("unknown", "Unknown"), + ], + default="unknown", + max_length=255, + ), + ), + migrations.AddField( + model_name="problem", + name="hypothesis", + field=models.CharField(blank=True, max_length=512, null=True), + ), + migrations.AddField( + model_name="problem", + name="premises", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.CharField(max_length=512), default=list, size=None + ), + ), + ] diff --git a/backend/problem/models.py b/backend/problem/models.py index 89d5a0d..54b8a05 100644 --- a/backend/problem/models.py +++ b/backend/problem/models.py @@ -1,20 +1,47 @@ from django.db import models +from django.contrib.postgres.fields import ArrayField +from problem.services import FracasData, SNLIData, SickData from langpro_annotator.logger import logger class Problem(models.Model): - class ProblemType(models.TextChoices): + class Dataset(models.TextChoices): SICK = "sick", "Sick" FRACAS = "fracas", "FraCaS" SNLI = "snli", "SNLI" + USER = "user", "User" - type = models.CharField( + class EntailmentLabel(models.TextChoices): + NEUTRAL = "neutral", "Neutral" + ENTAILMENT = "entailment", "Entailment" + CONTRADICTION = "contradiction", "Contradiction" + UNKNOWN = "unknown", "Unknown" + + dataset = models.CharField( max_length=255, - choices=ProblemType.choices, + choices=Dataset.choices, + default=Dataset.USER, + ) + + premises = ArrayField( + models.CharField(max_length=512), + default=list, ) - content = models.JSONField() + hypothesis = models.CharField( + max_length=512, + blank=True, + null=True, + ) + + entailment_label = models.CharField( + max_length=255, + choices=EntailmentLabel.choices, + default=EntailmentLabel.UNKNOWN, + ) + + extra_data = models.JSONField() def get_index(self) -> int | None: """ @@ -25,3 +52,27 @@ def get_index(self) -> int | None: except Exception as e: logger.error(f"Error getting index for problem {self.id}: {e}") return None + + def serialize(self) -> dict: + """ + Serialize the Problem instance to a dictionary. + """ + + match self.dataset: + case self.Dataset.SICK: + serialized_extra_data = SickData.serialize(self.extra_data) + case self.Dataset.FRACAS: + serialized_extra_data = FracasData.serialize(self.extra_data) + case self.Dataset.SNLI: + serialized_extra_data = SNLIData.serialize(self.extra_data) + case _: + serialized_extra_data = {} + + return { + "id": self.id, + "dataset": self.dataset, + "premises": self.premises, + "hypothesis": self.hypothesis, + "entailmentLabel": self.entailment_label, + "extraData": serialized_extra_data, + } diff --git a/backend/problem/problem_details.py b/backend/problem/problem_details.py new file mode 100644 index 0000000..2bce2c0 --- /dev/null +++ b/backend/problem/problem_details.py @@ -0,0 +1,26 @@ +from typing import Tuple +from langpro_annotator.logger import logger +from problem.models import Problem + + +def get_related_problem_ids(problem_id: int) -> Tuple[int, int, int]: + """ + Retrieves the IDs of the next, previous, and random Problem objects + in the database relative to the given problem ID. + """ + + try: + problem = Problem.objects.get(id=problem_id) + except Problem.DoesNotExist: + logger.warning(f"Problem ID {problem_id} does not exist.") + return None, None, None + + next_problem = Problem.objects.filter(id__gt=problem.id).order_by("id").first() + previous_problem = Problem.objects.filter(id__lt=problem.id).order_by("-id").first() + random_problem = Problem.objects.exclude(id=problem.id).order_by("?").first() + + return ( + next_problem.id if next_problem else None, + previous_problem.id if previous_problem else None, + random_problem.id if random_problem else None, + ) diff --git a/backend/problem/services.py b/backend/problem/services.py index d7f9394..794fa94 100644 --- a/backend/problem/services.py +++ b/backend/problem/services.py @@ -1,160 +1,102 @@ -import json -from typing import Tuple -from langpro_annotator.logger import logger -from problem.models import Problem -from problem.types import CombinedProblem, FracasProblem, SNLIProblem, SickProblem - - -def instance_to_sick_problem(instance: Problem) -> SickProblem | None: - """ - Converts a Problem instance to a SickProblem object. - """ - try: - content: dict = instance.content - return SickProblem( - pair_id=content["pair_ID"], - sentence_one=content["sentence_A"], - sentence_two=content["sentence_B"], - entailment_label=content["entailment_label"], - relatedness_score=float(content["relatedness_score"]), - ) - except json.JSONDecodeError as e: - logger.warning(f"Could not decode JSON for Problem ID {instance.id}: {e}") - return None - except Exception as e: - logger.warning( - f"Could not convert Problem ID {instance.id} to SickProblem: {e}" - ) - return None - - -def instance_to_fracas_problem(instance: Problem) -> FracasProblem | None: - """ - Converts a Problem instance to a FracasProblem object. - """ - try: - content: dict = instance.content - return FracasProblem( - fracas_id=content["fracas_id"], - question=content["question"], - hypothesis=content["hypothesis"], - answer=content["answer"], - fracas_answer=content["fracas_answer"], - fracas_non_standard=content["fracas_non_standard"], - note=content["note"], - section_name=content["section_name"], - subsection_name=content["subsection_name"], - premises=content.get("premises", []), - ) - except json.JSONDecodeError as e: - logger.warning(f"Could not decode JSON for Problem ID {instance.id}: {e}") - return None - except Exception as e: - logger.warning( - f"Could not convert Problem ID {instance.id} to FracasProblem: {e}" - ) - return None - - -def instance_to_snli_problem(instance: Problem) -> SNLIProblem | None: - """ - Converts a Problem instance to a SNLIProblem object. - """ - try: - content: dict = instance.content - return SNLIProblem( - pair_id=content["pairID"], - subset=content["subset"], - sentence_one=content["sentence1"], - sentence_two=content["sentence2"], - gold_label=content["gold_label"], - labels=[ - content["label1"], - content["label2"], - content["label3"], - content["label4"], - content["label5"], - ] - ) - except json.JSONDecodeError as e: - logger.warning(f"Could not decode JSON for Problem ID {instance.id}: {e}") - return None - except Exception as e: - logger.warning( - f"Could not convert Problem ID {instance.id} to SNLIProblem: {e}" - ) - return None - - -def get_sick_problems() -> list[SickProblem]: - """ - Retrieves all Problem objects of type 'SICK' from the database - and converts them into SickProblem instances. - """ - problems = Problem.objects.filter(type=Problem.ProblemType.SICK) - return [ - converted - for problem in problems - if (converted := instance_to_sick_problem(problem)) is not None - ] - - -def get_fracas_problems() -> list[FracasProblem]: - """ - Retrieves all Problem objects of type 'Fracas' from the database - and converts them into FracasProblem instances. - """ - problems = Problem.objects.filter(type=Problem.ProblemType.FRACAS) - return [ - converted - for problem in problems - if (converted := instance_to_fracas_problem(problem)) is not None - ] - -def get_snli_problems() -> list[SNLIProblem]: - """ - Retrieves all Problem objects of type 'SNLI' from the database - and converts them into SNLIProblem instances. - """ - problems = Problem.objects.filter(type=Problem.ProblemType.SNLI) - return [ - converted - for problem in problems - if (converted := instance_to_snli_problem(problem)) is not None - ] - -def convert_to_subtype(problem: Problem) -> CombinedProblem | None: - """ - Converts a Django Problem model instance to a specific subtype (dataclass) - based on its type. - """ - if problem.type == Problem.ProblemType.SICK: - return instance_to_sick_problem(problem) - elif problem.type == Problem.ProblemType.FRACAS: - return instance_to_fracas_problem(problem) - elif problem.type == Problem.ProblemType.SNLI: - return instance_to_snli_problem(problem) - else: - return None - - -def get_related_problem_ids(problem_id: int) -> Tuple[int, int, int]: - """ - Retrieves the IDs of the next, previous, and random Problem objects - in the database relative to the given problem ID. - """ - try: - problem = Problem.objects.get(id=problem_id) - except Problem.DoesNotExist: - logger.warning(f"Problem ID {problem_id} does not exist.") - return None, None, None - - next_problem = Problem.objects.filter(id__gt=problem.id).order_by("id").first() - previous_problem = Problem.objects.filter(id__lt=problem.id).order_by("-id").first() - random_problem = Problem.objects.exclude(id=problem.id).order_by("?").first() - - return ( - next_problem.id if next_problem else None, - previous_problem.id if previous_problem else None, - random_problem.id if random_problem else None, - ) +import xml.etree.ElementTree as ET +from typing import Literal + + +class SickData: + @staticmethod + def import_data(problem: dict) -> dict: + """ + Import SICK-specific data from a problem dictionary. + """ + pair_id = problem.get("pair_ID", "") + relatedness_score = float(problem.get("relatedness_score", 0.0)) + + return { + "pair_id": pair_id, + "relatedness_score": relatedness_score, + } + + @staticmethod + def serialize(extra_data: dict) -> dict: + """ + Serialize SICK-specific data from a Problem instance. + """ + return { + "pairId": extra_data.get("pair_id", ""), + "relatednessScore": extra_data.get("relatedness_score", 0.0), + } + + +class FracasData: + + @staticmethod + def _text_from_element(element: ET.Element) -> str: + """ + Extracts stripped text from an XML element, returning an empty string if the element is None or has no text. + """ + return element.text.strip() if element is not None and element.text else "" + + @staticmethod + def import_data(problem: dict) -> dict: + problem_id = problem.get("id") + question = FracasData._text_from_element(problem.find("q")) + answer = FracasData._text_from_element(problem.find("a")) + note = FracasData._text_from_element(problem.find("note")) + + section = problem.get("section") + subsection = problem.get("subsection") + fracas_nonstandard = problem.get("fracas_nonstandard", False) == "true" + + return { + "fracas_id": int(problem_id), + "question": question, + "answer": answer, + "note": note, + "section_name": section, + "subsection_name": subsection, + "fracas_non_standard": fracas_nonstandard, + } + + @staticmethod + def serialize(extra_data: dict) -> dict: + """ + Serialize FraCaS-specific data from a Problem instance. + """ + return { + "fracasId": extra_data.get("fracas_id", 0), + "question": extra_data.get("question", ""), + "answer": extra_data.get("answer", ""), + "note": extra_data.get("note", ""), + "sectionName": extra_data.get("section_name", ""), + "subsectionName": extra_data.get("subsection_name", ""), + "fracasNonStandard": extra_data.get("fracas_non_standard", False), + } + + +class SNLIData: + @staticmethod + def import_data(problem: dict, subset: Literal["dev", "train", "test"]) -> dict: + return { + "pair_id": problem["pairID"], + "subset": subset, + "label1": problem["label1"], + "label2": problem["label2"], + "label3": problem["label3"], + "label4": problem["label4"], + "label5": problem["label5"], + } + + @staticmethod + def serialize(extra_data: dict) -> dict: + """ + Serialize SNLI-specific data from a Problem instance. + """ + return { + "pairId": extra_data.get("pair_ID", ""), + "subset": extra_data.get("subset", ""), + "label1": extra_data.get("label1", ""), + "label2": extra_data.get("label2", ""), + "label3": extra_data.get("label3", ""), + "label4": extra_data.get("label4", ""), + "label5": extra_data.get("label5", ""), + } diff --git a/backend/problem/types.py b/backend/problem/types.py index 0788d9e..8b13789 100644 --- a/backend/problem/types.py +++ b/backend/problem/types.py @@ -1,71 +1 @@ -from typing import Literal -from dataclasses import dataclass, field - -@dataclass(frozen=True) -class SickProblem: - pair_id: int - sentence_one: str - sentence_two: str - entailment_label: Literal["neutral", "contradiction", "entailment"] - relatedness_score: float - - def serialize(self) -> dict: - return { - "pairId": self.pair_id, - "sentenceOne": self.sentence_one, - "sentenceTwo": self.sentence_two, - "entailmentLabel": self.entailment_label, - "relatednessScore": self.relatedness_score, - } - - -@dataclass(frozen=True) -class FracasProblem: - fracas_id: int - question: str - hypothesis: str - answer: str - fracas_answer: Literal["yes", "no", "unknown", "undefined"] - fracas_non_standard: bool - note: str - section_name: str - subsection_name: str - premises: list[str] = field(default_factory=list) - - def serialize(self) -> dict: - return { - "fracasId": self.fracas_id, - "question": self.question, - "hypothesis": self.hypothesis, - "answer": self.answer, - "fracasAnswer": self.fracas_answer, - "fracasNonStandard": self.fracas_non_standard, - "note": self.note, - "sectionName": self.section_name, - "subsectionName": self.subsection_name, - "premises": self.premises, - } - - -@dataclass(frozen=True) -class SNLIProblem: - pair_id: int - subset: Literal["train", "dev", "test"] - sentence_one: str - sentence_two: str - gold_label: Literal["neutral", "contradiction", "entailment", "none"] - labels: list[Literal["neutral", "contradiction", "entailment", "none"]] - - def serialize(self) -> dict: - return { - "pairId": self.pair_id, - "subset": self.subset, - "sentenceOne": self.sentence_one, - "sentenceTwo": self.sentence_two, - "goldLabel": self.gold_label, - "labels": self.labels, - } - - -type CombinedProblem = SickProblem | FracasProblem | SNLIProblem diff --git a/backend/problem/views.py b/backend/problem/views.py index a59d904..95bcc92 100644 --- a/backend/problem/views.py +++ b/backend/problem/views.py @@ -1,22 +1,16 @@ from dataclasses import dataclass -from typing import Literal from django.http import JsonResponse from rest_framework.views import APIView +from problem.problem_details import get_related_problem_ids from problem.models import Problem -from problem.types import CombinedProblem -from problem.services import ( - convert_to_subtype, - get_related_problem_ids, -) @dataclass class ProblemResponse: id: int | None = None index: int | None = None - type: Literal["sick", "fracas", "snli"] | None = None - problem: CombinedProblem | None = None + problem: Problem | None = None error: str | None = None next: str | None = None previous: str | None = None @@ -27,7 +21,6 @@ def json_response(self, status=200) -> JsonResponse: { "id": self.id, "index": self.index, - "type": self.type, "problem": self.problem.serialize() if self.problem else None, "error": self.error, "next": self.next, @@ -66,15 +59,8 @@ def get(self, request, problem_id: int): error="Problem not found", ).json_response(status=404) - converted_problem = convert_to_subtype(problem) - problem_index = problem.get_index() - if converted_problem is None: - return ProblemResponse( - error="Problem not found", - ).json_response(status=404) - next_problem_id, previous_problem_id, random_problem_id = ( get_related_problem_ids(problem_id) ) @@ -82,8 +68,7 @@ def get(self, request, problem_id: int): return ProblemResponse( id=problem.id, index=problem_index, - type=problem.type, - problem=converted_problem, + problem=problem, next=next_problem_id, previous=previous_problem_id, random=random_problem_id, diff --git a/frontend/src/app/annotate/annotation-input/annotation-input.component.html b/frontend/src/app/annotate/annotation-input/annotation-input.component.html index 6326cb6..e5929ba 100644 --- a/frontend/src/app/annotate/annotation-input/annotation-input.component.html +++ b/frontend/src/app/annotate/annotation-input/annotation-input.component.html @@ -1,7 +1,7 @@

Parser input

- Edit the premises, conclusion and associated Knowledge Bases using the forms + Edit the premises, hypothesis and associated Knowledge Bases using the forms below.

diff --git a/frontend/src/app/annotate/annotation-input/annotation-input.component.ts b/frontend/src/app/annotate/annotation-input/annotation-input.component.ts index 7fd471e..6a2b69d 100644 --- a/frontend/src/app/annotate/annotation-input/annotation-input.component.ts +++ b/frontend/src/app/annotate/annotation-input/annotation-input.component.ts @@ -8,7 +8,6 @@ import { Validators, } from "@angular/forms"; import { - Premises, PremisesFormComponent, } from "./premises-form/premises-form.component"; import { @@ -17,11 +16,10 @@ import { } from "./knowledge-base-form/knowledge-base-form.component"; import { AnnotateService } from "../../services/annotate.service"; import { toSignal } from "@angular/core/rxjs-interop"; -import { Dataset, Judgement, ProblemResponse } from "../../types"; +import { ProblemResponse } from "../../types"; import { FontAwesomeModule } from "@fortawesome/angular-fontawesome"; import { faCheck } from "@fortawesome/free-solid-svg-icons"; import { - ProblemDetails, ProblemDetailsComponent, } from "./problem-details/problem-details.component"; import { map, Subject } from "rxjs"; @@ -34,7 +32,7 @@ type KnowledgeBaseItemsForm = FormGroup<{ export type AnnotationInputForm = FormGroup<{ premises: FormArray>; - conclusion: FormControl; + hypothesis: FormControl; kbItems: FormArray; }>; @@ -56,7 +54,7 @@ export class AnnotationInputComponent { public problem$ = this.annotateService.problem$; public form$ = this.problem$.pipe( - map((response) => this.buildForm(response)) + map((response) => this.buildForm(response)), ); private formSignal = toSignal(this.form$, { @@ -77,47 +75,21 @@ export class AnnotationInputComponent { if (form.valid) { console.log( "submitting from AnnotationInputComponent!", - form.value + form.value, ); } } - private getPremisesAndConclusion(problem: ProblemResponse): Premises { - if (!problem.problem || !problem.type) { - return { - premises: [], - conclusion: "", - }; - } - // TODO: move this to the backend. - switch (problem.type) { - case Dataset.SICK: - return { - premises: [problem.problem.sentenceOne], - conclusion: problem.problem.sentenceTwo, - }; - case Dataset.FRACAS: - return { - premises: problem.problem.premises, - conclusion: problem.problem.hypothesis, - }; - case Dataset.SNLI: - return { - premises: [problem.problem.sentenceOne], - conclusion: problem.problem.sentenceTwo, - }; - } - } - private buildForm( - response: ProblemResponse | null + response: ProblemResponse | null, ): AnnotationInputForm | null { if (!response) { return null; } - const { premises, conclusion } = - this.getPremisesAndConclusion(response); + const premises = response.problem?.premises || []; + const hypothesis = response.problem?.hypothesis || ""; + return new FormGroup({ premises: new FormArray( premises.map( @@ -125,10 +97,10 @@ export class AnnotationInputComponent { new FormControl(premise, { validators: [Validators.required], nonNullable: true, - }) - ) + }), + ), ), - conclusion: new FormControl(conclusion, { + hypothesis: new FormControl(hypothesis, { validators: [Validators.required], nonNullable: true, }), diff --git a/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.html b/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.html index 6508928..a70befa 100644 --- a/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.html +++ b/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.html @@ -2,7 +2,7 @@
-

Premises and Conclusion

+

Premises and Hypothesis

@@ -55,14 +55,14 @@

Premises and Conclusion

-
diff --git a/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.spec.ts b/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.spec.ts index 403aafd..ad9a940 100644 --- a/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.spec.ts +++ b/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.spec.ts @@ -22,8 +22,8 @@ describe("PremisesFormComponent", () => { "form", new FormGroup({ premises: new FormArray([]), - conclusion: new FormControl("", { nonNullable: true }), - }) + hypothesis: new FormControl("", { nonNullable: true }), + }), ); fixture.detectChanges(); diff --git a/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.ts b/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.ts index 4827669..4b8aaac 100644 --- a/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.ts +++ b/frontend/src/app/annotate/annotation-input/premises-form/premises-form.component.ts @@ -1,9 +1,5 @@ import { Component, input } from "@angular/core"; -import { - ReactiveFormsModule, - FormControl, - Validators, -} from "@angular/forms"; +import { ReactiveFormsModule, FormControl, Validators } from "@angular/forms"; import { CommonModule } from "@angular/common"; import { FontAwesomeModule } from "@fortawesome/angular-fontawesome"; import { faCheck, faPlus, faTrash } from "@fortawesome/free-solid-svg-icons"; @@ -11,7 +7,7 @@ import { AnnotationInputForm } from "../annotation-input.component"; export interface Premises { premises: string[]; - conclusion: string; + hypothesis: string; } @Component({ @@ -36,7 +32,7 @@ export class PremisesFormComponent { new FormControl(value, { nonNullable: true, validators: [Validators.required], - }) + }), ); } diff --git a/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.html b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.html new file mode 100644 index 0000000..faa61ff --- /dev/null +++ b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.html @@ -0,0 +1 @@ +{{ entailmentText() }} diff --git a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.scss b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.scss similarity index 100% rename from frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.scss rename to frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.scss diff --git a/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.test.ts b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.test.ts new file mode 100644 index 0000000..f2a3028 --- /dev/null +++ b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.test.ts @@ -0,0 +1,25 @@ +import { ComponentFixture, TestBed } from "@angular/core/testing"; + +import { EntailmentLabelBadgeComponent } from "./entailment-label-badge.component"; +import { EntailmentLabel } from "../../../../types"; + +describe("EntailmentLabelBadgeComponent", () => { + let component: EntailmentLabelBadgeComponent; + let fixture: ComponentFixture; + + beforeEach(async () => { + await TestBed.configureTestingModule({ + imports: [EntailmentLabelBadgeComponent], + }).compileComponents(); + + fixture = TestBed.createComponent(EntailmentLabelBadgeComponent); + component = fixture.componentInstance; + const componentRef = fixture.componentRef; + componentRef.setInput("judgement", EntailmentLabel.ENTAILMENT); + fixture.detectChanges(); + }); + + it("should create", () => { + expect(component).toBeTruthy(); + }); +}); diff --git a/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.ts b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.ts new file mode 100644 index 0000000..ddef343 --- /dev/null +++ b/frontend/src/app/annotate/annotation-input/problem-details/entailment-label-badge/entailment-label-badge.component.ts @@ -0,0 +1,41 @@ +import { Component, computed, input } from "@angular/core"; +import { EntailmentLabel } from "../../../../types"; + +@Component({ + selector: "la-entailment-label-badge", + standalone: true, + imports: [], + templateUrl: "./entailment-label-badge.component.html", + styleUrl: "./entailment-label-badge.component.scss", +}) +export class EntailmentLabelBadgeComponent { + public entailmentLabel = input.required(); + + public entailmentText = computed(() => { + const entailment = this.entailmentLabel(); + switch (entailment) { + case EntailmentLabel.ENTAILMENT: + return $localize`Entailment`; + case EntailmentLabel.CONTRADICTION: + return $localize`Contradiction`; + case EntailmentLabel.NEUTRAL: + return $localize`Neutral`; + case EntailmentLabel.UNKNOWN: + return $localize`Unknown`; + } + }); + + public entailmentClass = computed(() => { + const entailment = this.entailmentLabel(); + switch (entailment) { + case EntailmentLabel.ENTAILMENT: + return "badge text-bg-success"; + case EntailmentLabel.CONTRADICTION: + return "badge text-bg-danger"; + case EntailmentLabel.NEUTRAL: + return "badge text-bg-secondary"; + case EntailmentLabel.UNKNOWN: + return "badge text-bg-warning"; + } + }); +} diff --git a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.html b/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.html deleted file mode 100644 index 303c8cd..0000000 --- a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.html +++ /dev/null @@ -1 +0,0 @@ -{{ judgementText() }} diff --git a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.spec.ts b/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.spec.ts deleted file mode 100644 index 4b735bb..0000000 --- a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.spec.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { ComponentFixture, TestBed } from "@angular/core/testing"; - -import { JudgementBadgeComponent } from "./judgement-badge.component"; -import { Judgement } from "../../../../types"; - -describe("JudgementBadgeComponent", () => { - let component: JudgementBadgeComponent; - let fixture: ComponentFixture; - - beforeEach(async () => { - await TestBed.configureTestingModule({ - imports: [JudgementBadgeComponent], - }).compileComponents(); - - fixture = TestBed.createComponent(JudgementBadgeComponent); - component = fixture.componentInstance; - const componentRef = fixture.componentRef; - componentRef.setInput("judgement", Judgement.ENTAILMENT); - fixture.detectChanges(); - }); - - it("should create", () => { - expect(component).toBeTruthy(); - }); -}); diff --git a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.ts b/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.ts deleted file mode 100644 index a25ee78..0000000 --- a/frontend/src/app/annotate/annotation-input/problem-details/judgement-badge/judgement-badge.component.ts +++ /dev/null @@ -1,41 +0,0 @@ -import { Component, computed, input } from "@angular/core"; -import { Judgement } from "../../../../types"; - -@Component({ - selector: "la-judgement-badge", - standalone: true, - imports: [], - templateUrl: "./judgement-badge.component.html", - styleUrl: "./judgement-badge.component.scss", -}) -export class JudgementBadgeComponent { - public judgement = input.required(); - - public judgementText = computed(() => { - const judgement = this.judgement(); - switch (judgement) { - case Judgement.ENTAILMENT: - return "Entailment"; - case Judgement.CONTRADICTION: - return "Contradiction"; - case Judgement.NEUTRAL: - return "Neutral"; - case Judgement.UNKNOWN: - return "Unknown"; - } - }); - - public judgementClass = computed(() => { - const judgement = this.judgement(); - switch (judgement) { - case Judgement.ENTAILMENT: - return "badge text-bg-success"; - case Judgement.CONTRADICTION: - return "badge text-bg-danger"; - case Judgement.NEUTRAL: - return "badge text-bg-secondary"; - case Judgement.UNKNOWN: - return "badge text-bg-warning"; - } - }); -} diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html index e5da779..aef630c 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html @@ -1,45 +1,46 @@ @if (problemDetails(); as details) { -
-
- Judgement: - -
+
+
+ Entailment label: + +
- - - - - - - - - - - @if (sectionString()) { - - - - - } @if (details.comment) { - - - - - } - -
ID:{{ details.problemId }}
Dataset:{{ details.dataset }}
Section:{{ sectionString() }}
- Comment: - - {{ details.comment }} -
-
+ + + + + + + + + + + @if (sectionString()) { + + + + + } + @if (details.comment) { + + + + + } + +
ID:{{ details.problemId }}
Dataset:{{ details.dataset }}
Section:{{ sectionString() }}
+ Comment: + + {{ details.comment }} +
+
} diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.scss b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.scss index d664b51..dd4f874 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.scss +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.scss @@ -2,7 +2,7 @@ position: relative; } -.judgement-badge-container { +.entailment-label-container { position: absolute; top: 1em; right: 1em; diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts index dbdcfc2..0e010ca 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts @@ -17,13 +17,9 @@ describe("ProblemDetailsComponent", () => { const componentRef = fixture.componentRef; componentRef.setInput("problem", { problem: { - pairId: "1", - sentenceOne: "This is a sentence.", - sentenceTwo: "This is another sentence.", - entailmentLabel: "NEUTRAL", - relatednessScore: 0.5, + id: 1 }, - type: Dataset.SICK, + dataset: Dataset.SICK, }); fixture.detectChanges(); }); diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts index a20a7a2..4fbff48 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts @@ -1,45 +1,27 @@ import { Component, computed, input } from "@angular/core"; -import { Dataset, Judgement, ProblemResponse } from "../../../types"; -import { JudgementBadgeComponent } from "./judgement-badge/judgement-badge.component"; +import { Dataset, EntailmentLabel, ProblemResponse } from "../../../types"; +import { EntailmentLabelBadgeComponent } from "./entailment-label-badge/entailment-label-badge.component"; import { faQuestionCircle } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeModule } from "@fortawesome/angular-fontawesome"; import { NgbTooltipModule } from "@ng-bootstrap/ng-bootstrap"; -import { AnnotateService } from "../../../services/annotate.service"; -import { toSignal } from "@angular/core/rxjs-interop"; export interface ProblemDetails { problemId: string; dataset: Dataset; - judgement: Judgement; + entailmentLabel: EntailmentLabel; section: string | null; subsection: string | null; comment: string | null; } -const judgementMap: Record> = { - [Dataset.SICK]: { - ENTAILMENT: Judgement.ENTAILMENT, - CONTRADICTION: Judgement.CONTRADICTION, - NEUTRAL: Judgement.NEUTRAL, - }, - [Dataset.FRACAS]: { - yes: Judgement.ENTAILMENT, - no: Judgement.CONTRADICTION, - unknown: Judgement.NEUTRAL, - undefined: Judgement.UNKNOWN, - }, - [Dataset.SNLI]: { - entailment: Judgement.ENTAILMENT, - contradiction: Judgement.CONTRADICTION, - neutral: Judgement.NEUTRAL, - none: Judgement.UNKNOWN, - }, -}; - @Component({ selector: "la-problem-details", standalone: true, - imports: [JudgementBadgeComponent, FontAwesomeModule, NgbTooltipModule], + imports: [ + EntailmentLabelBadgeComponent, + FontAwesomeModule, + NgbTooltipModule, + ], templateUrl: "./problem-details.component.html", styleUrl: "./problem-details.component.scss", }) @@ -74,68 +56,50 @@ export class ProblemDetailsComponent { }); private extractDetails( - response: ProblemResponse | null + response: ProblemResponse | null, ): ProblemDetails | null { if (!response?.problem) { return null; } - const judgement = this.getJudgement(response); - switch (response.type) { + + const shared: Pick< + ProblemDetails, + "problemId" | "dataset" | "entailmentLabel" + > = { + problemId: response.problem.id.toString(), + dataset: response.problem.dataset, + entailmentLabel: response.problem.entailmentLabel, + }; + + switch (response.problem.dataset) { case Dataset.SICK: return { - problemId: response.problem.pairId.toString(), - dataset: response.type, - judgement, + ...shared, section: null, subsection: null, comment: null, }; case Dataset.FRACAS: return { - problemId: response.problem.fracasId.toString(), - dataset: response.type, - judgement, - section: response.problem.sectionName, - subsection: response.problem.subsectionName, - comment: response.problem.note || null, + ...shared, + section: response.problem.extraData.sectionName, + subsection: response.problem.extraData.subsectionName, + comment: response.problem.extraData.note || null, }; case Dataset.SNLI: return { - problemId: response.problem.pairId.toString(), - dataset: response.type, - judgement, + ...shared, + section: null, + subsection: null, + comment: null, + }; + case Dataset.USER: + return { + ...shared, section: null, subsection: null, comment: null, }; } } - - private getJudgement(response: ProblemResponse): Judgement { - // This should never happen, as we check for a problem in the calling - // function, but TypeScript does not know this. - if (!response.problem) { - return Judgement.UNKNOWN; - } - - const { type, problem } = response; - // Use the judgementMap to get the judgement based on the dataset and - // the problem's entailment label or answer. - // TODO: move this to the backend. - const label = - type === Dataset.SICK - ? problem.entailmentLabel - : type === Dataset.FRACAS - ? problem.fracasAnswer - : type === Dataset.SNLI - ? problem.goldLabel - : undefined; - - if (!label) { - // If the label is not defined, we return UNKNOWN. - return Judgement.UNKNOWN; - } - - return judgementMap[response.type][label]; - } } diff --git a/frontend/src/app/services/annotate.service.ts b/frontend/src/app/services/annotate.service.ts index 8019419..d70bad6 100644 --- a/frontend/src/app/services/annotate.service.ts +++ b/frontend/src/app/services/annotate.service.ts @@ -1,5 +1,5 @@ import { Injectable } from "@angular/core"; -import { catchError, Observable, of, share, shareReplay, Subject, switchMap } from "rxjs"; +import { catchError, Observable, of, shareReplay, Subject, switchMap } from "rxjs"; import { HttpClient } from "@angular/common/http"; import { ProblemResponse, ProofBankStats } from "../types"; diff --git a/frontend/src/app/types.ts b/frontend/src/app/types.ts index fb973a8..31526f8 100644 --- a/frontend/src/app/types.ts +++ b/frontend/src/app/types.ts @@ -1,60 +1,66 @@ -export interface SickProblem { +interface SickData { pairId: number; - sentenceOne: string; - sentenceTwo: string; - entailmentLabel: "NEUTRAL" | "CONTRADICTION" | "ENTAILMENT"; relatednessScore: number; } -export interface FracasProblem { +interface FracasData { fracasId: number; question: string; - hypothesis: string; answer: string; - fracasAnswer: "yes" | "no" | "unknown" | "undefined"; - fracasNonStandard: boolean; note: string; sectionName: string; subsectionName: string; - premises: string[]; + fracasNonStandard: boolean; } -type SNLILabel = "neutral" | "contradiction" | "entailment" | "none"; - -export interface SNLIProblem { +interface SNLIData { pairId: number; - subset: 'dev' | 'test' | 'train'; - sentenceOne: string; - sentenceTwo: string; - goldLabel: SNLILabel; - labels: SNLILabel[]; + subset: "dev" | "test" | "train"; + label1: string; + label2: string; + label3: string; + label4: string; + label5: string; } -interface ProblemResponseBase { +interface ProblemBase { id: number; - index: number | null; - error: string | null; - next: string | null; - previous: string | null; - random: string | null; + premises: string[]; + hypothesis: string | null; + entailmentLabel: EntailmentLabel; } -interface SickProblemResponse extends ProblemResponseBase { - problem: SickProblem | null; - type: Dataset.SICK; +interface SickProblem extends ProblemBase { + dataset: Dataset.SICK; + extraData: SickData; } -interface FracasProblemResponse extends ProblemResponseBase { - problem: FracasProblem | null; - type: Dataset.FRACAS; +interface FracasProblem extends ProblemBase { + dataset: Dataset.FRACAS; + extraData: FracasData; } -export interface SNLIProblemResponse extends ProblemResponseBase { - problem: SNLIProblem | null; - type: Dataset.SNLI; +interface SNLIProblem extends ProblemBase { + dataset: Dataset.SNLI; + extraData: SNLIData; } -export type ProblemResponse = SickProblemResponse | FracasProblemResponse | SNLIProblemResponse; +interface UserProblem extends ProblemBase { + dataset: Dataset.USER; + extraData: null; +} + +type Problem = SickProblem | FracasProblem | SNLIProblem | UserProblem; + +export interface ProblemResponse { + id: number; + index: number | null; + next: string | null; + previous: string | null; + random: string | null; + error: string | null; + problem: Problem | null; +} export interface ProofBankStats { firstProblemId: string; @@ -66,9 +72,10 @@ export enum Dataset { SICK = "sick", FRACAS = "fracas", SNLI = "snli", + USER = "user", } -export enum Judgement { +export enum EntailmentLabel { ENTAILMENT = "entailment", CONTRADICTION = "contradiction", NEUTRAL = "neutral",