Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
pkg-config
npm install -g markdownlint-cli
pip install --user yamllint codespell
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.5.0

- name: Cache Rust dependencies
uses: actions/cache@v4
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/test-and-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ jobs:
run: |
pip install -U "huggingface_hub[cli]" hf_transfer


- name: Download models (minimal on PRs)
env:
CI_MINIMAL_MODELS: ${{ github.event_name == 'pull_request' }}
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile.precommit
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ RUN pip install --break-system-packages yamllint

# CodeSpell
RUN pip install --break-system-packages codespell

# Golangci-lint
RUN curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.5.0
17 changes: 9 additions & 8 deletions src/semantic-router/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"

"github.com/vllm-project/semantic-router/src/semantic-router/pkg/api"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/extproc"
Expand Down Expand Up @@ -63,16 +64,16 @@ func main() {
ServiceVersion: cfg.Observability.Tracing.Resource.ServiceVersion,
DeploymentEnvironment: cfg.Observability.Tracing.Resource.DeploymentEnvironment,
}
if err := observability.InitTracing(ctx, tracingCfg); err != nil {
observability.Warnf("Failed to initialize tracing: %v", err)
if tracingErr := observability.InitTracing(ctx, tracingCfg); tracingErr != nil {
observability.Warnf("Failed to initialize tracing: %v", tracingErr)
}

// Set up graceful shutdown for tracing
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := observability.ShutdownTracing(shutdownCtx); err != nil {
observability.Errorf("Failed to shutdown tracing: %v", err)
if shutdownErr := observability.ShutdownTracing(shutdownCtx); shutdownErr != nil {
observability.Errorf("Failed to shutdown tracing: %v", shutdownErr)
}
}()
}
Expand All @@ -86,8 +87,8 @@ func main() {
observability.Infof("Received shutdown signal, cleaning up...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := observability.ShutdownTracing(shutdownCtx); err != nil {
observability.Errorf("Failed to shutdown tracing: %v", err)
if shutdownErr := observability.ShutdownTracing(shutdownCtx); shutdownErr != nil {
observability.Errorf("Failed to shutdown tracing: %v", shutdownErr)
}
os.Exit(0)
}()
Expand All @@ -97,8 +98,8 @@ func main() {
http.Handle("/metrics", promhttp.Handler())
metricsAddr := fmt.Sprintf(":%d", *metricsPort)
observability.Infof("Starting metrics server on %s", metricsAddr)
if err := http.ListenAndServe(metricsAddr, nil); err != nil {
observability.Errorf("Metrics server error: %v", err)
if metricsErr := http.ListenAndServe(metricsAddr, nil); metricsErr != nil {
observability.Errorf("Metrics server error: %v", metricsErr)
}
}()

Expand Down
4 changes: 3 additions & 1 deletion src/semantic-router/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0
google.golang.org/grpc v1.75.0
gopkg.in/yaml.v3 v3.0.1
k8s.io/apimachinery v0.31.4
sigs.k8s.io/yaml v1.6.0
)

require (
Expand Down Expand Up @@ -89,6 +89,7 @@ require (
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
Expand All @@ -99,6 +100,7 @@ require (
google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
Expand Down
8 changes: 6 additions & 2 deletions src/semantic-router/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,10 @@ go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
Expand Down Expand Up @@ -527,5 +531,5 @@ sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMm
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08=
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
44 changes: 20 additions & 24 deletions src/semantic-router/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
}

// handleHealth handles health check requests
func (s *ClassificationAPIServer) handleHealth(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleHealth(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status": "healthy", "service": "classification-api"}`))
_, _ = w.Write([]byte(`{"status": "healthy", "service": "classification-api"}`))
}

// APIOverviewResponse represents the response for GET /api/v1
Expand Down Expand Up @@ -363,19 +363,15 @@ type OpenAPIComponents struct {
}

// handleAPIOverview handles GET /api/v1 for API discovery
func (s *ClassificationAPIServer) handleAPIOverview(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleAPIOverview(w http.ResponseWriter, _ *http.Request) {
// Build endpoints list from registry, filtering out disabled endpoints
endpoints := make([]EndpointInfo, 0, len(endpointRegistry))
for _, metadata := range endpointRegistry {
// Filter out system prompt endpoints if they are disabled
if !s.enableSystemPromptAPI && (metadata.Path == "/config/system-prompts") {
continue
}
endpoints = append(endpoints, EndpointInfo{
Path: metadata.Path,
Method: metadata.Method,
Description: metadata.Description,
})
endpoints = append(endpoints, EndpointInfo(metadata))
}

response := APIOverviewResponse{
Expand Down Expand Up @@ -497,13 +493,13 @@ func (s *ClassificationAPIServer) generateOpenAPISpec() OpenAPISpec {
}

// handleOpenAPISpec serves the OpenAPI 3.0 specification at /openapi.json
func (s *ClassificationAPIServer) handleOpenAPISpec(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleOpenAPISpec(w http.ResponseWriter, _ *http.Request) {
spec := s.generateOpenAPISpec()
s.writeJSONResponse(w, http.StatusOK, spec)
}

// handleSwaggerUI serves the Swagger UI at /docs
func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, _ *http.Request) {
// Serve a simple HTML page that loads Swagger UI from CDN
html := `<!DOCTYPE html>
<html lang="en">
Expand Down Expand Up @@ -545,7 +541,7 @@ func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, r *http

w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write([]byte(html))
_, _ = w.Write([]byte(html))
}

// handleIntentClassification handles intent classification requests
Expand Down Expand Up @@ -609,7 +605,7 @@ func (s *ClassificationAPIServer) handleSecurityDetection(w http.ResponseWriter,
}

// Placeholder handlers for remaining endpoints
func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Combined classification not implemented yet")
}

Expand All @@ -631,7 +627,7 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite

// Check if texts field exists in JSON
var rawReq map[string]interface{}
if err := json.Unmarshal(body, &rawReq); err != nil {
if unmarshalErr := json.Unmarshal(body, &rawReq); unmarshalErr != nil {
metrics.RecordBatchClassificationError("unified", "invalid_json")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "Invalid JSON format")
return
Expand All @@ -645,9 +641,9 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
}

var req BatchClassificationRequest
if err := s.parseJSONRequest(r, &req); err != nil {
if parseErr := s.parseJSONRequest(r, &req); parseErr != nil {
metrics.RecordBatchClassificationError("unified", "parse_request_failed")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", parseErr.Error())
return
}

Expand All @@ -660,9 +656,9 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
}

// Validate task_type if provided
if err := validateTaskType(req.TaskType); err != nil {
if validateErr := validateTaskType(req.TaskType); validateErr != nil {
metrics.RecordBatchClassificationError("unified", "invalid_task_type")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_TASK_TYPE", err.Error())
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_TASK_TYPE", validateErr.Error())
return
}

Expand Down Expand Up @@ -703,12 +699,12 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
s.writeJSONResponse(w, http.StatusOK, response)
}

func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, _ *http.Request) {
response := s.buildModelsInfoResponse()
s.writeJSONResponse(w, http.StatusOK, response)
}

func (s *ClassificationAPIServer) handleClassifierInfo(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleClassifierInfo(w http.ResponseWriter, _ *http.Request) {
if s.config == nil {
s.writeJSONResponse(w, http.StatusOK, map[string]interface{}{
"status": "no_config",
Expand All @@ -726,7 +722,7 @@ func (s *ClassificationAPIServer) handleClassifierInfo(w http.ResponseWriter, r

// handleOpenAIModels handles OpenAI-compatible model listing at /v1/models
// It returns all models discoverable from the router configuration plus a synthetic "auto" model.
func (s *ClassificationAPIServer) handleOpenAIModels(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleOpenAIModels(w http.ResponseWriter, _ *http.Request) {
now := time.Now().Unix()

// Start with the special "auto" model always available from the router
Expand Down Expand Up @@ -763,15 +759,15 @@ func (s *ClassificationAPIServer) handleOpenAIModels(w http.ResponseWriter, r *h
s.writeJSONResponse(w, http.StatusOK, resp)
}

func (s *ClassificationAPIServer) handleClassificationMetrics(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleClassificationMetrics(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Classification metrics not implemented yet")
}

func (s *ClassificationAPIServer) handleGetConfig(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleGetConfig(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Get config not implemented yet")
}

func (s *ClassificationAPIServer) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleUpdateConfig(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Update config not implemented yet")
}

Expand Down Expand Up @@ -1096,7 +1092,7 @@ type SystemPromptUpdateRequest struct {
}

// handleGetSystemPrompts handles GET /config/system-prompts
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, _ *http.Request) {
cfg := s.config
if cfg == nil {
http.Error(w, "Configuration not available", http.StatusInternalServerError)
Expand Down
20 changes: 7 additions & 13 deletions src/semantic-router/pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (
"testing"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus/testutil"

candle_binding "github.com/vllm-project/semantic-router/candle-binding"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/cache"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics"

"github.com/prometheus/client_golang/prometheus/testutil"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestCache(t *testing.T) {
Expand All @@ -29,9 +28,7 @@ var _ = BeforeSuite(func() {
})

var _ = Describe("Cache Package", func() {
var (
tempDir string
)
var tempDir string

BeforeEach(func() {
var err error
Expand Down Expand Up @@ -213,7 +210,6 @@ development:
Expect(backend).To(BeNil())
})
})

})

Describe("ValidateCacheConfig", func() {
Expand Down Expand Up @@ -412,9 +408,7 @@ development:
})

Describe("InMemoryCache", func() {
var (
inMemoryCache cache.CacheBackend
)
var inMemoryCache cache.CacheBackend

BeforeEach(func() {
options := cache.InMemoryCacheOptions{
Expand All @@ -435,7 +429,7 @@ development:

It("should implement CacheBackend interface", func() {
// Check that the concrete type implements the interface
var _ cache.CacheBackend = inMemoryCache
_ = inMemoryCache
Expect(inMemoryCache).NotTo(BeNil())
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func TestInMemoryCacheIntegration(t *testing.T) {

// Step 3: Access first entry multiple times to increase its frequency
for range 2 {
responseBody, found, err := cache.FindSimilar("test-model", "Hello world")
if err != nil {
t.Logf("FindSimilar failed (expected due to high threshold): %v", err)
responseBody, found, findErr := cache.FindSimilar("test-model", "Hello world")
if findErr != nil {
t.Logf("FindSimilar failed (expected due to high threshold): %v", findErr)
}
if !found {
t.Errorf("Expected to find similar entry for first query")
Expand Down
Loading
Loading