Skip to content

Commit

Permalink
initial asset collection implementation from selim call
Browse files Browse the repository at this point in the history
  • Loading branch information
samj committed Oct 3, 2024
1 parent 56bae51 commit fb6e852
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 18 deletions.
161 changes: 148 additions & 13 deletions backend/managers/AssetsManager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from uuid import uuid4
from threading import Lock
from sqlalchemy import select, insert, update, delete, func, or_
from backend.models import Asset
from backend.models import Asset, AssetCollection
from backend.db import db_session_context
from backend.schemas import AssetSchema, AssetCreateSchema
from backend.schemas import AssetSchema, AssetCreateSchema, AssetCollectionSchema, AssetCollectionCreateSchema
from typing import List, Tuple, Optional, Dict, Any
from backend.processors.BaseProcessor import BaseProcessor
from backend.processors.SimpleTextSplitter import SimpleTextSplitter
from backend.processors.SimpleEmbedder import SimpleEmbedder
from backend.processors.SimpleVectorStore import SimpleVectorStore

class AssetsManager:
_instance = None
Expand All @@ -21,16 +25,144 @@ def __init__(self):
if not hasattr(self, '_initialized'):
with self._lock:
if not hasattr(self, '_initialized'):
# db.init_db()
self._initialized = True
self._load_processors()

def _load_processors(self):
self.text_splitters = [SimpleTextSplitter()]
self.embedders = [SimpleEmbedder()]
self.vector_stores = [SimpleVectorStore()]

# Load additional processors if available
try:
from backend.processors.LangchainTextSplitter import LangchainTextSplitter
self.text_splitters.append(LangchainTextSplitter())
except ImportError:
pass

try:
from backend.processors.OpenAIEmbedder import OpenAIEmbedder
self.embedders.append(OpenAIEmbedder())
except ImportError:
pass

try:
from backend.processors.ChromaVectorStore import ChromaVectorStore
self.vector_stores.append(ChromaVectorStore())
except ImportError:
pass

async def create_asset_collection(self, collection_data: AssetCollectionCreateSchema) -> AssetCollectionSchema:
async with db_session_context() as session:
new_collection = AssetCollection(id=str(uuid4()), **collection_data.model_dump())
session.add(new_collection)
await session.commit()
await session.refresh(new_collection)
return AssetCollectionSchema.from_orm(new_collection)

async def process_asset_collection(self, collection_id: str, text_splitter: str, embedder: str, vector_store: str):
collection = await self.retrieve_asset_collection(collection_id)
if not collection:
raise ValueError(f"Asset collection with id {collection_id} not found")

text_splitter = next((ts for ts in self.text_splitters if ts.__class__.__name__ == text_splitter), None)
embedder = next((emb for emb in self.embedders if emb.__class__.__name__ == embedder), None)
vector_store = next((vs for vs in self.vector_stores if vs.__class__.__name__ == vector_store), None)

if not all([text_splitter, embedder, vector_store]):
raise ValueError("Invalid processor selection")

if collection.track_individual_assets:
assets = await self.retrieve_assets(filters={"collection_id": collection_id})
for asset in assets:
await self._process_asset(asset, text_splitter, embedder, vector_store)
else:
await self._process_large_collection(collection, text_splitter, embedder, vector_store)

async def _process_asset(self, asset: Asset, text_splitter: BaseProcessor, embedder: BaseProcessor, vector_store: BaseProcessor):
chunks = await text_splitter.process(asset.content)
embeddings = await embedder.process(chunks)
await vector_store.process(asset.id, chunks, embeddings)

async def _process_large_collection(self, collection: AssetCollection, text_splitter: BaseProcessor, embedder: BaseProcessor, vector_store: BaseProcessor):
# Implement a method to process large collections in batches
# This could involve streaming data from an external source, like an email server
# For demonstration, we'll use a dummy generator
for batch in self._large_collection_batch_generator(collection):
chunks = await text_splitter.process(batch)
embeddings = await embedder.process(chunks)
await vector_store.process(f"{collection.id}_{uuid4()}", chunks, embeddings)

