Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/detect_document_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

if __name__ == "__main__":
scanner = Textract()
result = scanner.detect_document_type("s3://document-extractor-gsa-dev-documents/test_dd214.jpg")
result = scanner.extract_raw_text("s3://document-extractor-gsa-dev-documents/test_dd214.jpg")

print(f"Document type is {result}")
print(f"Raw text is {result}")
11 changes: 11 additions & 0 deletions backend/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from src.external.ocr.textract import Textract
from src.forms.w2 import W2

if __name__ == "__main__":
scanner = Textract()
form = W2()
result = scanner.scan("s3://document-extractor-gsa-dev-documents/test_w2.jpg", queries=form.queries())

for key, value in result.items():
print(key)
print(f"\t{value}")
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"boto3>=1.37.1",
"iterator-chain>=1.1.0",
]

[tool.ruff]
Expand Down
15 changes: 12 additions & 3 deletions backend/src/external/lambda/text_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import boto3

from src.external.ocr.textract import Textract
from src.forms import supported_forms
from src.ocr import OcrException

s3_client = boto3.client("s3")
Expand Down Expand Up @@ -36,7 +37,7 @@ def lambda_handler(event, context):
ocr_engine = Textract()

try:
document_type = ocr_engine.detect_document_type(f"s3://{bucket_name}/{document_key}")
document_text = ocr_engine.extract_raw_text(f"s3://{bucket_name}/{document_key}")
except OcrException as e:
exception_message = f"Failed to detect the document type of s3://{bucket_name}/{document_key}: {e}"
print(exception_message)
Expand All @@ -45,8 +46,16 @@ def lambda_handler(event, context):
"body": json.dumps(exception_message),
}

identified_form = None

for text in document_text:
for form in supported_forms:
if form.form_matches() in text:
identified_form = form
break

try:
extracted_data = ocr_engine.scan(f"s3://{bucket_name}/{document_key}")
extracted_data = ocr_engine.scan(f"s3://{bucket_name}/{document_key}", queries=identified_form.queries())
except OcrException as e:
exception_message = f"Failed to extract text from S3 object s3://{bucket_name}/{document_key}: {e}"
print(exception_message)
Expand All @@ -63,7 +72,7 @@ def lambda_handler(event, context):
{
"document_key": document_key,
"extracted_data": extracted_data,
"document_type": document_type,
"document_type": identified_form.identifier(),
}
),
)
Expand Down
231 changes: 149 additions & 82 deletions backend/src/external/ocr/textract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import statistics
from typing import Any
from urllib import parse

import boto3
import iterator_chain

from src.ocr import Ocr, OcrException

Expand All @@ -10,35 +13,53 @@ class Textract(Ocr):
def __init__(self) -> None:
self.textract_client = boto3.client("textract")

def detect_document_type(self, s3_url: str) -> str | None:
def scan(self, s3_url: str, queries: list[str] | None = None) -> dict[str, dict[str, str | float]]:
try:
# Parse the S3 URL
bucket_name, object_key = self._parse_s3_url(s3_url)

if queries is None or len(queries) == 0:
print("Attempting AnalyzeDocument with forms and tables")
response = self.textract_client.analyze_document(
Document={"S3Object": {"Bucket": bucket_name, "Name": object_key}},
FeatureTypes=["FORMS"],
)
print("Parsing result")
extracted_data = self._parse_textract_forms(response)
else:
print("Attempting AnalyzeDocument with queries")
response_list = asyncio.run(self._paginated_textract_with_queries(queries, bucket_name, object_key))
print("Parsing result")
extracted_data = (
iterator_chain.from_iterable(response_list)
.map(self._parse_textract_queries)
.reduce(lambda a_dict, b_dict: {**a_dict, **b_dict}, initial={})
)

return extracted_data

except Exception as e:
raise OcrException(f"Unable to OCR the image {s3_url}") from e

def extract_raw_text(self, s3_url: str) -> list[str]:
try:
bucket_name, object_key = self._parse_s3_url(s3_url)

response = self.textract_client.detect_document_text(
Document={"S3Object": {"Bucket": bucket_name, "Name": object_key}}
)

document_type = None

for block in response.get("Blocks", []):
if block.get("BlockType") != "WORD" and block.get("BlockType") != "LINE":
continue

