Skip to content

Commit 687d705

Browse files
feat: governance plugin refactor
1 parent bccaa4b commit 687d705

File tree

10 files changed

+1158
-466
lines changed

10 files changed

+1158
-466
lines changed

core/schemas/context.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ func (bc *BifrostContext) SetValue(key, value any) {
170170
}
171171
bc.valuesMu.Lock()
172172
defer bc.valuesMu.Unlock()
173+
if bc.userValues == nil {
174+
bc.userValues = make(map[any]any)
175+
}
173176
bc.userValues[key] = value
174177
}
175178

framework/configstore/rdb.go

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,58 @@ func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) (
12741274

12751275
// DeleteVirtualKey deletes a virtual key from the database.
12761276
func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error {
1277-
return s.db.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error
1277+
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
1278+
var virtualKey tables.TableVirtualKey
1279+
if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil {
1280+
return err
1281+
}
1282+
1283+
// Delete provider config associated budgets and rate limits first
1284+
for _, pc := range virtualKey.ProviderConfigs {
1285+
// Delete the keys join table entries
1286+
if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil {
1287+
return err
1288+
}
1289+
// Delete the budget associated with the provider config
1290+
if pc.BudgetID != nil {
1291+
if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *pc.BudgetID).Error; err != nil {
1292+
return err
1293+
}
1294+
}
1295+
// Delete the rate limit associated with the provider config
1296+
if pc.RateLimitID != nil {
1297+
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *pc.RateLimitID).Error; err != nil {
1298+
return err
1299+
}
1300+
}
1301+
}
1302+
1303+
// Delete all provider configs associated with the virtual key
1304+
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil {
1305+
return err
1306+
}
1307+
// Delete all MCP configs associated with the virtual key
1308+
if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil {
1309+
return err
1310+
}
1311+
// Delete the budget associated with the virtual key
1312+
if virtualKey.BudgetID != nil {
1313+
if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", virtualKey.BudgetID).Error; err != nil {
1314+
return err
1315+
}
1316+
}
1317+
// Delete the rate limit associated with the virtual key
1318+
if virtualKey.RateLimitID != nil {
1319+
if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", virtualKey.RateLimitID).Error; err != nil {
1320+
return err
1321+
}
1322+
}
1323+
// Delete the virtual key
1324+
return tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error
1325+
}); err != nil {
1326+
return err
1327+
}
1328+
return nil
12781329
}
12791330

12801331
// GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database.
@@ -1477,7 +1528,21 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam,
14771528

14781529
// DeleteTeam deletes a team from the database.
14791530
func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error {
1480-
return s.db.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error
1531+
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
1532+
var team tables.TableTeam
1533+
if err := tx.WithContext(ctx).First(&team, "id = ?", id).Error; err != nil {
1534+
return err
1535+
}
1536+
// Set team_id to null for all virtual keys associated with the team
1537+
if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("team_id = ?", id).Update("team_id", nil).Error; err != nil {
1538+
return err
1539+
}
1540+
// Delete the team
1541+
return tx.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error
1542+
}); err != nil {
1543+
return err
1544+
}
1545+
return nil
14811546
}
14821547

14831548
// GetCustomers retrieves all customers from the database.
@@ -1534,7 +1599,37 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta
15341599

15351600
// DeleteCustomer deletes a customer from the database.
15361601
func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error {
1537-
return s.db.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error
1602+
if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
1603+
var customer tables.TableCustomer
1604+
if err := tx.WithContext(ctx).First(&customer, "id = ?", id).Error; err != nil {
1605+
return err
1606+
}
1607+
// Set customer_id to null for all virtual keys associated with the customer
1608+
if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil {
1609+
return err
1610+
}
1611+
// Set customer_id to null for all teams associated with the customer
1612+
if err := tx.WithContext(ctx).Model(&tables.TableTeam{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil {
1613+
return err
1614+
}
1615+
// Delete the customer
1616+
return tx.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error
1617+
}); err != nil {
1618+
return err
1619+
}
1620+
return nil
1621+
}
1622+
1623+
// GetRateLimits retrieves all rate limits from the database.
1624+
func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) {
1625+
var rateLimits []tables.TableRateLimit
1626+
if err := s.db.WithContext(ctx).Find(&rateLimits).Error; err != nil {
1627+
if errors.Is(err, gorm.ErrRecordNotFound) {
1628+
return nil, ErrNotFound
1629+
}
1630+
return nil, err
1631+
}
1632+
return rateLimits, nil
15381633
}
15391634

15401635
// GetRateLimit retrieves a specific rate limit from the database.

framework/configstore/store.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ type ConfigStore interface {
100100
DeleteCustomer(ctx context.Context, id string) error
101101

102102
// Rate limit CRUD
103+
GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error)
103104
GetRateLimit(ctx context.Context, id string) (*tables.TableRateLimit, error)
104105
CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error
105106
UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error