def _large_collection_batch_generator(self, collection: AssetCollection):
# This is a dummy generator. In a real scenario, this would fetch data from the actual source
for i in range(10): # Simulate 10 batches
yield f"This is batch {i} of collection {collection.name}"

async def retrieve_asset_collection(self, id: str) -> Optional[AssetCollectionSchema]:
async with db_session_context() as session:
result = await session.execute(select(AssetCollection).filter(AssetCollection.id == id))
collection = result.scalar_one_or_none()
if collection:
return AssetCollectionSchema(
id=collection.id,
name=collection.name,
description=collection.description
)
return None

async def retrieve_asset_collections(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None,
sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None,
query: Optional[str] = None) -> Tuple[List[AssetCollectionSchema], int]:
async with db_session_context() as session:
stmt = select(AssetCollection)

if filters:
for key, value in filters.items():
if isinstance(value, list):
stmt = stmt.filter(getattr(AssetCollection, key).in_(value))
else:
stmt = stmt.filter(getattr(AssetCollection, key) == value)

if query:
search_condition = or_(
AssetCollection.name.ilike(f"%{query}%"),
AssetCollection.description.ilike(f"%{query}%")
)
stmt = stmt.filter(search_condition)

if sort_by and hasattr(AssetCollection, sort_by):
order_column = getattr(AssetCollection, sort_by)
stmt = stmt.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column)

stmt = stmt.offset(offset).limit(limit)

result = await session.execute(stmt)
collections = [AssetCollectionSchema(
id=collection.id,
name=collection.name,
description=collection.description
) for collection in result.scalars().all()]

# Get total count
count_stmt = select(func.count()).select_from(AssetCollection)
if filters or query:
count_stmt = count_stmt.filter(stmt.whereclause)
total_count = await session.execute(count_stmt)
total_count = total_count.scalar()

return collections, total_count

async def create_asset(self, asset_data: AssetCreateSchema) -> AssetSchema:
collection = await self.retrieve_asset_collection(asset_data.collection_id)
if not collection.track_individual_assets:
raise ValueError("This collection does not track individual assets")

async with db_session_context() as session:
new_asset = Asset(id=str(uuid4()), **asset_data.model_dump())
session.add(new_asset)
await session.commit()
await session.refresh(new_asset)
return AssetSchema(id=new_asset.id, **asset_data.model_dump())
return AssetSchema.from_orm(new_asset)

async def update_asset(self, id: str, asset_data: AssetCreateSchema) -> Optional[AssetSchema]:
async with db_session_context() as session:
Expand Down Expand Up @@ -90,23 +222,26 @@ async def retrieve_assets(self, offset: int = 0, limit: int = 100, sort_by: Opti
order_column = getattr(Asset, sort_by)
stmt = stmt.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column)

# Add a join to check if the collection tracks individual assets
stmt = stmt.join(AssetCollection).filter(AssetCollection.track_individual_assets == True)

stmt = stmt.offset(offset).limit(limit)

result = await session.execute(stmt)
assets = [AssetSchema(
id=asset.id,
title=asset.title,
user_id=asset.user_id,
creator=asset.creator,
subject=asset.subject,
description=asset.description
) for asset in result.scalars().all()]
assets = [AssetSchema.from_orm(asset) for asset in result.scalars().all()]

# Get total count
count_stmt = select(func.count()).select_from(Asset)
count_stmt = select(func.count()).select_from(Asset).join(AssetCollection).filter(AssetCollection.track_individual_assets == True)
if filters or query:
count_stmt = count_stmt.filter(stmt.whereclause)
total_count = await session.execute(count_stmt)
total_count = total_count.scalar()

return assets, total_count

