Skip to content

Commit ad4c897

Browse files
GaosCodevincentkoc
andauthored
feat(embed): add embedding job drain
* feat(embed): add embedding job drain * fix(embed): migrate legacy jobs and requeue rate limits safely --------- Co-authored-by: Vincent Koc <vincentkoc@ieee.org>
1 parent 2f07416 commit ad4c897

13 files changed

Lines changed: 1449 additions & 3 deletions

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ All notable changes to `discrawl` will be documented in this file.
88
- `messages` and `mentions` now use composite read-path indexes so larger archives spend less time sorting/filtering common guild, channel, and author queries
99
- normalized message text is now sanitized before it reaches SQLite and FTS5, repairing malformed UTF-8 and stripping invisible/control-character noise that can poison search content
1010
- local embedding providers now support OpenAI-compatible endpoints, Ollama, and llama.cpp, and `doctor` can probe the configured provider before you queue vectors
11+
- `embed` now drains the queued embedding backlog in bounded batches, requeues safely on provider throttling, and drops stale stored vectors when messages no longer have embeddable content
1112

1213
## 0.3.0 - 2026-04-21
1314

internal/cli/admin_commands.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,61 @@ func (r *runtime) runStatus(args []string) error {
168168
return r.print(status)
169169
}
170170

171+
func (r *runtime) runEmbed(args []string) error {
172+
fs := flag.NewFlagSet("embed", flag.ContinueOnError)
173+
fs.SetOutput(io.Discard)
174+
limit := fs.Int("limit", store.DefaultEmbedLimit(), "")
175+
batchSize := fs.Int("batch-size", r.cfg.Search.Embeddings.BatchSize, "")
176+
rebuild := fs.Bool("rebuild", false, "")
177+
if err := fs.Parse(args); err != nil {
178+
return usageErr(err)
179+
}
180+
if fs.NArg() != 0 {
181+
return usageErr(fmt.Errorf("embed takes no positional arguments"))
182+
}
183+
if *limit <= 0 {
184+
return usageErr(fmt.Errorf("--limit must be positive"))
185+
}
186+
if *batchSize <= 0 {
187+
return usageErr(fmt.Errorf("--batch-size must be positive"))
188+
}
189+
if !r.cfg.Search.Embeddings.Enabled {
190+
return usageErr(fmt.Errorf("embeddings are disabled in config"))
191+
}
192+
providerFactory := r.newEmbed
193+
if providerFactory == nil {
194+
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
195+
return embed.NewProvider(cfg)
196+
}
197+
}
198+
provider, err := providerFactory(r.cfg.Search.Embeddings)
199+
if err != nil {
200+
return configErr(err)
201+
}
202+
opts := store.EmbeddingDrainOptions{
203+
Provider: r.cfg.Search.Embeddings.Provider,
204+
Model: r.cfg.Search.Embeddings.Model,
205+
InputVersion: store.EmbeddingInputVersion,
206+
Limit: *limit,
207+
BatchSize: *batchSize,
208+
MaxInputChars: r.cfg.Search.Embeddings.MaxInputChars,
209+
Now: r.now,
210+
}
211+
requeued := 0
212+
if *rebuild {
213+
requeued, err = r.store.RequeueAllEmbeddingJobs(r.ctx, opts)
214+
if err != nil {
215+
return err
216+
}
217+
}
218+
stats, err := r.store.DrainEmbeddingJobs(r.ctx, provider, opts)
219+
if err != nil {
220+
return err
221+
}
222+
stats.Requeued = requeued
223+
return r.print(stats)
224+
}
225+
171226
func (r *runtime) runDoctor(args []string) error {
172227
if len(args) != 0 {
173228
return usageErr(fmt.Errorf("doctor takes no arguments"))

internal/cli/cli.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/bwmarrin/discordgo"
1414
"github.com/steipete/discrawl/internal/config"
1515
"github.com/steipete/discrawl/internal/discord"
16+
"github.com/steipete/discrawl/internal/embed"
1617
"github.com/steipete/discrawl/internal/share"
1718
"github.com/steipete/discrawl/internal/store"
1819
"github.com/steipete/discrawl/internal/syncer"
@@ -96,6 +97,7 @@ type runtime struct {
9697
openStore func(context.Context, string) (*store.Store, error)
9798
newDiscord func(config.Config) (discordClient, error)
9899
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
100+
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
99101
now func() time.Time
100102
}
101103

@@ -130,6 +132,8 @@ func (r *runtime) dispatch(rest []string) error {
130132
return r.withServicesAuto(hasBoolFlag(rest[1:], "--sync"), true, func() error { return r.runMessages(rest[1:]) })
131133
case "mentions":
132134
return r.withServices(false, func() error { return r.runMentions(rest[1:]) })
135+
case "embed":
136+
return r.withServices(false, func() error { return r.runEmbed(rest[1:]) })
133137
case "sql":
134138
return r.withServices(false, func() error { return r.runSQL(rest[1:]) })
135139
case "members":

internal/cli/cli_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cli
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"log/slog"
78
"net/http"
89
"net/http/httptest"
@@ -489,6 +490,70 @@ func runGit(t *testing.T, dir string, args ...string) {
489490
require.NoError(t, err, string(out))
490491
}
491492

493+
func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) {
494+
ctx := context.Background()
495+
dir := t.TempDir()
496+
cfgPath := filepath.Join(dir, "config.toml")
497+
dbPath := filepath.Join(dir, "discrawl.db")
498+
499+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
500+
require.Equal(t, "/embeddings", r.URL.Path)
501+
var req struct {
502+
Input []string `json:"input"`
503+
}
504+
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
505+
require.Len(t, req.Input, 1)
506+
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1,2]}]}`))
507+
}))
508+
defer server.Close()
509+
510+
cfg := config.Default()
511+
cfg.DBPath = dbPath
512+
cfg.Search.Embeddings.Enabled = true
513+
cfg.Search.Embeddings.Provider = "openai_compatible"
514+
cfg.Search.Embeddings.Model = "local-model"
515+
cfg.Search.Embeddings.BaseURL = server.URL
516+
cfg.Search.Embeddings.APIKeyEnv = ""
517+
require.NoError(t, config.Write(cfgPath, cfg))
518+
519+
s, err := store.Open(ctx, dbPath)
520+
require.NoError(t, err)
521+
for _, id := range []string{"m1", "m2"} {
522+
require.NoError(t, s.UpsertMessageWithOptions(ctx, store.MessageRecord{
523+
ID: id,
524+
GuildID: "g1",
525+
ChannelID: "c1",
526+
MessageType: 0,
527+
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
528+
Content: "hello",
529+
NormalizedContent: "hello",
530+
RawJSON: `{}`,
531+
}, store.WriteOptions{EnqueueEmbedding: true}))
532+
}
533+
require.NoError(t, s.Close())
534+
535+
var out bytes.Buffer
536+
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--limit", "1"}, &out, &bytes.Buffer{}))
537+
require.Contains(t, out.String(), "processed=1")
538+
require.Contains(t, out.String(), "succeeded=1")
539+
require.Contains(t, out.String(), "remaining_backlog=1")
540+
require.Contains(t, out.String(), "provider=openai_compatible")
541+
542+
s, err = store.Open(ctx, dbPath)
543+
require.NoError(t, err)
544+
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
545+
require.NoError(t, err)
546+
require.Equal(t, "1", rows[0][0])
547+
require.NoError(t, s.Close())
548+
549+
out.Reset()
550+
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--rebuild", "--limit", "1"}, &out, &bytes.Buffer{}))
551+
require.Contains(t, out.String(), "processed=1")
552+
require.Contains(t, out.String(), "succeeded=1")
553+
require.Contains(t, out.String(), "remaining_backlog=1")
554+
require.Contains(t, out.String(), "requeued=2")
555+
}
556+
492557
type fakeDiscordClient struct {
493558
guilds []*discordgo.UserGuild
494559
self *discordgo.User

internal/cli/output.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Commands:
8080
search
8181
messages
8282
mentions
83+
embed
8384
sql
8485
members
8586
channels
@@ -108,6 +109,21 @@ func printHuman(w io.Writer, value any) error {
108109
v.DBPath, v.GuildCount, v.ChannelCount, v.ThreadCount, v.MessageCount, v.MemberCount, v.EmbeddingBacklog,
109110
formatTime(v.LastSyncAt), formatTime(v.LastTailEventAt))
110111
return err
112+
case store.EmbeddingDrainStats:
113+
_, err := fmt.Fprintf(w, "processed=%d\nsucceeded=%d\nfailed=%d\nskipped=%d\nremaining_backlog=%d\nprovider=%s\nmodel=%s\ninput_version=%s\n",
114+
v.Processed, v.Succeeded, v.Failed, v.Skipped, v.RemainingBacklog, v.Provider, v.Model, v.InputVersion)
115+
if err != nil {
116+
return err
117+
}
118+
if v.Requeued > 0 {
119+
if _, err := fmt.Fprintf(w, "requeued=%d\n", v.Requeued); err != nil {
120+
return err
121+
}
122+
}
123+
if v.RateLimited {
124+
_, err = fmt.Fprintln(w, "rate_limited=true")
125+
}
126+
return err
111127
case []store.SearchResult:
112128
for _, row := range v {
113129
if _, err := fmt.Fprintf(w, "[%s/%s] %s %s\n%s\n\n", row.GuildID, row.ChannelName, row.AuthorName, formatTime(row.CreatedAt), row.Content); err != nil {

internal/embed/ollama.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string,
8282
defer func() { _ = resp.Body.Close() }()
8383
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
8484
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
85-
return fmt.Errorf("embedding request failed with HTTP %d: %s", resp.StatusCode, string(msg))
85+
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
8686
}
8787
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
8888
return fmt.Errorf("decode embedding response: %w", err)

internal/embed/provider.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ type EmbeddingBatch struct {
4141
Vectors [][]float32
4242
}
4343

44+
type HTTPError struct {
45+
StatusCode int
46+
Body string
47+
}
48+
49+
func (e *HTTPError) Error() string {
50+
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
51+
}
52+
53+
func IsRateLimitError(err error) bool {
54+
var httpErr *HTTPError
55+
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
56+
}
57+
4458
type CheckResult struct {
4559
Provider string
4660
Model string

internal/embed/provider_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,27 @@ func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
172172
require.False(t, result.Probed)
173173
}
174174

175+
func TestProviderExposesRateLimitErrors(t *testing.T) {
176+
t.Parallel()
177+
178+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
179+
http.Error(w, "rate limited", http.StatusTooManyRequests)
180+
}))
181+
defer server.Close()
182+
183+
provider, err := NewProvider(config.EmbeddingsConfig{
184+
Provider: ProviderOpenAICompatible,
185+
Model: "local-model",
186+
BaseURL: server.URL,
187+
RequestTimeout: "5s",
188+
})
189+
require.NoError(t, err)
190+
191+
_, err = provider.Embed(context.Background(), []string{"one"})
192+
require.ErrorContains(t, err, "HTTP 429")
193+
require.True(t, IsRateLimitError(err))
194+
}
195+
175196
func TestProviderRejectsInvalidResponses(t *testing.T) {
176197
t.Parallel()
177198

0 commit comments

Comments
 (0)