Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
pkg-config
npm install -g markdownlint-cli
pip install --user yamllint codespell
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.5.0

- name: Cache Rust dependencies
uses: actions/cache@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-and-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
make \
build-essential \
pkg-config
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.5.0

- name: Cache Rust dependencies
uses: actions/cache@v4
Expand Down Expand Up @@ -74,7 +75,6 @@ jobs:
run: |
pip install -U "huggingface_hub[cli]" hf_transfer


- name: Download models (minimal on PRs)
env:
CI_MINIMAL_MODELS: ${{ github.event_name == 'pull_request' }}
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ repos:
language: system
files: \.go$

- repo: local
hooks:
- id: golang-lint
name: go lint
entry: make go-lint
language: system
files: \.go$
pass_filenames: false

# Markdown specific hooks
- repo: local
hooks:
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile.precommit
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ RUN pip install --break-system-packages yamllint

# CodeSpell
RUN pip install --break-system-packages codespell

# Golangci-lint
RUN curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.5.0
17 changes: 9 additions & 8 deletions src/semantic-router/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"

"github.com/vllm-project/semantic-router/src/semantic-router/pkg/api"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/extproc"
Expand Down Expand Up @@ -63,16 +64,16 @@ func main() {
ServiceVersion: cfg.Observability.Tracing.Resource.ServiceVersion,
DeploymentEnvironment: cfg.Observability.Tracing.Resource.DeploymentEnvironment,
}
if err := observability.InitTracing(ctx, tracingCfg); err != nil {
observability.Warnf("Failed to initialize tracing: %v", err)
if tracingErr := observability.InitTracing(ctx, tracingCfg); tracingErr != nil {
observability.Warnf("Failed to initialize tracing: %v", tracingErr)
}

// Set up graceful shutdown for tracing
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := observability.ShutdownTracing(shutdownCtx); err != nil {
observability.Errorf("Failed to shutdown tracing: %v", err)
if shutdownErr := observability.ShutdownTracing(shutdownCtx); shutdownErr != nil {
observability.Errorf("Failed to shutdown tracing: %v", shutdownErr)
}
}()
}
Expand All @@ -86,8 +87,8 @@ func main() {
observability.Infof("Received shutdown signal, cleaning up...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := observability.ShutdownTracing(shutdownCtx); err != nil {
observability.Errorf("Failed to shutdown tracing: %v", err)
if shutdownErr := observability.ShutdownTracing(shutdownCtx); shutdownErr != nil {
observability.Errorf("Failed to shutdown tracing: %v", shutdownErr)
}
os.Exit(0)
}()
Expand All @@ -97,8 +98,8 @@ func main() {
http.Handle("/metrics", promhttp.Handler())
metricsAddr := fmt.Sprintf(":%d", *metricsPort)
observability.Infof("Starting metrics server on %s", metricsAddr)
if err := http.ListenAndServe(metricsAddr, nil); err != nil {
observability.Errorf("Metrics server error: %v", err)
if metricsErr := http.ListenAndServe(metricsAddr, nil); metricsErr != nil {
observability.Errorf("Metrics server error: %v", metricsErr)
}
}()

Expand Down
2 changes: 2 additions & 0 deletions src/semantic-router/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ require (
google.golang.org/grpc v1.75.0
gopkg.in/yaml.v3 v3.0.1
k8s.io/apimachinery v0.31.4
sigs.k8s.io/yaml v1.6.0
)

require (
Expand Down Expand Up @@ -89,6 +90,7 @@ require (
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
Expand Down
8 changes: 6 additions & 2 deletions src/semantic-router/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,10 @@ go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
Expand Down Expand Up @@ -527,5 +531,5 @@ sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMm
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08=
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
44 changes: 20 additions & 24 deletions src/semantic-router/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
}

// handleHealth handles health check requests
func (s *ClassificationAPIServer) handleHealth(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleHealth(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status": "healthy", "service": "classification-api"}`))
_, _ = w.Write([]byte(`{"status": "healthy", "service": "classification-api"}`))
}

// APIOverviewResponse represents the response for GET /api/v1
Expand Down Expand Up @@ -363,19 +363,15 @@ type OpenAPIComponents struct {
}

// handleAPIOverview handles GET /api/v1 for API discovery
func (s *ClassificationAPIServer) handleAPIOverview(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleAPIOverview(w http.ResponseWriter, _ *http.Request) {
// Build endpoints list from registry, filtering out disabled endpoints
endpoints := make([]EndpointInfo, 0, len(endpointRegistry))
for _, metadata := range endpointRegistry {
// Filter out system prompt endpoints if they are disabled
if !s.enableSystemPromptAPI && (metadata.Path == "/config/system-prompts") {
continue
}
endpoints = append(endpoints, EndpointInfo{
Path: metadata.Path,
Method: metadata.Method,
Description: metadata.Description,
})
endpoints = append(endpoints, EndpointInfo(metadata))
}

response := APIOverviewResponse{
Expand Down Expand Up @@ -497,13 +493,13 @@ func (s *ClassificationAPIServer) generateOpenAPISpec() OpenAPISpec {
}

// handleOpenAPISpec serves the OpenAPI 3.0 specification at /openapi.json
func (s *ClassificationAPIServer) handleOpenAPISpec(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleOpenAPISpec(w http.ResponseWriter, _ *http.Request) {
spec := s.generateOpenAPISpec()
s.writeJSONResponse(w, http.StatusOK, spec)
}

// handleSwaggerUI serves the Swagger UI at /docs
func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, _ *http.Request) {
// Serve a simple HTML page that loads Swagger UI from CDN
html := `<!DOCTYPE html>
<html lang="en">
Expand Down Expand Up @@ -545,7 +541,7 @@ func (s *ClassificationAPIServer) handleSwaggerUI(w http.ResponseWriter, r *http

w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write([]byte(html))
_, _ = w.Write([]byte(html))
}

// handleIntentClassification handles intent classification requests
Expand Down Expand Up @@ -609,7 +605,7 @@ func (s *ClassificationAPIServer) handleSecurityDetection(w http.ResponseWriter,
}

// Placeholder handlers for remaining endpoints
func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Combined classification not implemented yet")
}

Expand All @@ -631,7 +627,7 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite

// Check if texts field exists in JSON
var rawReq map[string]interface{}
if err := json.Unmarshal(body, &rawReq); err != nil {
if unmarshalErr := json.Unmarshal(body, &rawReq); unmarshalErr != nil {
metrics.RecordBatchClassificationError("unified", "invalid_json")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "Invalid JSON format")
return
Expand All @@ -645,9 +641,9 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
}

var req BatchClassificationRequest
if err := s.parseJSONRequest(r, &req); err != nil {
if parseErr := s.parseJSONRequest(r, &req); parseErr != nil {
metrics.RecordBatchClassificationError("unified", "parse_request_failed")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", parseErr.Error())
return
}

Expand All @@ -660,9 +656,9 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
}

// Validate task_type if provided
if err := validateTaskType(req.TaskType); err != nil {
if validateErr := validateTaskType(req.TaskType); validateErr != nil {
metrics.RecordBatchClassificationError("unified", "invalid_task_type")
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_TASK_TYPE", err.Error())
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_TASK_TYPE", validateErr.Error())
return
}

Expand Down Expand Up @@ -703,12 +699,12 @@ func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWrite
s.writeJSONResponse(w, http.StatusOK, response)
}

func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleModelsInfo(w http.ResponseWriter, _ *http.Request) {
response := s.buildModelsInfoResponse()
s.writeJSONResponse(w, http.StatusOK, response)
}

func (s *ClassificationAPIServer) handleClassifierInfo(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleClassifierInfo(w http.ResponseWriter, _ *http.Request) {
if s.config == nil {
s.writeJSONResponse(w, http.StatusOK, map[string]interface{}{
"status": "no_config",
Expand All @@ -726,7 +722,7 @@ func (s *ClassificationAPIServer) handleClassifierInfo(w http.ResponseWriter, r

// handleOpenAIModels handles OpenAI-compatible model listing at /v1/models
// It returns all models discoverable from the router configuration plus a synthetic "auto" model.
func (s *ClassificationAPIServer) handleOpenAIModels(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleOpenAIModels(w http.ResponseWriter, _ *http.Request) {
now := time.Now().Unix()

// Start with the special "auto" model always available from the router
Expand Down Expand Up @@ -763,15 +759,15 @@ func (s *ClassificationAPIServer) handleOpenAIModels(w http.ResponseWriter, r *h
s.writeJSONResponse(w, http.StatusOK, resp)
}

func (s *ClassificationAPIServer) handleClassificationMetrics(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleClassificationMetrics(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Classification metrics not implemented yet")
}

func (s *ClassificationAPIServer) handleGetConfig(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleGetConfig(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Get config not implemented yet")
}

func (s *ClassificationAPIServer) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleUpdateConfig(w http.ResponseWriter, _ *http.Request) {
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Update config not implemented yet")
}

Expand Down Expand Up @@ -1096,7 +1092,7 @@ type SystemPromptUpdateRequest struct {
}

// handleGetSystemPrompts handles GET /config/system-prompts
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) {
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, _ *http.Request) {
cfg := s.config
if cfg == nil {
http.Error(w, "Configuration not available", http.StatusInternalServerError)
Expand Down
20 changes: 7 additions & 13 deletions src/semantic-router/pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (
"testing"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus/testutil"

candle_binding "github.com/vllm-project/semantic-router/candle-binding"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/cache"
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics"

"github.com/prometheus/client_golang/prometheus/testutil"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestCache(t *testing.T) {
Expand All @@ -29,9 +28,7 @@ var _ = BeforeSuite(func() {
})

var _ = Describe("Cache Package", func() {
var (
tempDir string
)
var tempDir string

BeforeEach(func() {
var err error
Expand Down Expand Up @@ -213,7 +210,6 @@ development:
Expect(backend).To(BeNil())
})
})

})

Describe("ValidateCacheConfig", func() {
Expand Down Expand Up @@ -412,9 +408,7 @@ development:
})

Describe("InMemoryCache", func() {
var (
inMemoryCache cache.CacheBackend
)
var inMemoryCache cache.CacheBackend

BeforeEach(func() {
options := cache.InMemoryCacheOptions{
Expand All @@ -435,7 +429,7 @@ development:

It("should implement CacheBackend interface", func() {
// Check that the concrete type implements the interface
var _ cache.CacheBackend = inMemoryCache
_ = inMemoryCache
Expect(inMemoryCache).NotTo(BeNil())
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func TestInMemoryCacheIntegration(t *testing.T) {

// Step 3: Access first entry multiple times to increase its frequency
for range 2 {
responseBody, found, err := cache.FindSimilar("test-model", "Hello world")
if err != nil {
t.Logf("FindSimilar failed (expected due to high threshold): %v", err)
responseBody, found, findErr := cache.FindSimilar("test-model", "Hello world")
if findErr != nil {
t.Logf("FindSimilar failed (expected due to high threshold): %v", findErr)
}
if !found {
t.Errorf("Expected to find similar entry for first query")
Expand Down
Loading
Loading