if block.get("Text") == "W-2":
document_type = "W2"
break
elif block.get("Text") == "1099-NEC":
document_type = "1099-NEC"
break
elif block.get("Text").startswith("DD FORM 214"):
document_type = "DD214"
break
return (
iterator_chain.from_iterable(response.get("Blocks", []))
.filter(lambda block: block["BlockType"] == "LINE")
.filter(lambda block: "Text" in block)
.map(lambda block: block["Text"])
.list()
)

except Exception as e:
raise OcrException(f"Failure while trying to detect the document type of {s3_url}") from e

return document_type

def _parse_s3_url(self, s3_url: str) -> tuple[str, str]:
parsed_url = parse.urlparse(s3_url)

Expand All @@ -53,85 +74,131 @@ def _parse_s3_url(self, s3_url: str) -> tuple[str, str]:

return bucket_name, object_key

def scan(self, s3_url: str) -> dict[str, dict[str, Any]]:
try:
# Parse the S3 URL
bucket_name, object_key = self._parse_s3_url(s3_url)
def _split_list_by_30(self, the_list: list[Any]) -> list[list[Any]]:
sublist_size = 30
return [the_list[i : i + sublist_size] for i in range(0, len(the_list), sublist_size)]

async def _paginated_textract_with_queries(self, queries, bucket_name, object_key) -> list[Any]:
queries_config = [{"Text": query, "Pages": ["*"]} for query in queries]

paginated_queries_config = self._split_list_by_30(queries_config)

tasks = [
asyncio.create_task(self._call_textract_with_queries(bucket_name, object_key, sub_queries_config))
for sub_queries_config in paginated_queries_config
]
results_list = await asyncio.gather(*tasks)
return results_list

async def _call_textract_with_queries(self, bucket_name, object_key, queries_config):
print("Initiating document analysis")
initiate_response = self.textract_client.start_document_analysis(
DocumentLocation={"S3Object": {"Bucket": bucket_name, "Name": object_key}},
FeatureTypes=["QUERIES"],
QueriesConfig={"Queries": queries_config},
)
job_id = initiate_response["JobId"]
response = self.textract_client.get_document_analysis(JobId=job_id)
while response["JobStatus"] == "IN_PROGRESS":
await asyncio.sleep(1)
print(f"Checking if job {job_id} is complete")
response = self.textract_client.get_document_analysis(JobId=job_id)

print(f"Completed document analysis for job {job_id}")
return response

def _parse_textract_queries(self, textract_response):
extracted_data = {}

# Download the image
print("Attempting AnalyzeDocument (Structured Mode)...")
response = self.textract_client.analyze_document(
Document={"S3Object": {"Bucket": bucket_name, "Name": object_key}},
FeatureTypes=["FORMS", "TABLES"],
)
extracted_data = self._parse_textract_analyze_document_response(response)
blocks = textract_response.get("Blocks", [])
query_blocks = []
query_result_blocks = {}

# Check if AnalyzeDocument works
if not extracted_data:
print("AnalyzeDocument yielded no data. Falling back to DetectDocumentText...")
response = self.textract_client.detect_document_text(
Document={"S3Object": {"Bucket": bucket_name, "Name": object_key}}
)
extracted_data = self._parse_ocr_response(response)
for block in blocks:
if block["BlockType"] == "QUERY":
query_blocks.append(block)
elif block["BlockType"] == "QUERY_RESULT":
query_result_blocks[block["Id"]] = block

return extracted_data
for query_block in query_blocks:
value, confidence = self._get_text_and_confidence_from_relationship_blocks(
query_block, query_result_blocks, "ANSWER"
)

except Exception as e:
raise OcrException(f"Unable to OCR the image {s3_url}") from e
extracted_data[query_block["Query"]["Text"]] = {"value": value, "confidence": confidence}

return extracted_data

def _parse_textract_analyze_document_response(self, response):
def _parse_textract_forms(self, response):
"""Parses structured data from AnalyzeDocument response into a simple key-value format."""
extracted_data = {}
block_map = {block["Id"]: block for block in response.get("Blocks", [])}

