From d6c6b20b8883dfbb8ca5d144487bb19d86ec9b6f Mon Sep 17 00:00:00 2001 From: Dan Piths <85949566+danpiths@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:01:51 +0530 Subject: [PATCH] feat: add support for per-model and per-provider level budgeting and rate limiting in governance plugin --- framework/configstore/clientconfig.go | 14 +- framework/configstore/migrations.go | 116 ++- framework/configstore/rdb.go | 218 +++++- framework/configstore/store.go | 13 + framework/configstore/tables/modelconfig.go | 59 ++ framework/configstore/tables/provider.go | 17 + plugins/governance/main.go | 50 +- plugins/governance/resolver.go | 83 +- plugins/governance/store.go | 822 +++++++++++++++++++- plugins/governance/tracker.go | 39 +- 10 files changed, 1353 insertions(+), 78 deletions(-) create mode 100644 framework/configstore/tables/modelconfig.go diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 047676d36..9efebb987 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -725,10 +725,12 @@ type AuthConfig struct { type ConfigMap map[schemas.ModelProvider]ProviderConfig type GovernanceConfig struct { - VirtualKeys []tables.TableVirtualKey `json:"virtual_keys"` - Teams []tables.TableTeam `json:"teams"` - Customers []tables.TableCustomer `json:"customers"` - Budgets []tables.TableBudget `json:"budgets"` - RateLimits []tables.TableRateLimit `json:"rate_limits"` - AuthConfig *AuthConfig `json:"auth_config,omitempty"` + VirtualKeys []tables.TableVirtualKey `json:"virtual_keys"` + Teams []tables.TableTeam `json:"teams"` + Customers []tables.TableCustomer `json:"customers"` + Budgets []tables.TableBudget `json:"budgets"` + RateLimits []tables.TableRateLimit `json:"rate_limits"` + ModelConfigs []tables.TableModelConfig `json:"model_configs"` + Providers []tables.TableProvider `json:"providers"` + AuthConfig *AuthConfig `json:"auth_config,omitempty"` } diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 4ac808636..3a7d76ad5 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -128,6 +128,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddUseForBatchAPIColumnAndS3BucketsConfig(ctx, db); err != nil { return err } + if err := migrationAddModelConfigTable(ctx, db); err != nil { + return err + } + if err := migrationAddProviderGovernanceColumns(ctx, db); err != nil { + return err + } return nil } @@ -1201,7 +1207,6 @@ func migrationAddEnabledColumnToKeyTable(ctx context.Context, db *gorm.DB) error if err := mg.AddColumn(&tables.TableKey{}, "enabled"); err != nil { return fmt.Errorf("failed to add enabled column: %w", err) } - } // Set default = true for existing rows if err := tx.Exec("UPDATE config_keys SET enabled = TRUE WHERE enabled IS NULL").Error; err != nil { @@ -2114,3 +2119,112 @@ func migrationAddUseForBatchAPIColumnAndS3BucketsConfig(ctx context.Context, db } return nil } + +// migrationAddModelConfigTable adds the governance_model_configs table +func migrationAddModelConfigTable(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_model_config_table", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasTable(&tables.TableModelConfig{}) { + if err := migrator.CreateTable(&tables.TableModelConfig{}); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropTable(&tables.TableModelConfig{}); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running add model config table migration: %s", err.Error()) + } + return nil +} + +// migrationAddProviderGovernanceColumns adds budget_id and rate_limit_id columns to config_providers table +func migrationAddProviderGovernanceColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_provider_governance_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + provider := &tables.TableProvider{} + + // Add budget_id column if it doesn't exist + if !migrator.HasColumn(provider, "budget_id") { + if err := migrator.AddColumn(provider, "budget_id"); err != nil { + return fmt.Errorf("failed to add budget_id column: %w", err) + } + // Create index for budget_id + if !migrator.HasIndex(provider, "idx_provider_budget") { + if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_provider_budget ON config_providers (budget_id)").Error; err != nil { + return fmt.Errorf("failed to create budget_id index: %w", err) + } + } + } + + // Add rate_limit_id column if it doesn't exist + if !migrator.HasColumn(provider, "rate_limit_id") { + if err := migrator.AddColumn(provider, "rate_limit_id"); err != nil { + return fmt.Errorf("failed to add rate_limit_id column: %w", err) + } + // Create index for rate_limit_id + if !migrator.HasIndex(provider, "idx_provider_rate_limit") { + if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_provider_rate_limit ON config_providers (rate_limit_id)").Error; err != nil { + return fmt.Errorf("failed to create rate_limit_id index: %w", err) + } + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + provider := &tables.TableProvider{} + + // Drop indexes first + if migrator.HasIndex(provider, "idx_provider_rate_limit") { + if err := tx.Exec("DROP INDEX IF EXISTS idx_provider_rate_limit").Error; err != nil { + return fmt.Errorf("failed to drop rate_limit_id index: %w", err) + } + } + + if migrator.HasIndex(provider, "idx_provider_budget") { + if err := tx.Exec("DROP INDEX IF EXISTS idx_provider_budget").Error; err != nil { + return fmt.Errorf("failed to drop budget_id index: %w", err) + } + } + + // Drop rate_limit_id column if it exists + if migrator.HasColumn(provider, "rate_limit_id") { + if err := migrator.DropColumn(provider, "rate_limit_id"); err != nil { + return fmt.Errorf("failed to drop rate_limit_id column: %w", err) + } + } + + // Drop budget_id column if it exists + if migrator.HasColumn(provider, "budget_id") { + if err := migrator.DropColumn(provider, "budget_id"); err != nil { + return fmt.Errorf("failed to drop budget_id column: %w", err) + } + } + + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running add provider governance columns migration: %s", err.Error()) + } + return nil +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index e420cf328..0a4d8dc40 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -576,6 +576,9 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo return err } + // Store the budget and rate limit IDs before deleting + budgetID := dbProvider.BudgetID + rateLimitID := dbProvider.RateLimitID // Delete the provider first (keys will be deleted due to CASCADE constraint) if err := txDB.WithContext(ctx).Delete(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -584,6 +587,19 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo return err } + // Delete the budget if it exists + if budgetID != nil { + if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + // Delete the rate limit if it exists + if rateLimitID != nil { + if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { + return err + } + } + return nil } @@ -700,6 +716,30 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo return processedProviders, nil } +// GetProviders retrieves all providers from the database with their governance relationships. +func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) { + var providers []tables.TableProvider + if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&providers).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return providers, nil +} + +// GetProviderByName retrieves a provider by name from the database with governance relationships. +func (s *RDBConfigStore) GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) { + var provider tables.TableProvider + if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Where("name = ?", name).First(&provider).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &provider, nil +} + // GetMCPConfig retrieves the MCP configuration from the database. func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient @@ -1959,6 +1999,160 @@ func (s *RDBConfigStore) UpdateBudget(ctx context.Context, budget *tables.TableB return nil } +// UpdateBudgetUsage updates only the current_usage field of a budget. +// Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. +func (s *RDBConfigStore) UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error { + result := s.db.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&tables.TableBudget{}). + Where("id = ?", id). + Update("current_usage", currentUsage) + if result.Error != nil { + return s.parseGormError(result.Error) + } + return nil +} + +// UpdateRateLimitUsage updates only the usage fields of a rate limit. +// Uses SkipHooks to avoid triggering BeforeSave validation since we're only updating usage. +func (s *RDBConfigStore) UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error { + result := s.db.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&tables.TableRateLimit{}). + Where("id = ?", id). + Updates(map[string]interface{}{ + "token_current_usage": tokenCurrentUsage, + "request_current_usage": requestCurrentUsage, + }) + if result.Error != nil { + return s.parseGormError(result.Error) + } + return nil +} + +// GetModelConfigs retrieves all model configs from the database. +func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) { + var modelConfigs []tables.TableModelConfig + if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").Find(&modelConfigs).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return modelConfigs, nil +} + +// GetModelConfig retrieves a specific model config from the database by model name and optional provider. +func (s *RDBConfigStore) GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) { + var modelConfig tables.TableModelConfig + query := s.db.WithContext(ctx).Where("model_name = ?", modelName) + if provider != nil { + query = query.Where("provider = ?", *provider) + } else { + query = query.Where("provider IS NULL") + } + if err := query.Preload("Budget").Preload("RateLimit").First(&modelConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &modelConfig, nil +} + +// GetModelConfigByID retrieves a specific model config from the database by ID. +func (s *RDBConfigStore) GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) { + var modelConfig tables.TableModelConfig + if err := s.db.WithContext(ctx).Preload("Budget").Preload("RateLimit").First(&modelConfig, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &modelConfig, nil +} + +// CreateModelConfig creates a new model config in the database. +func (s *RDBConfigStore) CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + if err := txDB.WithContext(ctx).Create(modelConfig).Error; err != nil { + return s.parseGormError(err) + } + return nil +} + +// UpdateModelConfig updates a model config in the database. +func (s *RDBConfigStore) UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + if err := txDB.WithContext(ctx).Save(modelConfig).Error; err != nil { + return s.parseGormError(err) + } + return nil +} + +// UpdateModelConfigs updates multiple model configs in the database. +func (s *RDBConfigStore) UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + for _, mc := range modelConfigs { + if err := txDB.WithContext(ctx).Save(mc).Error; err != nil { + return s.parseGormError(err) + } + } + return nil +} + +// DeleteModelConfig deletes a model config from the database. +func (s *RDBConfigStore) DeleteModelConfig(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // First fetch the model config to get budget and rate limit IDs + var modelConfig tables.TableModelConfig + if err := tx.First(&modelConfig, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Store the budget and rate limit IDs before deleting + budgetID := modelConfig.BudgetID + rateLimitID := modelConfig.RateLimitID + // Delete the model config first + if err := tx.Delete(&tables.TableModelConfig{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return s.parseGormError(err) + } + // Delete the budget if it exists + if budgetID != nil { + if err := tx.Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + // Delete the rate limit if it exists + if rateLimitID != nil { + if err := tx.Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { + return err + } + } + return nil + }) +} + // GetGovernanceConfig retrieves the governance configuration from the database. func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error) { var virtualKeys []tables.TableVirtualKey @@ -1966,6 +2160,8 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo var customers []tables.TableCustomer var budgets []tables.TableBudget var rateLimits []tables.TableRateLimit + var modelConfigs []tables.TableModelConfig + var providers []tables.TableProvider var governanceConfigs []tables.TableGovernanceConfig if err := s.db.WithContext(ctx).Preload("ProviderConfigs").Find(&virtualKeys).Error; err != nil { @@ -1983,12 +2179,18 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo if err := s.db.WithContext(ctx).Find(&rateLimits).Error; err != nil { return nil, err } + if err := s.db.WithContext(ctx).Find(&modelConfigs).Error; err != nil { + return nil, err + } + if err := s.db.WithContext(ctx).Find(&providers).Error; err != nil { + return nil, err + } // Fetching governance config for username and password if err := s.db.WithContext(ctx).Find(&governanceConfigs).Error; err != nil { return nil, err } // Check if any config is present - if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(governanceConfigs) == 0 { + if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(modelConfigs) == 0 && len(providers) == 0 && len(governanceConfigs) == 0 { return nil, nil } var authConfig *AuthConfig @@ -2016,12 +2218,14 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo } } return &GovernanceConfig{ - VirtualKeys: virtualKeys, - Teams: teams, - Customers: customers, - Budgets: budgets, - RateLimits: rateLimits, - AuthConfig: authConfig, + VirtualKeys: virtualKeys, + Teams: teams, + Customers: customers, + Budgets: budgets, + RateLimits: rateLimits, + ModelConfigs: modelConfigs, + Providers: providers, + AuthConfig: authConfig, }, nil } diff --git a/framework/configstore/store.go b/framework/configstore/store.go index ade2d3a78..fe965643c 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -33,6 +33,8 @@ type ConfigStore interface { UpdateProvider(ctx context.Context, provider schemas.ModelProvider, config ProviderConfig, envKeys map[string][]EnvKeyInfo, tx ...*gorm.DB) error DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) + GetProviders(ctx context.Context) ([]tables.TableProvider, error) + GetProviderByName(ctx context.Context, name string) (*tables.TableProvider, error) // MCP config CRUD GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) @@ -113,6 +115,17 @@ type ConfigStore interface { CreateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error UpdateBudget(ctx context.Context, budget *tables.TableBudget, tx ...*gorm.DB) error UpdateBudgets(ctx context.Context, budgets []*tables.TableBudget, tx ...*gorm.DB) error + UpdateBudgetUsage(ctx context.Context, id string, currentUsage float64) error + UpdateRateLimitUsage(ctx context.Context, id string, tokenCurrentUsage int64, requestCurrentUsage int64) error + + // Model config CRUD + GetModelConfigs(ctx context.Context) ([]tables.TableModelConfig, error) + GetModelConfig(ctx context.Context, modelName string, provider *string) (*tables.TableModelConfig, error) + GetModelConfigByID(ctx context.Context, id string) (*tables.TableModelConfig, error) + CreateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error + UpdateModelConfig(ctx context.Context, modelConfig *tables.TableModelConfig, tx ...*gorm.DB) error + UpdateModelConfigs(ctx context.Context, modelConfigs []*tables.TableModelConfig, tx ...*gorm.DB) error + DeleteModelConfig(ctx context.Context, id string) error // Governance config CRUD GetGovernanceConfig(ctx context.Context) (*GovernanceConfig, error) diff --git a/framework/configstore/tables/modelconfig.go b/framework/configstore/tables/modelconfig.go new file mode 100644 index 000000000..5e6b5ba6d --- /dev/null +++ b/framework/configstore/tables/modelconfig.go @@ -0,0 +1,59 @@ +package tables + +import ( + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +// TableModelConfig represents a model configuration with rate limiting and budgeting +type TableModelConfig struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + ModelName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_model_provider" json:"model_name"` + Provider *string `gorm:"type:varchar(50);uniqueIndex:idx_model_provider" json:"provider,omitempty"` // Optional provider, nullable + BudgetID *string `gorm:"type:varchar(255);index:idx_model_config_budget" json:"budget_id,omitempty"` + RateLimitID *string `gorm:"type:varchar(255);index:idx_model_config_rate_limit" json:"rate_limit_id,omitempty"` + + // Relationships + Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"` + RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"` + + // Config hash is used to detect the changes synced from config.json file + // Every time we sync the config.json file, we will update the config hash + ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` + + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableName sets the table name for each model +func (TableModelConfig) TableName() string { + return "governance_model_configs" +} + +// BeforeSave hook for ModelConfig to validate required fields +func (mc *TableModelConfig) BeforeSave(tx *gorm.DB) error { + // Validate that ModelName is not empty + if strings.TrimSpace(mc.ModelName) == "" { + return fmt.Errorf("model_name cannot be empty") + } + + // Validate that if BudgetID is provided, it's not an empty string + if mc.BudgetID != nil && strings.TrimSpace(*mc.BudgetID) == "" { + return fmt.Errorf("budget_id cannot be an empty string") + } + + // Validate that if RateLimitID is provided, it's not an empty string + if mc.RateLimitID != nil && strings.TrimSpace(*mc.RateLimitID) == "" { + return fmt.Errorf("rate_limit_id cannot be an empty string") + } + + // Validate that if Provider is provided, it's not an empty string + if mc.Provider != nil && strings.TrimSpace(*mc.Provider) == "" { + return fmt.Errorf("provider cannot be an empty string") + } + + return nil +} diff --git a/framework/configstore/tables/provider.go b/framework/configstore/tables/provider.go index b85af6cb5..1a95c26d6 100644 --- a/framework/configstore/tables/provider.go +++ b/framework/configstore/tables/provider.go @@ -38,6 +38,14 @@ type TableProvider struct { // Foreign keys Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"` + // Governance fields - Budget and Rate Limit for provider-level governance + BudgetID *string `gorm:"type:varchar(255);index:idx_provider_budget" json:"budget_id,omitempty"` + RateLimitID *string `gorm:"type:varchar(255);index:idx_provider_rate_limit" json:"rate_limit_id,omitempty"` + + // Governance relationships + Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"` + RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"` + // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` @@ -79,6 +87,15 @@ func (p *TableProvider) BeforeSave(tx *gorm.DB) error { } p.CustomProviderConfigJSON = string(data) } + + // Validate governance fields + if p.BudgetID != nil && *p.BudgetID == "" { + return fmt.Errorf("budget_id cannot be an empty string") + } + if p.RateLimitID != nil && *p.RateLimitID == "" { + return fmt.Errorf("rate_limit_id cannot be an empty string") + } + return nil } diff --git a/plugins/governance/main.go b/plugins/governance/main.go index e53631cf3..0a11895f8 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -41,7 +41,7 @@ type InMemoryStore interface { type BaseGovernancePlugin interface { GetName() string - TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) PostHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) Cleanup() error @@ -474,6 +474,9 @@ func (p *GovernancePlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif // Extract governance headers and virtual key using utility functions virtualKeyValue := getStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) requestID := getStringFromContext(ctx, schemas.BifrostContextKeyRequestID) + provider, model, _ := req.GetRequestFields() + + // Check if virtual key is mandatory when none is provided if virtualKeyValue == "" { if p.isVkMandatory != nil && *p.isVkMandatory { return req, &schemas.PluginShortCircuit{ @@ -485,23 +488,18 @@ func (p *GovernancePlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif }, }, }, nil - } else { - return req, nil, nil } } - provider, model, _ := req.GetRequestFields() + // First evaluate model and provider checks (applies even when virtual keys are disabled or not present) + result := p.resolver.EvaluateModelAndProviderRequest(ctx, provider, model, requestID) - // Create request context for evaluation - evaluationRequest := &EvaluationRequest{ - VirtualKey: virtualKeyValue, - Provider: provider, - Model: model, - RequestID: requestID, + // If model/provider checks passed and virtual key exists, evaluate virtual key checks + // This will overwrite the result with virtual key-specific decision + if result.Decision == DecisionAllow && virtualKeyValue != "" { + result = p.resolver.EvaluateVirtualKeyRequest(ctx, virtualKeyValue, provider, model, requestID) } - - // Use resolver to make governance decision (pure decision engine) - result := p.resolver.EvaluateRequest(ctx, evaluationRequest) + // If model/provider checks failed, skip virtual key evaluation and proceed to final decision handling if result.Decision != DecisionAllow { if ctx != nil { @@ -581,11 +579,6 @@ func (p *GovernancePlugin) PostHook(ctx *schemas.BifrostContext, result *schemas virtualKey := getStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) requestID := getStringFromContext(ctx, schemas.BifrostContextKeyRequestID) - // Skip if no virtual key - if virtualKey == "" { - return result, err, nil - } - // Extract request type, provider, and model requestType, provider, model := bifrost.GetResponseFields(result, err) @@ -603,11 +596,17 @@ func (p *GovernancePlugin) PostHook(ctx *schemas.BifrostContext, result *schemas } } - p.wg.Add(1) - go func() { - defer p.wg.Done() - p.postHookWorker(result, provider, model, requestType, virtualKey, requestID, isCacheRead, isBatch, bifrost.IsFinalChunk(ctx)) - }() + // Always process usage tracking (with or without virtual key) + // If virtualKey is empty, it will be passed as empty string to postHookWorker + // The tracker will handle empty virtual keys gracefully by only updating provider-level and model-level usage + if model != "" { + p.wg.Add(1) + go func() { + defer p.wg.Done() + // Pass virtualKey (empty string if not present) - tracker handles this case + p.postHookWorker(result, provider, model, requestType, virtualKey, requestID, isCacheRead, isBatch, bifrost.IsFinalChunk(ctx)) + }() + } return result, err, nil } @@ -627,12 +626,14 @@ func (p *GovernancePlugin) Cleanup() error { // postHookWorker is a worker function that processes the response and updates usage tracking // It is used to avoid blocking the main thread when updating usage tracking +// Handles both cases: with virtual key and without virtual key (empty string) +// When virtualKey is empty, the tracker will only update provider-level and model-level usage // Parameters: // - result: The Bifrost response to be processed // - provider: The provider of the request // - model: The model of the request // - requestType: The type of the request -// - virtualKey: The virtual key of the request +// - virtualKey: The virtual key of the request (empty string if not present) // - requestID: The request ID // - isCacheRead: Whether the request is a cache read // - isBatch: Whether the request is a batch request @@ -687,6 +688,7 @@ func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provi } // Queue usage update asynchronously using tracker + // UpdateUsage handles empty virtual keys gracefully by only updating provider-level and model-level usage p.tracker.UpdateUsage(p.ctx, usageUpdate) } } diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index e37f92a97..edde360a0 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -74,10 +74,60 @@ func NewBudgetResolver(store GovernanceStore, logger schemas.Logger) *BudgetReso } } -// EvaluateRequest evaluates a request against the new hierarchical governance system -func (r *BudgetResolver) EvaluateRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest) *EvaluationResult { +// EvaluateModelAndProviderRequest evaluates provider-level and model-level rate limits and budgets +// This applies even when virtual keys are disabled or not present +func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, requestID string) *EvaluationResult { + // 1. Check provider-level rate limits FIRST (before model-level checks) + if provider != "" { + if err, decision := r.store.CheckProviderRateLimit(ctx, provider, requestID, nil, nil); err != nil { + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Provider-level rate limit check failed: %s", err.Error()), + } + } + + // 2. Check provider-level budgets FIRST (before model-level checks) + if err := r.store.CheckProviderBudget(ctx, provider, nil); err != nil { + return &EvaluationResult{ + Decision: DecisionBudgetExceeded, + Reason: fmt.Sprintf("Provider-level budget exceeded: %s", err.Error()), + } + } + } + + // 3. Check model-level rate limits (after provider-level checks) + if model != "" { + var providerPtr *schemas.ModelProvider + if provider != "" { + providerPtr = &provider + } + if err, decision := r.store.CheckModelRateLimit(ctx, model, providerPtr, requestID, nil, nil); err != nil { + return &EvaluationResult{ + Decision: decision, + Reason: fmt.Sprintf("Model-level rate limit check failed: %s", err.Error()), + } + } + + // 4. Check model-level budgets (after provider-level checks) + if err := r.store.CheckModelBudget(ctx, model, providerPtr, nil); err != nil { + return &EvaluationResult{ + Decision: DecisionBudgetExceeded, + Reason: fmt.Sprintf("Model-level budget exceeded: %s", err.Error()), + } + } + } + + // All provider-level and model-level checks passed + return &EvaluationResult{ + Decision: DecisionAllow, + Reason: "Request allowed by governance policy (provider-level and model-level checks passed)", + } +} + +// EvaluateVirtualKeyRequest evaluates virtual key-specific checks including validation, filtering, rate limits, and budgets +func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestID string) *EvaluationResult { // 1. Validate virtual key exists and is active - vk, exists := r.store.GetVirtualKey(evaluationRequest.VirtualKey) + vk, exists := r.store.GetVirtualKey(virtualKeyValue) if !exists { return &EvaluationResult{ Decision: DecisionVirtualKeyNotFound, @@ -109,25 +159,32 @@ func (r *BudgetResolver) EvaluateRequest(ctx *schemas.BifrostContext, evaluation } // 2. Check provider filtering - if !r.isProviderAllowed(vk, evaluationRequest.Provider) { + if !r.isProviderAllowed(vk, provider) { return &EvaluationResult{ Decision: DecisionProviderBlocked, - Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", evaluationRequest.Provider), + Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", provider), VirtualKey: vk, } } // 3. Check model filtering - if !r.isModelAllowed(vk, evaluationRequest.Provider, evaluationRequest.Model) { + if !r.isModelAllowed(vk, provider, model) { return &EvaluationResult{ Decision: DecisionModelBlocked, - Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", evaluationRequest.Model), + Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", model), VirtualKey: vk, } } - // 4. Check rate limits hierarchy (Provider level first, then VK level) - if rateLimitResult := r.checkRateLimitHierarchy(ctx, vk, string(evaluationRequest.Provider), evaluationRequest.Model, evaluationRequest.RequestID); rateLimitResult != nil { + evaluationRequest := &EvaluationRequest{ + VirtualKey: virtualKeyValue, + Provider: provider, + Model: model, + RequestID: requestID, + } + + // 4. Check rate limits hierarchy (VK level) + if rateLimitResult := r.checkRateLimitHierarchy(ctx, vk, evaluationRequest); rateLimitResult != nil { return rateLimitResult } @@ -138,7 +195,7 @@ func (r *BudgetResolver) EvaluateRequest(ctx *schemas.BifrostContext, evaluation // Find the provider config that matches the request's provider and get its allowed keys for _, pc := range vk.ProviderConfigs { - if schemas.ModelProvider(pc.Provider) == evaluationRequest.Provider && len(pc.Keys) > 0 { + if schemas.ModelProvider(pc.Provider) == provider && len(pc.Keys) > 0 { includeOnlyKeys := make([]string, 0, len(pc.Keys)) for _, dbKey := range pc.Keys { includeOnlyKeys = append(includeOnlyKeys, dbKey.KeyID) @@ -192,12 +249,12 @@ func (r *BudgetResolver) isProviderAllowed(vk *configstoreTables.TableVirtualKey } // checkRateLimitHierarchy checks provider-level rate limits first, then VK rate limits using flexible approach -func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider string, model string, requestID string) *EvaluationResult { - if decision, err := r.store.CheckRateLimit(ctx, vk, schemas.ModelProvider(provider), model, requestID, nil, nil); err != nil { +func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest) *EvaluationResult { + if decision, err := r.store.CheckRateLimit(ctx, vk, schemas.ModelProvider(request.Provider), request.Model, request.RequestID, nil, nil); err != nil { // Check provider-level first (matching check order), then VK-level var rateLimitInfo *configstoreTables.TableRateLimit for _, pc := range vk.ProviderConfigs { - if pc.Provider == provider && pc.RateLimit != nil { + if pc.Provider == string(request.Provider) && pc.RateLimit != nil { rateLimitInfo = pc.RateLimit break } diff --git a/plugins/governance/store.go b/plugins/governance/store.go index c3ef36fe3..da553347d 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -17,11 +17,13 @@ import ( // LocalGovernanceStore provides in-memory cache for governance data with fast, non-blocking access type LocalGovernanceStore struct { // Core data maps using sync.Map for lock-free reads - virtualKeys sync.Map // string -> *VirtualKey (VK value -> VirtualKey with preloaded relationships) - teams sync.Map // string -> *Team (Team ID -> Team) - customers sync.Map // string -> *Customer (Customer ID -> Customer) - budgets sync.Map // string -> *Budget (Budget ID -> Budget) - rateLimits sync.Map // string -> *RateLimit (RateLimit ID -> RateLimit) + virtualKeys sync.Map // string -> *VirtualKey (VK value -> VirtualKey with preloaded relationships) + teams sync.Map // string -> *Team (Team ID -> Team) + customers sync.Map // string -> *Customer (Customer ID -> Customer) + budgets sync.Map // string -> *Budget (Budget ID -> Budget) + rateLimits sync.Map // string -> *RateLimit (RateLimit ID -> RateLimit) + modelConfigs sync.Map // string -> *ModelConfig (key: "modelName" or "modelName:provider" -> ModelConfig) + providers sync.Map // string -> *Provider (Provider name -> Provider with preloaded relationships) // Config store for refresh operations configStore configstore.ConfigStore @@ -50,16 +52,31 @@ type GovernanceData struct { type GovernanceStore interface { GetGovernanceData() *GovernanceData GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) + // Provider-level governance checks + CheckProviderBudget(ctx context.Context, provider schemas.ModelProvider, baselines map[string]float64) error + CheckProviderRateLimit(ctx context.Context, provider schemas.ModelProvider, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) + // Model-level governance checks + CheckModelBudget(ctx context.Context, model string, provider *schemas.ModelProvider, baselines map[string]float64) error + CheckModelRateLimit(ctx context.Context, model string, provider *schemas.ModelProvider, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) + // VK-level governance checks CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) - UpdateBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error - UpdateRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error + // In-memory usage updates (for VK-level) + UpdateVirtualKeyBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error + UpdateVirtualKeyRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error + // In-memory reset checks (return items that need DB sync) ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget + // DB sync for expired items ResetExpiredRateLimits(ctx context.Context, resetRateLimits []*configstoreTables.TableRateLimit) error ResetExpiredBudgets(ctx context.Context, resetBudgets []*configstoreTables.TableBudget) error + // Provider and model-level usage updates (combined) + UpdateProviderAndModelBudgetUsageInMemory(ctx context.Context, model string, provider schemas.ModelProvider, cost float64) error + UpdateProviderAndModelRateLimitUsageInMemory(ctx context.Context, model string, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error + // Dump operations DumpRateLimits(ctx context.Context, tokenBaselines map[string]int64, requestBaselines map[string]int64) error DumpBudgets(ctx context.Context, baselines map[string]float64) error + // In-memory CRUD operations CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) DeleteVirtualKeyInMemory(vkID string) @@ -69,6 +86,12 @@ type GovernanceStore interface { CreateCustomerInMemory(customer *configstoreTables.TableCustomer) UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) DeleteCustomerInMemory(customerID string) + // Model config in-memory operations + UpdateModelConfigInMemory(mc *configstoreTables.TableModelConfig) *configstoreTables.TableModelConfig + DeleteModelConfigInMemory(mcID string) + // Provider in-memory operations + UpdateProviderInMemory(provider *configstoreTables.TableProvider) *configstoreTables.TableProvider + DeleteProviderInMemory(providerName string) } // NewLocalGovernanceStore creates a new in-memory governance store @@ -216,6 +239,370 @@ func (gs *LocalGovernanceStore) CheckBudget(ctx context.Context, vk *configstore return nil } +// CheckProviderBudget performs budget checking for provider-level configs (lock-free for high performance) +func (gs *LocalGovernanceStore) CheckProviderBudget(ctx context.Context, provider schemas.ModelProvider, baselines map[string]float64) error { + // This is to prevent nil pointer dereference + if baselines == nil { + baselines = map[string]float64{} + } + + // Get provider config + providerKey := string(provider) + value, exists := gs.providers.Load(providerKey) + if !exists || value == nil { + // No provider config found, allow request + return nil + } + + providerTable, ok := value.(*configstoreTables.TableProvider) + if !ok || providerTable == nil || providerTable.BudgetID == nil { + // No budget configured for provider, allow request + return nil + } + + // Read from budgets map to get the latest updated budget (same source as UpdateProviderBudgetUsage) + budgetValue, exists := gs.budgets.Load(*providerTable.BudgetID) + if !exists || budgetValue == nil { + // Budget not found in cache, allow request + return nil + } + + budget, ok := budgetValue.(*configstoreTables.TableBudget) + if !ok || budget == nil { + // Invalid budget type, allow request + return nil + } + + // Check if budget needs reset (in-memory check) + if budget.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { + if time.Since(budget.LastReset) >= duration { + // Budget expired but hasn't been reset yet - treat as reset + return nil // Skip budget check for expired budgets + } + } + } + + baseline, exists := baselines[budget.ID] + if !exists { + baseline = 0 + } + + // Check if current usage (local + remote baseline) exceeds budget limit + if budget.CurrentUsage+baseline >= budget.MaxLimit { + return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", + providerKey, budget.CurrentUsage+baseline, budget.MaxLimit) + } + + return nil +} + +// CheckProviderRateLimit checks provider-level rate limits and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckProviderRateLimit(ctx context.Context, provider schemas.ModelProvider, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) { + var violations []string + + // This is to prevent nil pointer dereference + if tokensBaselines == nil { + tokensBaselines = map[string]int64{} + } + if requestsBaselines == nil { + requestsBaselines = map[string]int64{} + } + + // Get provider config + providerKey := string(provider) + value, exists := gs.providers.Load(providerKey) + if !exists || value == nil { + // No provider config found, allow request + return nil, DecisionAllow + } + + providerTable, ok := value.(*configstoreTables.TableProvider) + if !ok || providerTable == nil || providerTable.RateLimitID == nil { + // No rate limit configured for provider, allow request + return nil, DecisionAllow + } + + // Read from rateLimits map to get the latest updated rate limit (same source as UpdateProviderRateLimitUsage) + rateLimitValue, exists := gs.rateLimits.Load(*providerTable.RateLimitID) + if !exists || rateLimitValue == nil { + // Rate limit not found in cache, allow request + return nil, DecisionAllow + } + + rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit) + if !ok || rateLimit == nil { + // Invalid rate limit type, allow request + return nil, DecisionAllow + } + + // Check if rate limit needs reset (in-memory check) + // Track which limits are expired so we can skip only those specific checks + tokenLimitExpired := false + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if time.Since(rateLimit.TokenLastReset) >= duration { + // Token rate limit expired but hasn't been reset yet - skip token check only + tokenLimitExpired = true + } + } + } + requestLimitExpired := false + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if time.Since(rateLimit.RequestLastReset) >= duration { + // Request rate limit expired but hasn't been reset yet - skip request check only + requestLimitExpired = true + } + } + } + + tokensBaseline, exists := tokensBaselines[rateLimit.ID] + if !exists { + tokensBaseline = 0 + } + requestsBaseline, exists := requestsBaselines[rateLimit.ID] + if !exists { + requestsBaseline = 0 + } + + // Token limits - check if total usage (local + remote baseline) exceeds limit + // Skip this check if token limit has expired + if !tokenLimitExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { + duration := "unknown" + if rateLimit.TokenResetDuration != nil { + duration = *rateLimit.TokenResetDuration + } + violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", + rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) + } + + // Request limits - check if total usage (local + remote baseline) exceeds limit + // Skip this check if request limit has expired + if !requestLimitExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { + duration := "unknown" + if rateLimit.RequestResetDuration != nil { + duration = *rateLimit.RequestResetDuration + } + violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", + rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) + } + + if len(violations) > 0 { + // Determine specific violation type + decision := DecisionRateLimited // Default to general rate limited decision + if len(violations) == 1 { + if strings.Contains(violations[0], "token") { + decision = DecisionTokenLimited // More specific violation type + } else if strings.Contains(violations[0], "request") { + decision = DecisionRequestLimited // More specific violation type + } + } + return fmt.Errorf("rate limit violated for %s: %s", providerKey, violations), decision + } + + return nil, DecisionAllow // No rate limit violations +} + +// CheckModelBudget performs budget checking for model-level configs (lock-free for high performance) +func (gs *LocalGovernanceStore) CheckModelBudget(ctx context.Context, model string, provider *schemas.ModelProvider, baselines map[string]float64) error { + // This is to prevent nil pointer dereference + if baselines == nil { + baselines = map[string]float64{} + } + + // Collect model configs to check: model+provider (if exists) AND model-only (if exists) + var modelConfigsToCheck []*configstoreTables.TableModelConfig + var budgetNames []string + + // Check model+provider config first (more specific) - if provider is provided + if provider != nil { + key := fmt.Sprintf("%s:%s", model, string(*provider)) + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.Budget != nil { + modelConfigsToCheck = append(modelConfigsToCheck, mc) + budgetNames = append(budgetNames, fmt.Sprintf("Model:%s:Provider:%s", model, string(*provider))) + } + } + } + + // Always check model-only config (if exists) - regardless of whether model+provider config exists + key := model + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.Budget != nil { + modelConfigsToCheck = append(modelConfigsToCheck, mc) + budgetNames = append(budgetNames, fmt.Sprintf("Model:%s", model)) + } + } + + // Check each model budget + for i, mc := range modelConfigsToCheck { + if mc.BudgetID == nil { + continue + } + + // Read from budgets map to get the latest updated budget (same source as UpdateModelBudgetUsage) + budgetValue, exists := gs.budgets.Load(*mc.BudgetID) + if !exists || budgetValue == nil { + // Budget not found in cache, skip check + continue + } + + budget, ok := budgetValue.(*configstoreTables.TableBudget) + if !ok || budget == nil { + // Invalid budget type, skip check + continue + } + + // Check if budget needs reset (in-memory check) + if budget.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { + if time.Since(budget.LastReset) >= duration { + // Budget expired but hasn't been reset yet - treat as reset + continue // Skip budget check for expired budgets + } + } + } + + baseline, exists := baselines[budget.ID] + if !exists { + baseline = 0 + } + + // Check if current usage (local + remote baseline) exceeds budget limit + if budget.CurrentUsage+baseline >= budget.MaxLimit { + return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", + budgetNames[i], budget.CurrentUsage+baseline, budget.MaxLimit) + } + } + + return nil +} + +// CheckModelRateLimit checks model-level rate limits and returns evaluation result if violated +func (gs *LocalGovernanceStore) CheckModelRateLimit(ctx context.Context, model string, provider *schemas.ModelProvider, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (error, Decision) { + var violations []string + + // This is to prevent nil pointer dereference + if tokensBaselines == nil { + tokensBaselines = map[string]int64{} + } + if requestsBaselines == nil { + requestsBaselines = map[string]int64{} + } + + // Collect model configs to check: model+provider (if exists) AND model-only (if exists) + var modelConfigsToCheck []*configstoreTables.TableModelConfig + var rateLimitNames []string + + // Check model+provider config first (more specific) - if provider is provided + if provider != nil { + key := fmt.Sprintf("%s:%s", model, string(*provider)) + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { + modelConfigsToCheck = append(modelConfigsToCheck, mc) + rateLimitNames = append(rateLimitNames, fmt.Sprintf("Model:%s:Provider:%s", model, string(*provider))) + } + } + } + + // Always check model-only config (if exists) - regardless of whether model+provider config exists + key := model + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { + modelConfigsToCheck = append(modelConfigsToCheck, mc) + rateLimitNames = append(rateLimitNames, fmt.Sprintf("Model:%s", model)) + } + } + + // Check each model rate limit + for i, mc := range modelConfigsToCheck { + if mc.RateLimitID == nil { + continue + } + + // Read from rateLimits map to get the latest updated rate limit (same source as UpdateModelRateLimitUsage) + rateLimitValue, exists := gs.rateLimits.Load(*mc.RateLimitID) + if !exists || rateLimitValue == nil { + // Rate limit not found in cache, skip check + continue + } + + rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit) + if !ok || rateLimit == nil { + // Invalid rate limit type, skip check + continue + } + + // Check if rate limit needs reset (in-memory check) + // Track which limits are expired so we can skip only those specific checks + tokenLimitExpired := false + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if time.Since(rateLimit.TokenLastReset) >= duration { + // Token rate limit expired but hasn't been reset yet - skip token check only + tokenLimitExpired = true + } + } + } + requestLimitExpired := false + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if time.Since(rateLimit.RequestLastReset) >= duration { + // Request rate limit expired but hasn't been reset yet - skip request check only + requestLimitExpired = true + } + } + } + + tokensBaseline, exists := tokensBaselines[rateLimit.ID] + if !exists { + tokensBaseline = 0 + } + requestsBaseline, exists := requestsBaselines[rateLimit.ID] + if !exists { + requestsBaseline = 0 + } + + // Token limits - check if total usage (local + remote baseline) exceeds limit + // Skip this check if token limit has expired + if !tokenLimitExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { + duration := "unknown" + if rateLimit.TokenResetDuration != nil { + duration = *rateLimit.TokenResetDuration + } + violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", + rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) + } + + // Request limits - check if total usage (local + remote baseline) exceeds limit + // Skip this check if request limit has expired + if !requestLimitExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { + duration := "unknown" + if rateLimit.RequestResetDuration != nil { + duration = *rateLimit.RequestResetDuration + } + violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", + rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) + } + + if len(violations) > 0 { + // Determine specific violation type + decision := DecisionRateLimited // Default to general rate limited decision + if len(violations) == 1 { + if strings.Contains(violations[0], "token") { + decision = DecisionTokenLimited // More specific violation type + } else if strings.Contains(violations[0], "request") { + decision = DecisionRequestLimited // More specific violation type + } + } + return fmt.Errorf("rate limit violated for %s: %s", rateLimitNames[i], violations), decision + } + } + + return nil, DecisionAllow // No rate limit violations +} + // CheckRateLimit checks a single rate limit and returns evaluation result if violated (true if violated, false if not) func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { var violations []string @@ -307,8 +694,8 @@ func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, vk *configst return DecisionAllow, nil // No rate limit violations } -// UpdateBudgetUsageInMemory performs atomic budget updates across the hierarchy (both in memory and in database) -func (gs *LocalGovernanceStore) UpdateBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error { +// UpdateVirtualKeyBudgetUsageInMemory performs atomic budget updates across the hierarchy (both in memory and in database) +func (gs *LocalGovernanceStore) UpdateVirtualKeyBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error { if vk == nil { return fmt.Errorf("virtual key cannot be nil") } @@ -330,7 +717,7 @@ func (gs *LocalGovernanceStore) UpdateBudgetUsageInMemory(ctx context.Context, v if now.Sub(clone.LastReset) >= duration { clone.CurrentUsage = 0 clone.LastReset = now - gs.logger.Debug("UpdateBudgetUsage: Budget %s was reset (expired, duration: %v)", budgetID, duration) + gs.logger.Debug("UpdateVirtualKeyBudgetUsageInMemory: Budget %s was reset (expired, duration: %v)", budgetID, duration) } } } @@ -338,18 +725,145 @@ func (gs *LocalGovernanceStore) UpdateBudgetUsageInMemory(ctx context.Context, v // Update the clone clone.CurrentUsage += cost gs.budgets.Store(budgetID, &clone) - gs.logger.Debug("UpdateBudgetUsage: Updated budget %s: %.4f -> %.4f (added %.4f)", + gs.logger.Debug("UpdateVirtualKeyBudgetUsageInMemory: Updated budget %s: %.4f -> %.4f (added %.4f)", budgetID, oldUsage, clone.CurrentUsage, cost) } } else { - gs.logger.Warn("UpdateBudgetUsage: Budget %s not found in local store", budgetID) + gs.logger.Warn("UpdateVirtualKeyBudgetUsageInMemory: Budget %s not found in local store", budgetID) } } return nil } -// UpdateRateLimitUsageInMemory updates rate limit counters for both provider-level and VK-level rate limits (lock-free) -func (gs *LocalGovernanceStore) UpdateRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { +// UpdateProviderAndModelBudgetUsageInMemory performs atomic budget updates for both provider-level and model-level configs (in memory) +func (gs *LocalGovernanceStore) UpdateProviderAndModelBudgetUsageInMemory(ctx context.Context, model string, provider schemas.ModelProvider, cost float64) error { + now := time.Now() + + // Helper function to update a budget by ID + updateBudget := func(budgetID string) { + if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { + if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { + // Clone FIRST to avoid race conditions + clone := *cachedBudget + // Check if budget needs reset (in-memory check) - operate on clone + if clone.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { + if now.Sub(clone.LastReset) >= duration { + clone.CurrentUsage = 0 + clone.LastReset = now + } + } + } + // Update the clone + clone.CurrentUsage += cost + gs.budgets.Store(budgetID, &clone) + } + } + } + + // 1. Update provider-level budget (if provider is set) + if provider != "" { + providerKey := string(provider) + if value, exists := gs.providers.Load(providerKey); exists && value != nil { + if providerTable, ok := value.(*configstoreTables.TableProvider); ok && providerTable != nil && providerTable.BudgetID != nil { + updateBudget(*providerTable.BudgetID) + } + } + } + + // 2. Update model-level budgets + // Check model+provider config first (more specific) - if provider is provided + if provider != "" { + key := fmt.Sprintf("%s:%s", model, string(provider)) + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.BudgetID != nil { + updateBudget(*mc.BudgetID) + } + } + } + + // Always check model-only config (if exists) - regardless of whether model+provider config exists + if value, exists := gs.modelConfigs.Load(model); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.BudgetID != nil { + updateBudget(*mc.BudgetID) + } + } + + return nil +} + +// UpdateProviderAndModelRateLimitUsageInMemory updates rate limit counters for both provider-level and model-level rate limits (lock-free) +func (gs *LocalGovernanceStore) UpdateProviderAndModelRateLimitUsageInMemory(ctx context.Context, model string, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { + now := time.Now() + + // Helper function to update a rate limit by ID + updateRateLimit := func(rateLimitID string) { + if cachedRateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && cachedRateLimitValue != nil { + if cachedRateLimit, ok := cachedRateLimitValue.(*configstoreTables.TableRateLimit); ok && cachedRateLimit != nil { + // Clone FIRST to avoid race conditions + clone := *cachedRateLimit + // Check if rate limit needs reset (in-memory check) - operate on clone + if clone.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { + if now.Sub(clone.TokenLastReset) >= duration { + clone.TokenCurrentUsage = 0 + clone.TokenLastReset = now + } + } + } + if clone.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { + if now.Sub(clone.RequestLastReset) >= duration { + clone.RequestCurrentUsage = 0 + clone.RequestLastReset = now + } + } + } + // Update the clone + if shouldUpdateTokens { + clone.TokenCurrentUsage += tokensUsed + } + if shouldUpdateRequests { + clone.RequestCurrentUsage += 1 + } + gs.rateLimits.Store(rateLimitID, &clone) + } + } + } + + // 1. Update provider-level rate limit (if provider is set) + if provider != "" { + providerKey := string(provider) + if value, exists := gs.providers.Load(providerKey); exists && value != nil { + if providerTable, ok := value.(*configstoreTables.TableProvider); ok && providerTable != nil && providerTable.RateLimitID != nil { + updateRateLimit(*providerTable.RateLimitID) + } + } + } + + // 2. Update model-level rate limits + // Check model+provider config first (more specific) - if provider is provided + if provider != "" { + key := fmt.Sprintf("%s:%s", model, string(provider)) + if value, exists := gs.modelConfigs.Load(key); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { + updateRateLimit(*mc.RateLimitID) + } + } + } + + // Always check model-only config (if exists) - regardless of whether model+provider config exists + if value, exists := gs.modelConfigs.Load(model); exists && value != nil { + if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { + updateRateLimit(*mc.RateLimitID) + } + } + + return nil +} + +// UpdateVirtualKeyRateLimitUsageInMemory updates rate limit counters for VK-level rate limits (lock-free) +func (gs *LocalGovernanceStore) UpdateVirtualKeyRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { if vk == nil { return fmt.Errorf("virtual key cannot be nil") } @@ -591,7 +1105,7 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin requestBaselines = map[string]int64{} } - // Collect unique rate limit IDs from virtual keys + // Collect unique rate limit IDs from virtual keys, model configs, and providers rateLimitIDs := make(map[string]bool) gs.virtualKeys.Range(func(key, value interface{}) bool { vk, ok := value.(*configstoreTables.TableVirtualKey) @@ -611,6 +1125,30 @@ func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselin return true // continue }) + // Collect rate limit IDs from model configs + gs.modelConfigs.Range(func(key, value interface{}) bool { + mc, ok := value.(*configstoreTables.TableModelConfig) + if !ok || mc == nil { + return true // continue + } + if mc.RateLimitID != nil { + rateLimitIDs[*mc.RateLimitID] = true + } + return true // continue + }) + + // Collect rate limit IDs from providers + gs.providers.Range(func(key, value interface{}) bool { + provider, ok := value.(*configstoreTables.TableProvider) + if !ok || provider == nil { + return true // continue + } + if provider.RateLimitID != nil { + rateLimitIDs[*provider.RateLimitID] = true + } + return true // continue + }) + // Prepare rate limit usage updates with baselines type rateLimitUpdate struct { ID string @@ -778,8 +1316,20 @@ func (gs *LocalGovernanceStore) loadFromDatabase(ctx context.Context) error { return fmt.Errorf("failed to load rate limits: %w", err) } + // Load model configs + modelConfigs, err := gs.configStore.GetModelConfigs(ctx) + if err != nil { + return fmt.Errorf("failed to load model configs: %w", err) + } + + // Load providers with governance relationships (similar to GetModelConfigs) + providers, err := gs.configStore.GetProviders(ctx) + if err != nil { + return fmt.Errorf("failed to load providers: %w", err) + } + // Rebuild in-memory structures (lock-free) - gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits) + gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits, modelConfigs, providers) return nil } @@ -805,6 +1355,66 @@ func (gs *LocalGovernanceStore) loadFromConfigMemory(ctx context.Context, config // Load rate limits rateLimits := config.RateLimits + // Load model configs + modelConfigs := config.ModelConfigs + + // Load providers + providers := config.Providers + + // Populate model configs with their relationships (Budget and RateLimit) + for i := range modelConfigs { + mc := &modelConfigs[i] + + // Populate budget + if mc.BudgetID != nil { + for j := range budgets { + if budgets[j].ID == *mc.BudgetID { + mc.Budget = &budgets[j] + break + } + } + } + + // Populate rate limit + if mc.RateLimitID != nil { + for j := range rateLimits { + if rateLimits[j].ID == *mc.RateLimitID { + mc.RateLimit = &rateLimits[j] + break + } + } + } + + modelConfigs[i] = *mc + } + + // Populate providers with their relationships (Budget and RateLimit) + for i := range providers { + provider := &providers[i] + + // Populate budget + if provider.BudgetID != nil { + for j := range budgets { + if budgets[j].ID == *provider.BudgetID { + provider.Budget = &budgets[j] + break + } + } + } + + // Populate rate limit + if provider.RateLimitID != nil { + for j := range rateLimits { + if rateLimits[j].ID == *provider.RateLimitID { + provider.RateLimit = &rateLimits[j] + break + } + } + } + + providers[i] = *provider + } + // Populate virtual keys with their relationships for i := range virtualKeys { vk := &virtualKeys[i] @@ -864,19 +1474,21 @@ func (gs *LocalGovernanceStore) loadFromConfigMemory(ctx context.Context, config } // Rebuild in-memory structures (lock-free) - gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits) + gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits, modelConfigs, providers) return nil } // rebuildInMemoryStructures rebuilds all in-memory data structures (lock-free) -func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, customers []configstoreTables.TableCustomer, teams []configstoreTables.TableTeam, virtualKeys []configstoreTables.TableVirtualKey, budgets []configstoreTables.TableBudget, rateLimits []configstoreTables.TableRateLimit) { +func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, customers []configstoreTables.TableCustomer, teams []configstoreTables.TableTeam, virtualKeys []configstoreTables.TableVirtualKey, budgets []configstoreTables.TableBudget, rateLimits []configstoreTables.TableRateLimit, modelConfigs []configstoreTables.TableModelConfig, providers []configstoreTables.TableProvider) { // Clear existing data by creating new sync.Maps gs.virtualKeys = sync.Map{} gs.teams = sync.Map{} gs.customers = sync.Map{} gs.budgets = sync.Map{} gs.rateLimits = sync.Map{} + gs.modelConfigs = sync.Map{} + gs.providers = sync.Map{} // Build customers map for i := range customers { @@ -907,6 +1519,28 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c vk := &virtualKeys[i] gs.virtualKeys.Store(vk.Value, vk) } + + // Build model configs map + // Key format: "modelName" for global configs, "modelName:provider" for provider-specific configs + for i := range modelConfigs { + mc := &modelConfigs[i] + if mc.Provider != nil { + // Store under provider-specific key + key := fmt.Sprintf("%s:%s", mc.ModelName, *mc.Provider) + gs.modelConfigs.Store(key, mc) + } else { + // Global config (applies to all providers) - store under model name only + key := mc.ModelName + gs.modelConfigs.Store(key, mc) + } + } + + // Build providers map + // Key format: provider name (e.g., "openai", "anthropic") + for i := range providers { + provider := &providers[i] + gs.providers.Store(provider.Name, provider) + } } // UTILITY FUNCTIONS @@ -1473,6 +2107,158 @@ func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { gs.customers.Delete(customerID) } +// UpdateModelConfigInMemory adds or updates a model config in the in-memory store (lock-free) +// Preserves existing usage values when updating budgets and rate limits +// Returns the updated model config with potentially modified usage values +func (gs *LocalGovernanceStore) UpdateModelConfigInMemory(mc *configstoreTables.TableModelConfig) *configstoreTables.TableModelConfig { + if mc == nil { + return nil // Nothing to update + } + + // Clone to avoid modifying the original + clone := *mc + + // Store associated budget if exists, preserving existing usage + if clone.Budget != nil { + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, 0) + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } + + // Store associated rate limit if exists, preserving existing usage + if clone.RateLimit != nil { + var existingRateLimit *configstoreTables.TableRateLimit + if existingRateLimitValue, exists := gs.rateLimits.Load(clone.RateLimit.ID); exists && existingRateLimitValue != nil { + if erl, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && erl != nil { + existingRateLimit = erl + } + } + clone.RateLimit = checkAndUpdateRateLimit(clone.RateLimit, existingRateLimit, 0, 0) + if clone.RateLimit != nil { + gs.rateLimits.Store(clone.RateLimit.ID, clone.RateLimit) + } + } + + // Determine the key based on whether provider is specified + // Key format: "modelName" for global configs, "modelName:provider" for provider-specific configs + if clone.Provider != nil { + key := fmt.Sprintf("%s:%s", clone.ModelName, *clone.Provider) + gs.modelConfigs.Store(key, &clone) + } else { + key := clone.ModelName + gs.modelConfigs.Store(key, &clone) + } + + return &clone +} + +// DeleteModelConfigInMemory removes a model config from the in-memory store (lock-free) +func (gs *LocalGovernanceStore) DeleteModelConfigInMemory(mcID string) { + if mcID == "" { + return // Nothing to delete + } + + // Find and delete the model config by ID + gs.modelConfigs.Range(func(key, value interface{}) bool { + mc, ok := value.(*configstoreTables.TableModelConfig) + if !ok || mc == nil { + return true // continue iteration + } + + if mc.ID == mcID { + // Delete associated budget if exists + if mc.BudgetID != nil { + gs.budgets.Delete(*mc.BudgetID) + } + + // Delete associated rate limit if exists + if mc.RateLimitID != nil { + gs.rateLimits.Delete(*mc.RateLimitID) + } + + gs.modelConfigs.Delete(key) + return false // stop iteration + } + return true // continue iteration + }) +} + +// UpdateProviderInMemory adds or updates a provider in the in-memory store (lock-free) +// Preserves existing usage values when updating budgets and rate limits +// Returns the updated provider with potentially modified usage values +func (gs *LocalGovernanceStore) UpdateProviderInMemory(provider *configstoreTables.TableProvider) *configstoreTables.TableProvider { + if provider == nil { + return nil // Nothing to update + } + + // Clone to avoid modifying the original + clone := *provider + + // Store associated budget if exists, preserving existing usage + if clone.Budget != nil { + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, 0) + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } + + // Store associated rate limit if exists, preserving existing usage + if clone.RateLimit != nil { + var existingRateLimit *configstoreTables.TableRateLimit + if existingRateLimitValue, exists := gs.rateLimits.Load(clone.RateLimit.ID); exists && existingRateLimitValue != nil { + if erl, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && erl != nil { + existingRateLimit = erl + } + } + clone.RateLimit = checkAndUpdateRateLimit(clone.RateLimit, existingRateLimit, 0, 0) + if clone.RateLimit != nil { + gs.rateLimits.Store(clone.RateLimit.ID, clone.RateLimit) + } + } + + // Store under provider name + gs.providers.Store(clone.Name, &clone) + + return &clone +} + +// DeleteProviderInMemory removes a provider from the in-memory store (lock-free) +func (gs *LocalGovernanceStore) DeleteProviderInMemory(providerName string) { + if providerName == "" { + return // Nothing to delete + } + + // Get provider to check for associated budget/rate limit + if providerValue, exists := gs.providers.Load(providerName); exists && providerValue != nil { + if provider, ok := providerValue.(*configstoreTables.TableProvider); ok && provider != nil { + // Delete associated budget if exists + if provider.BudgetID != nil { + gs.budgets.Delete(*provider.BudgetID) + } + + // Delete associated rate limit if exists + if provider.RateLimitID != nil { + gs.rateLimits.Delete(*provider.RateLimitID) + } + } + } + + gs.providers.Delete(providerName) +} + // Helper functions // updateBudgetReferences updates all VKs, teams, customers, and provider configs that reference a reset budget diff --git a/plugins/governance/tracker.go b/plugins/governance/tracker.go index 1a10622a5..25e3d3c36 100644 --- a/plugins/governance/tracker.go +++ b/plugins/governance/tracker.go @@ -67,15 +67,9 @@ func NewUsageTracker(ctx context.Context, store GovernanceStore, resolver *Budge // UpdateUsage queues a usage update for async processing (main business entry point) func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { - // Get virtual key - vk, exists := t.store.GetVirtualKey(update.VirtualKey) - if !exists { - return - } - // Only process successful requests for usage tracking if !update.Success { - t.logger.Debug(fmt.Sprintf("Request was not successful, skipping usage update for VK: %s", vk.ID)) + t.logger.Debug("Request was not successful, skipping usage update") return } @@ -84,9 +78,36 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { shouldUpdateRequests := !update.IsStreaming || (update.IsStreaming && update.IsFinalChunk) shouldUpdateBudget := !update.IsStreaming || (update.IsStreaming && update.HasUsageData) + // 1. Update rate limit usage for both provider-level and model-level + // This applies even when virtual keys are disabled or not present + if err := t.store.UpdateProviderAndModelRateLimitUsageInMemory(ctx, update.Model, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + t.logger.Error("failed to update rate limit usage for model %s, provider %s: %v", update.Model, update.Provider, err) + } + + // 2. Update budget usage for both provider-level and model-level + // This applies even when virtual keys are disabled or not present + if shouldUpdateBudget && update.Cost > 0 { + if err := t.store.UpdateProviderAndModelBudgetUsageInMemory(ctx, update.Model, update.Provider, update.Cost); err != nil { + t.logger.Error("failed to update budget usage for model %s, provider %s: %v", update.Model, update.Provider, err) + } + } + + // 3. Now handle virtual key-level updates (if virtual key exists) + if update.VirtualKey == "" { + // No virtual key, provider-level and model-level updates already done above + return + } + + // Get virtual key + vk, exists := t.store.GetVirtualKey(update.VirtualKey) + if !exists { + t.logger.Debug(fmt.Sprintf("Virtual key not found: %s", update.VirtualKey)) + return + } + // Update rate limit usage (both provider-level and VK-level) if applicable if vk.RateLimit != nil || len(vk.ProviderConfigs) > 0 { - if err := t.store.UpdateRateLimitUsageInMemory(ctx, vk, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + if err := t.store.UpdateVirtualKeyRateLimitUsageInMemory(ctx, vk, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { t.logger.Error("failed to update rate limit usage for VK %s: %v", vk.ID, err) } } @@ -95,7 +116,7 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { if shouldUpdateBudget && update.Cost > 0 { t.logger.Debug("updating budget usage for VK %s", vk.ID) // Use atomic budget update to prevent race conditions and ensure consistency - if err := t.store.UpdateBudgetUsageInMemory(ctx, vk, update.Provider, update.Cost); err != nil { + if err := t.store.UpdateVirtualKeyBudgetUsageInMemory(ctx, vk, update.Provider, update.Cost); err != nil { t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err) } }