Skip to content

Commit 85ed9ae

Browse files
committed
centralized request middleware
Signed-off-by: Dave Lee <[email protected]>
1 parent 835932e commit 85ed9ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+834
-714
lines changed

core/backend/llm.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ type TokenUsage struct {
3232
Completion int
3333
}
3434

35-
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
35+
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
3636
modelFile := c.Model
3737

3838
var inferenceModel grpc.Backend
3939
var err error
4040

41-
opts := ModelOptions(c, o, []model.Option{})
41+
opts := ModelOptions(*c, o, []model.Option{})
4242

4343
if c.Backend != "" {
4444
opts = append(opts, model.WithBackendString(c.Backend))
@@ -96,7 +96,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
9696

9797
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
9898
fn := func() (LLMResponse, error) {
99-
opts := gRPCPredictOpts(c, loader.ModelPath)
99+
opts := gRPCPredictOpts(*c, loader.ModelPath)
100100
opts.Prompt = s
101101
opts.Messages = protoMessages
102102
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate

core/backend/rerank.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ import (
99
model "github.com/mudler/LocalAI/pkg/model"
1010
)
1111

12-
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
12+
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
1313

14-
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
14+
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
1515
rerankModel, err := loader.BackendLoader(opts...)
1616
if err != nil {
1717
return nil, err

core/backend/soundgeneration.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
)
1414

1515
func SoundGeneration(
16-
modelFile string,
1716
text string,
1817
duration *float32,
1918
temperature *float32,
@@ -25,7 +24,7 @@ func SoundGeneration(
2524
backendConfig config.BackendConfig,
2625
) (string, *proto.Result, error) {
2726

28-
opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
27+
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
2928

3029
soundGenModel, err := loader.BackendLoader(opts...)
3130
if err != nil {
@@ -45,7 +44,7 @@ func SoundGeneration(
4544

4645
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
4746
Text: text,
48-
Model: modelFile,
47+
Model: backendConfig.Model,
4948
Dst: filePath,
5049
Sample: doSample,
5150
Duration: duration,

core/backend/tokenize.go

+1-5
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,10 @@ import (
99

1010
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
1111

12-
modelFile := backendConfig.Model
13-
1412
var inferenceModel grpc.Backend
1513
var err error
1614

17-
opts := ModelOptions(backendConfig, appConfig, []model.Option{
18-
model.WithModel(modelFile),
19-
})
15+
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
2016

2117
if backendConfig.Backend == "" {
2218
inferenceModel, err = loader.GreedyLoader(opts...)

core/backend/tts.go

+16-25
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,23 @@ import (
1414
)
1515

1616
func ModelTTS(
17-
backend,
1817
text,
19-
modelFile,
2018
voice,
2119
language string,
2220
loader *model.ModelLoader,
2321
appConfig *config.ApplicationConfig,
2422
backendConfig config.BackendConfig,
2523
) (string, *proto.Result, error) {
26-
bb := backend
27-
if bb == "" {
28-
bb = model.PiperBackend
29-
}
30-
31-
opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{
32-
model.WithBackendString(bb),
33-
model.WithModel(modelFile),
24+
opts := ModelOptions(*&backendConfig, appConfig, []model.Option{
25+
model.WithDefaultBackendString(model.PiperBackend),
3426
})
3527
ttsModel, err := loader.BackendLoader(opts...)
3628
if err != nil {
3729
return "", nil, err
3830
}
3931

4032
if ttsModel == nil {
41-
return "", nil, fmt.Errorf("could not load piper model")
33+
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
4234
}
4335

4436
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
@@ -48,22 +40,21 @@ func ModelTTS(
4840
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
4941
filePath := filepath.Join(appConfig.AudioDir, fileName)
5042

51-
// If the model file is not empty, we pass it joined with the model path
43+
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
44+
// This should be addressed in a follow up PR soon.
45+
// Copying it over nearly verbatim, as TTS backends are not functional without this.
5246
modelPath := ""
53-
if modelFile != "" {
54-
// If the model file is not empty, we pass it joined with the model path
55-
// Checking first that it exists and is not outside ModelPath
56-
// TODO: we should actually first check if the modelFile is looking like
57-
// a FS path
58-
mp := filepath.Join(loader.ModelPath, modelFile)
59-
if _, err := os.Stat(mp); err == nil {
60-
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
61-
return "", nil, err
62-
}
63-
modelPath = mp
64-
} else {
65-
modelPath = modelFile
47+
// Checking first that it exists and is not outside ModelPath
48+
// TODO: we should actually first check if the modelFile is looking like
49+
// a FS path
50+
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
51+
if _, err := os.Stat(mp); err == nil {
52+
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
53+
return "", nil, err
6654
}
55+
modelPath = mp
56+
} else {
57+
modelPath = backendConfig.Model // skip this step if it fails?????
6758
}
6859

6960
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

core/cli/soundgeneration.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
8686
options := config.BackendConfig{}
8787
options.SetDefaults()
8888
options.Backend = t.Backend
89+
options.Model = t.Model
8990

9091
var inputFile *string
9192
if t.InputFile != "" {
9293
inputFile = &t.InputFile
9394
}
9495

95-
filePath, _, err := backend.SoundGeneration(t.Model, text,
96+
filePath, _, err := backend.SoundGeneration(text,
9697
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
9798
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)
9899

core/cli/tts.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
5252

5353
options := config.BackendConfig{}
5454
options.SetDefaults()
55+
options.Backend = t.Backend
56+
options.Model = t.Model
5557

56-
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options)
58+
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
5759
if err != nil {
5860
return err
5961
}

core/config/backend_config.go

+20-11
Original file line numberDiff line numberDiff line change
@@ -432,19 +432,20 @@ func (c *BackendConfig) HasTemplate() bool {
432432
type BackendConfigUsecases int
433433

434434
const (
435-
FLAG_ANY BackendConfigUsecases = 0b000000000
436-
FLAG_CHAT BackendConfigUsecases = 0b000000001
437-
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
438-
FLAG_EDIT BackendConfigUsecases = 0b000000100
439-
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
440-
FLAG_RERANK BackendConfigUsecases = 0b000010000
441-
FLAG_IMAGE BackendConfigUsecases = 0b000100000
442-
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
443-
FLAG_TTS BackendConfigUsecases = 0b010000000
444-
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
435+
FLAG_ANY BackendConfigUsecases = 0b0000000000
436+
FLAG_CHAT BackendConfigUsecases = 0b0000000001
437+
FLAG_COMPLETION BackendConfigUsecases = 0b0000000010
438+
FLAG_EDIT BackendConfigUsecases = 0b0000000100
439+
FLAG_EMBEDDINGS BackendConfigUsecases = 0b0000001000
440+
FLAG_RERANK BackendConfigUsecases = 0b0000010000
441+
FLAG_IMAGE BackendConfigUsecases = 0b0000100000
442+
FLAG_TRANSCRIPT BackendConfigUsecases = 0b0001000000
443+
FLAG_TTS BackendConfigUsecases = 0b0010000000
444+
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b0100000000
445+
FLAG_TOKENIZE BackendConfigUsecases = 0b1000000000
445446

446447
// Common Subsets
447-
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
448+
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
448449
)
449450

450451
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
@@ -459,6 +460,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
459460
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
460461
"FLAG_TTS": FLAG_TTS,
461462
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
463+
"FLAG_TOKENIZE": FLAG_TOKENIZE,
462464
"FLAG_LLM": FLAG_LLM,
463465
}
464466
}
@@ -544,5 +546,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
544546
}
545547
}
546548

549+
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
550+
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
551+
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
552+
return false
553+
}
554+
}
555+
547556
return true
548557
}

core/config/backend_config_loader.go

+21-9
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
8181
c := &[]*BackendConfig{}
8282
f, err := os.ReadFile(file)
8383
if err != nil {
84-
return nil, fmt.Errorf("cannot read config file: %w", err)
84+
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
8585
}
8686
if err := yaml.Unmarshal(f, c); err != nil {
87-
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
87+
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
8888
}
8989

9090
for _, cc := range *c {
@@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
101101
c := &BackendConfig{}
102102
f, err := os.ReadFile(file)
103103
if err != nil {
104-
return nil, fmt.Errorf("cannot read config file: %w", err)
104+
return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
105105
}
106106
if err := yaml.Unmarshal(f, c); err != nil {
107-
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
107+
return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
108108
}
109109

110110
c.SetDefaults(opts...)
@@ -117,7 +117,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
117117
// Load a config file if present after the model name
118118
cfg := &BackendConfig{
119119
PredictionOptions: schema.PredictionOptions{
120-
Model: modelName,
120+
BasicModelRequest: schema.BasicModelRequest{
121+
Model: modelName,
122+
},
121123
},
122124
}
123125

@@ -145,6 +147,15 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
145147
return cfg, nil
146148
}
147149

150+
func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
151+
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
152+
LoadOptionDebug(appConfig.Debug),
153+
LoadOptionThreads(appConfig.Threads),
154+
LoadOptionContextSize(appConfig.ContextSize),
155+
LoadOptionF16(appConfig.F16),
156+
ModelPath(appConfig.ModelPath))
157+
}
158+
148159
// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
149160
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
150161
bcl.Lock()
@@ -167,7 +178,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
167178
defer bcl.Unlock()
168179
c, err := readBackendConfigFromFile(file, opts...)
169180
if err != nil {
170-
return fmt.Errorf("cannot read config file: %w", err)
181+
return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
171182
}
172183

173184
if c.Validate() {
@@ -324,9 +335,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
324335
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
325336
bcl.Lock()
326337
defer bcl.Unlock()
338+
327339
entries, err := os.ReadDir(path)
328340
if err != nil {
329-
return fmt.Errorf("cannot read directory '%s': %w", path, err)
341+
return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
330342
}
331343
files := make([]fs.FileInfo, 0, len(entries))
332344
for _, entry := range entries {
@@ -344,13 +356,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
344356
}
345357
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
346358
if err != nil {
347-
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name())
359+
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
348360
continue
349361
}
350362
if c.Validate() {
351363
bcl.configs[c.Name] = *c
352364
} else {
353-
log.Error().Err(err).Msgf("config is not valid")
365+
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
354366
}
355367
}
356368

core/config/guesser.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ const (
2626
type settingsConfig struct {
2727
StopWords []string
2828
TemplateConfig TemplateConfig
29-
RepeatPenalty float64
29+
RepeatPenalty float64
3030
}
3131

3232
// default settings to adopt with a given model family
3333
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
3434
Gemma: {
3535
RepeatPenalty: 1.0,
36-
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
36+
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
3737
TemplateConfig: TemplateConfig{
3838
Chat: "{{.Input }}\n<start_of_turn>model\n",
3939
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
@@ -161,10 +161,11 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
161161
}
162162

163163
// We try to guess only if we don't have a template defined already
164-
f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName()))
164+
guessPath := filepath.Join(modelPath, cfg.ModelFileName())
165+
f, err := gguf.ParseGGUFFile(guessPath)
165166
if err != nil {
166167
// Only valid for gguf files
167-
log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file")
168+
log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file")
168169
return
169170
}
170171

core/http/app.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
121121
return metricsService.Shutdown()
122122
})
123123
}
124-
125124
}
126125
// Health Checks should always be exempt from auth, so register these first
127126
routes.HealthRoutes(app)
@@ -158,13 +157,15 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
158157
galleryService := services.NewGalleryService(appConfig)
159158
galleryService.Start(appConfig.Context, cl)
160159

161-
routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
162-
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
163-
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
160+
requestExtractor := middleware.NewRequestExtractor(cl, ml, appConfig)
161+
162+
routes.RegisterElevenLabsRoutes(app, requestExtractor, cl, ml, appConfig)
163+
routes.RegisterLocalAIRoutes(app, requestExtractor, cl, ml, appConfig, galleryService)
164+
routes.RegisterOpenAIRoutes(app, requestExtractor, cl, ml, appConfig)
164165
if !appConfig.DisableWebUI {
165166
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
166167
}
167-
routes.RegisterJINARoutes(app, cl, ml, appConfig)
168+
routes.RegisterJINARoutes(app, requestExtractor, cl, ml, appConfig)
168169

169170
httpFS := http.FS(embedDirStatic)
170171

0 commit comments

Comments
 (0)