framework/configstore/tables/budget.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ type TableBudget struct {
2323
func (TableBudget) TableName() string { return "governance_budgets" }
2424

2525
// BeforeSave hook for Budget to validate reset duration format and max limit
26-
func (b *TableBudget) BeforeSave(tx *gorm.DB) error {
26+
func (b *TableBudget) BeforeSave(tx *gorm.DB) error {
2727
// Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y")
2828
if d, err := ParseDuration(b.ResetDuration); err != nil {
2929
return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration)
30-
}else if d <= 0 {
30+
} else if d <= 0 {
3131
return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration)
3232
}
3333
// Validate that MaxLimit is not negative (budgets should be positive)

plugins/governance/main.go

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ type GovernancePlugin struct {
4444
wg sync.WaitGroup // Track active goroutines
4545

4646
// Core components with clear separation of concerns
47-
store *GovernanceStore // Pure data access layer
48-
resolver *BudgetResolver // Pure decision engine for hierarchical governance
49-
tracker *UsageTracker // Business logic owner (updates, resets, persistence)
47+
store GovernanceStore // Pure data access layer
48+
resolver *BudgetResolver // Pure decision engine for hierarchical governance
49+
tracker *UsageTracker // Business logic owner (updates, resets, persistence)
5050

5151
// Dependencies
5252
configStore configstore.ConfigStore
@@ -96,12 +96,12 @@ func Init(
9696
ctx context.Context,
9797
config *Config,
9898
logger schemas.Logger,
99-
store configstore.ConfigStore,
99+
configStore configstore.ConfigStore,
100100
governanceConfig *configstore.GovernanceConfig,
101101
modelCatalog *modelcatalog.ModelCatalog,
102102
inMemoryStore InMemoryStore,
103103
) (*GovernancePlugin, error) {
104-
if store == nil {
104+
if configStore == nil {
105105
logger.Warn("governance plugin requires config store to persist data, running in memory only mode")
106106
}
107107
if modelCatalog == nil {
@@ -114,7 +114,7 @@ func Init(
114114
isVkMandatory = config.IsVkMandatory
115115
}
116116

117-
governanceStore, err := NewGovernanceStore(ctx, logger, store, governanceConfig)
117+
governanceStore, err := NewLocalGovernanceStore(ctx, logger, configStore, governanceConfig)
118118
if err != nil {
119119
return nil, fmt.Errorf("failed to initialize governance store: %w", err)
120120
}
@@ -123,10 +123,10 @@ func Init(
123123
resolver := NewBudgetResolver(governanceStore, logger)
124124

125125
// 3. Tracker (business logic owner, depends on store and resolver)
126-
tracker := NewUsageTracker(ctx, governanceStore, resolver, store, logger)
126+
tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger)
127127

128128
// 4. Perform startup reset check for any expired limits from downtime
129-
if store != nil {
129+
if configStore != nil {
130130
if err := tracker.PerformStartupResets(ctx); err != nil {
131131
logger.Warn("startup reset failed: %v", err)
132132
// Continue initialization even if startup reset fails (non-critical)
@@ -139,7 +139,7 @@ func Init(
139139
store: governanceStore,
140140
resolver: resolver,
141141
tracker: tracker,
142-
configStore: store,
142+
configStore: configStore,
143143
modelCatalog: modelCatalog,
144144
logger: logger,
145145
isVkMandatory: isVkMandatory,
@@ -148,6 +148,54 @@ func Init(
148148
return plugin, nil
149149
}
150150

151+
func InitFromStore(
152+
ctx context.Context,
153+
config *Config,
154+
logger schemas.Logger,
155+
governanceStore GovernanceStore,
156+
configStore configstore.ConfigStore,
157+
modelCatalog *modelcatalog.ModelCatalog,
158+
inMemoryStore InMemoryStore,
159+
) (*GovernancePlugin, error) {
160+
if configStore == nil {
161+
logger.Warn("governance plugin requires config store to persist data, running in memory only mode")
162+
}
163+
if modelCatalog == nil {
164+
logger.Warn("governance plugin requires model catalog to calculate cost, all cost calculations will be skipped.")
165+
}
166+
if governanceStore == nil {
167+
return nil, fmt.Errorf("governance store is nil")
168+
}
169+
// Handle nil config - use safe default for IsVkMandatory
170+
var isVkMandatory *bool
171+
if config != nil {
172+
isVkMandatory = config.IsVkMandatory
173+
}
174+
resolver := NewBudgetResolver(governanceStore, logger)
175+
tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger)
176+
// Perform startup reset check for any expired limits from downtime
177+
if configStore != nil {
178+
if err := tracker.PerformStartupResets(ctx); err != nil {
179+
logger.Warn("startup reset failed: %v", err)
180+
// Continue initialization even if startup reset fails (non-critical)
181+
}
182+
}
183+
ctx, cancelFunc := context.WithCancel(ctx)
184+
plugin := &GovernancePlugin{
185+
ctx: ctx,
186+
cancelFunc: cancelFunc,
187+
store: governanceStore,
188+
resolver: resolver,
189+
tracker: tracker,
190+
configStore: configStore,
191+
modelCatalog: modelCatalog,
192+
logger: logger,
193+
inMemoryStore: inMemoryStore,
194+
isVkMandatory: isVkMandatory,
195+
}
196+
return plugin, nil
197+
}
198+
151199
// GetName returns the name of the plugin
152200
func (p *GovernancePlugin) GetName() string {
153201
return PluginName
@@ -596,6 +644,6 @@ func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provi
596644
}
597645

598646
// GetGovernanceStore returns the governance store
599-
func (p *GovernancePlugin) GetGovernanceStore() *GovernanceStore {
647+
func (p *GovernancePlugin) GetGovernanceStore() GovernanceStore {
600648
return p.store
601649
}

0 commit comments

Comments
 (0)