Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lexy/api/endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from lexy.db.session import get_session
from lexy.models.document import Document
from lexy.models.embedding import Embedding, EmbeddingCreate
from lexy.transformers.embeddings import custom_transformer, get_default_transformer
from lexy.transformers.embeddings import custom_transformer
from lexy.models.transformer import Transformer


router = APIRouter()
Expand Down Expand Up @@ -44,8 +45,10 @@ async def add_embeddings(embeddings: list[EmbeddingCreate], session: AsyncSessio
name="query_embeddings",
description="Query for similar documents")
async def query_embeddings(query_string: str, k: int = 5, session: AsyncSession = Depends(get_session)) -> dict:
transformer_result = await session.execute(select(Transformer).where(Transformer.transformer_id == "text.embeddings.minilm"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note I'm hard-coding text.embeddings.minilm for now. That is changed upstack.

transformer = transformer_result.scalar_one()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will raise if it fails. We can catch and rethrow a custom error to improve DX.

doc = Document(content=query_string)
task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10)
task = custom_transformer.apply_async(args=[doc, transformer.code], priority=10)
result = task.get()
query_embedding = result.tolist()
search_result = await session.execute(
Expand Down
7 changes: 5 additions & 2 deletions lexy/api/endpoints/index_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from lexy.db.session import get_session
from lexy.models.document import Document
from lexy.models.index import Index
from lexy.transformers.embeddings import custom_transformer, get_default_transformer
from lexy.transformers.embeddings import custom_transformer
from lexy.models.transformer import Transformer


router = APIRouter()
Expand Down Expand Up @@ -43,7 +44,9 @@ async def query_records(query_string: str, k: int = 5, query_field: str = "embed

# get embedding for query string
doc = Document(content=query_string)
task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10)
transformer_result = await session.execute(select(Transformer).where(Transformer.transformer_id == "text.embeddings.minilm"))
transformer = transformer_result.scalar_one()
task = custom_transformer.apply_async(args=[doc, transformer.code], priority=10)
result = task.get()
query_embedding = result.tolist()

Expand Down
28 changes: 1 addition & 27 deletions lexy/api/endpoints/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lexy.models.document import Document
from lexy.models.transformer import Transformer, TransformerCreate, TransformerUpdate
from lexy.transformers.counter import count_words
from lexy.transformers.embeddings import get_chunks, just_split, custom_transformer, get_default_transformer
from lexy.transformers.embeddings import get_chunks, just_split


router = APIRouter()
Expand Down Expand Up @@ -87,32 +87,6 @@ async def delete_transformer(transformer_id: str, session: AsyncSession = Depend
await session.commit()
return {"Say": "Transformer deleted!"}


@router.post("/embed/string",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth maintaining this?

response_model=dict,
status_code=status.HTTP_200_OK,
name="embed_string",
description="Get embeddings for query string")
async def embed_string(string: str) -> dict:
doc = Document(content=string)
task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10)
result = task.get()
return {"embedding": result.tolist()}


@router.post("/embed/documents",
response_model=dict,
status_code=status.HTTP_202_ACCEPTED,
name="embed_documents",
description="Create embeddings for a list of documents")
async def embed_documents(documents: List[Document], index_id: str = "default_text_embeddings") -> dict:
tasks = []
for doc in documents:
task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10)
tasks.append({"task_id": task.id, "document_id": doc.document_id})
return {"tasks": tasks}


@router.post("/count",
response_model=dict,
status_code=status.HTTP_200_OK,
Expand Down
10 changes: 8 additions & 2 deletions lexy/db/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,14 @@
},
"transformer_1": {
"transformer_id": "text.embeddings.minilm",
"path": "lexy.transformers.embeddings.text_embeddings",
"description": "Text embeddings using Hugging Face model 'sentence-transformers/all-MiniLM-L6-v2'"
"description": "Text embeddings using Hugging Face model 'sentence-transformers/all-MiniLM-L6-v2'",
"code": """import torch
from sentence_transformers import SentenceTransformer
torch.set_num_threads(1)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

def transform(document):
return model.encode([document.content], batch_size=len([document.content]))"""
},
"index_1": {
"index_id": "default_text_embeddings",
Expand Down
5 changes: 2 additions & 3 deletions lexy/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ class TransformerBase(SQLModel):
max_length=255,
regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$"
)
path: Optional[str] = Field(
code: str = Field(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to reset your DB state as a result of this (at least by deleting the bindings and transformers models)!

default=None,
min_length=1,
max_length=255,
regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$"
max_length=255*255,
)
description: Optional[str] = None

Expand Down
11 changes: 0 additions & 11 deletions lexy/transformers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,6 @@
torch.set_num_threads(1)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

def get_default_transformer():
return """
import torch
from sentence_transformers import SentenceTransformer
torch.set_num_threads(1)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

def transform(document):
return model.encode([document.content], batch_size=len([document.content]))
"""

@shared_task(name="custom_transformer")
def custom_transformer(document: Document, transformer: str) -> list[dict]:
""" Apply a custom transformer to a document.
Expand Down