def get_available_processors(self):
return {
"text_splitters": [ts.__class__.__name__ for ts in self.text_splitters if ts.is_available()],
"embedders": [emb.__class__.__name__ for emb in self.embedders if emb.is_available()],
"vector_stores": [vs.__class__.__name__ for vs in self.vector_stores if vs.is_available()]
}
19 changes: 15 additions & 4 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,22 @@ class Session(SQLModelBase, table=True):
expires_at: datetime = Field()
user: User = Relationship(back_populates="sessions")

class AssetCollection(SQLModelBase, table=True):
id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
name: str = Field()
description: Optional[str] = Field(default=None)
track_individual_assets: bool = Field(default=True)
assets: List["Asset"] = Relationship(back_populates="collection")

class Asset(SQLModelBase, table=True):
id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
user_id: str | None = Field(default=None, foreign_key="user.id")
collection_id: str = Field(foreign_key="assetcollection.id")
title: str = Field()
creator: str | None = Field(default=None)
subject: str | None = Field(default=None)
description: str | None = Field(default=None)
content: str = Field(default="")
creator: Optional[str] = Field(default=None)
subject: Optional[str] = Field(default=None)
description: Optional[str] = Field(default=None)
collection: AssetCollection = Relationship(back_populates="assets")

class Persona(SQLModelBase, table=True):
id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
Expand All @@ -69,3 +78,5 @@ class Share(SQLModelBase, table=True):
User.model_rebuild()
Cred.model_rebuild()
Session.model_rebuild()
Asset.model_rebuild()
AssetCollection.model_rebuild()
13 changes: 13 additions & 0 deletions backend/processors/BaseProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any
from abc import ABC, abstractmethod

class BaseProcessor(ABC):
@abstractmethod
def is_available(self) -> bool:
"""Check if the processor is available (dependencies installed)."""
pass

@abstractmethod
async def process(self, content: str) -> Any:
"""Process the content."""
pass
10 changes: 10 additions & 0 deletions backend/processors/SimpleEmbedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import random
from typing import List
from backend.processors.BaseProcessor import BaseProcessor

class SimpleEmbedder(BaseProcessor):
def is_available(self) -> bool:
return True

async def process(self, chunks: List[str]) -> List[List[float]]:
return [[random.random() for _ in range(100)] for _ in chunks]
10 changes: 10 additions & 0 deletions backend/processors/SimpleTextSplitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List
from backend.processors.BaseProcessor import BaseProcessor

class SimpleTextSplitter(BaseProcessor):
def is_available(self) -> bool:
return True

async def process(self, content: str) -> List[str]:
return [content[i:i+1000] for i in range(0, len(content), 1000)]

12 changes: 12 additions & 0 deletions backend/processors/SimpleVectorStore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import List
from backend.processors.BaseProcessor import BaseProcessor

class SimpleVectorStore(BaseProcessor):
def __init__(self):
self.store = {}

def is_available(self) -> bool:
return True

async def process(self, asset_id: str, chunks: List[str], embeddings: List[List[float]]):
self.store[asset_id] = list(zip(chunks, embeddings))
Empty file added backend/processors/__init__.py
Empty file.
15 changes: 14 additions & 1 deletion backend/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from pydantic import BaseModel, field_serializer
from typing import Optional
from typing import Optional, List


# We have *Create schemas because API clients ideally don't set the id field, it's set by the server
Expand Down Expand Up @@ -55,6 +55,7 @@ class AssetBaseSchema(BaseModel):
creator: Optional[str] = None
subject: Optional[str] = None
description: Optional[str] = None
collection_id: str

class AssetCreateSchema(AssetBaseSchema):
pass
Expand Down Expand Up @@ -104,3 +105,15 @@ class VerifyAuthentication(BaseModel):
email: str
auth_resp: dict
challenge: str

# Asset Collection schemas
class AssetCollectionBaseSchema(BaseModel):
name: str
description: Optional[str] = None
track_individual_assets: bool = True

class AssetCollectionCreateSchema(AssetCollectionBaseSchema):
pass

class AssetCollectionSchema(AssetCollectionBaseSchema):
id: str

0 comments on commit fb6e852

Please sign in to comment.