-
Notifications
You must be signed in to change notification settings - Fork 0
Generalize custom transformer code by loading from DB #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: @jnnnthnn/Prototype_using_exec_to_run_transformer_code
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,8 @@ | |
| from lexy.db.session import get_session | ||
| from lexy.models.document import Document | ||
| from lexy.models.embedding import Embedding, EmbeddingCreate | ||
| from lexy.transformers.embeddings import custom_transformer, get_default_transformer | ||
| from lexy.transformers.embeddings import custom_transformer | ||
| from lexy.models.transformer import Transformer | ||
|
|
||
|
|
||
| router = APIRouter() | ||
|
|
@@ -44,8 +45,10 @@ async def add_embeddings(embeddings: list[EmbeddingCreate], session: AsyncSessio | |
| name="query_embeddings", | ||
| description="Query for similar documents") | ||
| async def query_embeddings(query_string: str, k: int = 5, session: AsyncSession = Depends(get_session)) -> dict: | ||
| transformer_result = await session.execute(select(Transformer).where(Transformer.transformer_id == "text.embeddings.minilm")) | ||
| transformer = transformer_result.scalar_one() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will |
||
| doc = Document(content=query_string) | ||
| task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10) | ||
| task = custom_transformer.apply_async(args=[doc, transformer.code], priority=10) | ||
| result = task.get() | ||
| query_embedding = result.tolist() | ||
| search_result = await session.execute( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| from lexy.models.document import Document | ||
| from lexy.models.transformer import Transformer, TransformerCreate, TransformerUpdate | ||
| from lexy.transformers.counter import count_words | ||
| from lexy.transformers.embeddings import get_chunks, just_split, custom_transformer, get_default_transformer | ||
| from lexy.transformers.embeddings import get_chunks, just_split | ||
|
|
||
|
|
||
| router = APIRouter() | ||
|
|
@@ -87,32 +87,6 @@ async def delete_transformer(transformer_id: str, session: AsyncSession = Depend | |
| await session.commit() | ||
| return {"Say": "Transformer deleted!"} | ||
|
|
||
|
|
||
| @router.post("/embed/string", | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it's worth maintaining this? |
||
| response_model=dict, | ||
| status_code=status.HTTP_200_OK, | ||
| name="embed_string", | ||
| description="Get embeddings for query string") | ||
| async def embed_string(string: str) -> dict: | ||
| doc = Document(content=string) | ||
| task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10) | ||
| result = task.get() | ||
| return {"embedding": result.tolist()} | ||
|
|
||
|
|
||
| @router.post("/embed/documents", | ||
| response_model=dict, | ||
| status_code=status.HTTP_202_ACCEPTED, | ||
| name="embed_documents", | ||
| description="Create embeddings for a list of documents") | ||
| async def embed_documents(documents: List[Document], index_id: str = "default_text_embeddings") -> dict: | ||
| tasks = [] | ||
| for doc in documents: | ||
| task = custom_transformer.apply_async(args=[doc, get_default_transformer()], priority=10) | ||
| tasks.append({"task_id": task.id, "document_id": doc.document_id}) | ||
| return {"tasks": tasks} | ||
|
|
||
|
|
||
| @router.post("/count", | ||
| response_model=dict, | ||
| status_code=status.HTTP_200_OK, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,11 +13,10 @@ class TransformerBase(SQLModel): | |
| max_length=255, | ||
| regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$" | ||
| ) | ||
| path: Optional[str] = Field( | ||
| code: str = Field( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You will need to reset your DB state as a result of this (at least by deleting the bindings and transformers models)! |
||
| default=None, | ||
| min_length=1, | ||
| max_length=255, | ||
| regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$" | ||
| max_length=255*255, | ||
| ) | ||
| description: Optional[str] = None | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note I'm hard-coding
text.embeddings.minilmfor now. That is changed upstack.