Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ var (
UnknownError ErrorCode = "unknown_error"
)

type StreamingCacheEntry struct {
Key string
Query string
ResponseChunks []string
Complete bool
}

type StreamingCacheEntryChunk struct {
Key string
Index int
Content ChatStreamChunk
Complete bool
}

type StreamRequestID = string

// ChatStreamRequest defines a message that requests a new streaming chat
Expand Down
27 changes: 27 additions & 0 deletions pkg/cache/memory_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package cache

import "sync"

type MemoryCache struct {
cache map[string]interface{}
lock sync.RWMutex
}

func NewMemoryCache() *MemoryCache {
return &MemoryCache{
cache: make(map[string]interface{}),
}
}

func (m *MemoryCache) Get(key string) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
val, found := m.cache[key]
return val, found
}

func (m *MemoryCache) Set(key string, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
m.cache[key] = value
}
98 changes: 85 additions & 13 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package routers
import (
"context"
"errors"
"fmt"
"log"

"github.com/EinStack/glide/pkg/cache"
"github.com/EinStack/glide/pkg/routers/retry"
"go.uber.org/zap"

Expand Down Expand Up @@ -33,6 +36,7 @@ type LangRouter struct {
retry *retry.ExpRetry
tel *telemetry.Telemetry
logger *zap.Logger
cache *cache.MemoryCache
}

func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) {
Expand All @@ -56,6 +60,7 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter
chatStreamRouting: chatStreamRouting,
tel: tel,
logger: tel.L().With(zap.String("routerID", cfg.ID)),
cache: cache.NewMemoryCache(),
}

return router, err
Expand All @@ -70,6 +75,17 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem
return nil, ErrNoModels
}

// Generate cache key
cacheKey := req.Message.Content
if cachedResponse, found := r.cache.Get(cacheKey); found {
log.Println("found cached response and returning: ", cachedResponse)
if response, ok := cachedResponse.(*schemas.ChatResponse); ok {
return response, nil
} else {
log.Println("Failed to cast cached response to ChatResponse")
}
}

retryIterator := r.retry.Iterator()

for retryIterator.HasNext() {
Expand Down Expand Up @@ -101,17 +117,17 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem
zap.String("provider", langModel.Provider()),
zap.Error(err),
)

continue
}

resp.RouterID = r.routerID

// Store response in cache
r.cache.Set(cacheKey, resp)

return resp, nil
}

// no providers were available to handle the request,
// so we have to wait a bit with a hope there is some available next time
r.logger.Warn("No healthy model found to serve chat request, wait and retry")

err := retryIterator.WaitNext(ctx)
Expand All @@ -121,7 +137,6 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem
}
}

// if we reach this part, then we are in trouble
r.logger.Error("No model was available to handle chat request")

return nil, ErrNoModelAvailable
Expand All @@ -141,10 +156,43 @@ func (r *LangRouter) ChatStream(
req.Metadata,
&schemas.ErrorReason,
)

return
}

cacheKey := req.Message.Content
if streamingCacheEntry, found := r.cache.Get(cacheKey); found {
if entry, ok := streamingCacheEntry.(*schemas.StreamingCacheEntry); ok {
for _, chunkKey := range entry.ResponseChunks {
if cachedChunk, found := r.cache.Get(chunkKey); found {
if chunk, ok := cachedChunk.(*schemas.ChatStreamChunk); ok {
respC <- schemas.NewChatStreamChunk(
req.ID,
r.routerID,
req.Metadata,
chunk,
)
} else {
log.Println("Failed to cast cached chunk to ChatStreamChunk")
}
}
}

if entry.Complete {
return
}
} else {
log.Println("Failed to cast cached entry to StreamingCacheEntry")
}
} else {
streamingCacheEntry := &schemas.StreamingCacheEntry{
Key: cacheKey,
Query: req.Message.Content,
ResponseChunks: []string{},
Complete: false,
}
r.cache.Set(cacheKey, streamingCacheEntry)
}

retryIterator := r.retry.Iterator()

for retryIterator.HasNext() {
Expand Down Expand Up @@ -172,6 +220,7 @@ func (r *LangRouter) ChatStream(
continue
}

buffer := []schemas.ChatStreamChunk{}
for chunkResult := range modelRespC {
err = chunkResult.Error()
if err != nil {
Expand All @@ -182,9 +231,6 @@ func (r *LangRouter) ChatStream(
zap.Error(err),
)

// It's challenging to hide an error in case of streaming chat as consumer apps
// may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does),
// so we cannot easily restart that process from scratch
respC <- schemas.NewChatStreamError(
req.ID,
r.routerID,
Expand All @@ -198,25 +244,52 @@ func (r *LangRouter) ChatStream(
}

chunk := chunkResult.Chunk()

buffer = append(buffer, *chunk)
respC <- schemas.NewChatStreamChunk(
req.ID,
r.routerID,
req.Metadata,
chunk,
)

if len(buffer) >= 1048 { // Define bufferSize as per your requirement
chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer))
r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{
Key: chunkKey,
Index: len(buffer),
Content: *chunk,
})
streamingCacheEntry := schemas.StreamingCacheEntry{}
streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey)
buffer = buffer[:0] // Reset buffer
r.cache.Set(cacheKey, streamingCacheEntry)
}
}

if len(buffer) > 0 {
chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer))
r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{
Key: chunkKey,
Index: len(buffer),
Content: buffer[0], // Assuming buffer has at least one element
})
streamingCacheEntry := schemas.StreamingCacheEntry{}
streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey)
buffer = buffer[:0] // Reset buffer
r.cache.Set(cacheKey, streamingCacheEntry)
}

streamingCacheEntry := schemas.StreamingCacheEntry{}
streamingCacheEntry.Complete = true
r.cache.Set(cacheKey, streamingCacheEntry)

return
}

// no providers were available to handle the request,
// so we have to wait a bit with a hope there is some available next time
r.logger.Warn("No healthy model found to serve streaming chat request, wait and retry")

err := retryIterator.WaitNext(ctx)
if err != nil {
// something has cancelled the context
respC <- schemas.NewChatStreamError(
req.ID,
r.routerID,
Expand All @@ -230,7 +303,6 @@ func (r *LangRouter) ChatStream(
}
}

// if we reach this part, then we are in trouble
r.logger.Error(
"No model was available to handle streaming chat request. " +
"Try to configure more fallback models to avoid this",
Expand Down