# Extract form data
for block in response.get("Blocks", []):
if block["BlockType"] == "KEY_VALUE_SET" and "KEY" in block.get("EntityTypes", []):
key_text, key_conf = self._get_text_from_block(block, block_map)
value_text = ""
for rel in block.get("Relationships", []):
if rel["Type"] == "VALUE":
for value_id in rel["Ids"]:
value_block = block_map.get(value_id)
if value_block:
value_text, value_conf = self._get_text_from_block(value_block, block_map)
if key_text:
extracted_data[key_text] = {"value": value_text, "confidence": key_conf}
if block["BlockType"] != "KEY_VALUE_SET" or "KEY" not in block.get("EntityTypes", []):
continue

return extracted_data
key_text, key_confidence = self._get_text_and_confidence_from_relationship_blocks(block, block_map, "CHILD")

def _get_text_from_block(self, block, block_map):
"""Helper to extract text from a block."""
text = ""
confidence = block.get("Confidence", 0.0)
if "Relationships" in block:
for rel in block["Relationships"]:
if rel["Type"] == "CHILD":
for child_id in rel["Ids"]:
word_block = block_map.get(child_id)
if word_block and word_block.get("Text"):
text += word_block["Text"] + " "
return text.strip(), confidence

def _parse_ocr_response(self, response):
"""Parses text from DetectDocumentText response with pseudo-keys based on content."""
extracted_data = {}
line_count = 1
relationships = block.get("Relationships", [])

for block in response.get("Blocks", []):
if block["BlockType"] == "LINE":
line_text = block.get("DetectedText", "")
confidence = block.get("Confidence", 0.0)
value_texts = []
value_confidences = []

# Generate a key based on the first 3 words
words = line_text.split()[:3]
key = "_".join(words).replace(":", "").replace(".", "").strip() or f"Line_{line_count}"
for relationship in relationships:
if relationship["Type"] != "VALUE":
continue

for related_value_block_id in relationship["Ids"]:
value_block = block_map[related_value_block_id]
value_text, value_confidence = self._get_text_and_confidence_from_relationship_blocks(
value_block, block_map, "CHILD"
)

# Ensure key uniqueness
while key in extracted_data:
key += f"_{line_count}"
if value_text != "":
value_texts.append(value_text)
value_confidences.append(value_confidence)

extracted_data[key] = {"value": line_text, "confidence": confidence}
line_count += 1
confidence = -1
if len(value_texts) > 0:
confidence = statistics.fmean(value_confidences)
extracted_data[key_text] = {"value": " ".join(value_texts), "confidence": confidence}

return extracted_data

def _get_text_and_confidence_from_relationship_blocks(
self, block: Any, blocks: dict[str, Any], wanted_relationship: str
) -> tuple[str, float]:
relationships = block.get("Relationships", [])

texts = []
confidences = []

for relationship in relationships:
if relationship["Type"] != wanted_relationship:
continue

related_blocks = [blocks[related_block_id] for related_block_id in relationship.get("Ids", [])]

relation_texts = []
relation_confidences = []

for related_block in related_blocks:
if "Text" not in related_block:
continue

relation_texts.append(related_block["Text"])
relation_confidences.append(related_block["Confidence"])

if len(relation_texts) > 0:
relation_text = " ".join(relation_texts)
texts.append(relation_text)

relation_confidence = statistics.fmean(relation_confidences)
confidences.append(relation_confidence)

confidence = -1
if len(confidences) > 0:
confidence = statistics.fmean(confidences)

return " ".join(texts), confidence
12 changes: 12 additions & 0 deletions backend/src/forms/1099.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from src.forms.form import Form


class TenNinetyNineNec(Form):
def identifier(self) -> str:
return "1099-NEC"

def form_matches(self) -> str:
return "1099-NEC"

def queries(self) -> list[str]:
return []
27 changes: 27 additions & 0 deletions backend/src/forms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import importlib
import inspect
import os
import pkgutil

from src.forms.form import Form


def find_form_implementations():
implementations = []

# Iterate over all modules in the package
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
# Import the module
module = importlib.import_module(f"{__name__}.{module_name}")

# Iterate over all classes in the module
for _, clazz in inspect.getmembers(module, inspect.isclass):
# Check if the class is a subclass of the base class and is not the base class itself
if issubclass(clazz, Form) and clazz is not Form:
instantiation = clazz()
implementations.append(instantiation)

return implementations


supported_forms = find_form_implementations()
Loading