Skip to content

Commit d13327b

Browse files
authored
Merge pull request #344 from doringeman/vllm-configure
refactor(scheduler): deduplicate vLLM backend selection logic
2 parents 59a5c51 + a6f258b commit d13327b

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

pkg/inference/scheduling/scheduler.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,27 @@ func (s *Scheduler) Run(ctx context.Context) error {
155155
return workers.Wait()
156156
}
157157

158+
// selectBackendForModel selects the appropriate backend for a model based on its format.
159+
// If the model is in safetensors format, it will prefer vLLM if available.
160+
func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.Backend, modelRef string) inference.Backend {
161+
config, err := model.Config()
162+
if err != nil {
163+
s.log.Warnln("failed to fetch model config:", err)
164+
return backend
165+
}
166+
167+
if config.Format == types.FormatSafetensors {
168+
if vllmBackend, ok := s.backends[vllm.Name]; ok && vllmBackend != nil {
169+
return vllmBackend
170+
}
171+
s.log.Warnf("Model %s is in safetensors format but vLLM backend is not available. "+
172+
"Backend %s may not support this format and could fail at runtime.",
173+
utils.SanitizeForLog(modelRef), backend.Name())
174+
}
175+
176+
return backend
177+
}
178+
158179
// handleOpenAIInference handles scheduling and responding to OpenAI inference
159180
// requests, including:
160181
// - POST <inference-prefix>/{backend}/v1/chat/completions
@@ -218,20 +239,7 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request
218239
s.tracker.TrackModel(model, r.UserAgent(), "inference/"+backendMode.String())
219240

220241
// Automatically identify models for vLLM.
221-
config, err := model.Config()
222-
if err != nil {
223-
s.log.Warnln("failed to fetch model config:", err)
224-
} else {
225-
if config.Format == types.FormatSafetensors {
226-
if vllmBackend, ok := s.backends[vllm.Name]; ok {
227-
backend = vllmBackend
228-
} else {
229-
s.log.Warnf("Model %s is in safetensors format but vLLM backend is not available. "+
230-
"Backend %s may not support this format and could fail at runtime.",
231-
utils.SanitizeForLog(request.Model), backend.Name())
232-
}
233-
}
234-
}
242+
backend = s.selectBackendForModel(model, backend, request.Model)
235243
}
236244

237245
// Wait for the corresponding backend installation to complete or fail. We
@@ -440,6 +448,9 @@ func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) {
440448
if model, err := s.modelManager.GetModel(configureRequest.Model); err == nil {
441449
// Configure is called by compose for each model.
442450
s.tracker.TrackModel(model, r.UserAgent(), "configure/"+mode.String())
451+
452+
// Automatically identify models for vLLM.
453+
backend = s.selectBackendForModel(model, backend, configureRequest.Model)
443454
}
444455
modelID := s.modelManager.ResolveModelID(configureRequest.Model)
445456
if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), modelID, mode, runnerConfig); err != nil {

0 commit comments

Comments
 (0)