diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index ad0760886..33af9a863 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -438,6 +438,9 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(testfilter.HeaderBasedTestingFilterType, testfilter.HeaderBasedTestingFilterFactory) // register response received plugin for test purpose only (used in conformance tests) plugins.Register(testresponsereceived.DestinationEndpointServedVerifierType, testresponsereceived.DestinationEndpointServedVerifierFactory) + // register datalayer metrics collection plugins + plugins.Register(dlmetrics.MetricsDataSourceType, dlmetrics.MetricsDataSourceFactory) + plugins.Register(dlmetrics.MetricsExtractorType, dlmetrics.ModelServerExtractorFactory) } func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.EndpointPickerConfig, error) { @@ -476,7 +479,7 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End // Return a function that can be used in the EPP Handle to list pod names. func makePodListFunc(ds datastore.Datastore) func() []types.NamespacedName { return func() []types.NamespacedName { - pods := ds.PodList(func(_ backendmetrics.PodMetrics) bool { return true }) + pods := ds.PodList(backendmetrics.AllPodsPredicate) names := make([]types.NamespacedName, 0, len(pods)) for _, p := range pods { @@ -615,10 +618,10 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { // are to be configured), must be done before the EndpointFactory is initialized. func setupDatalayer(logger logr.Logger) (datalayer.EndpointFactory, error) { // create and register a metrics data source and extractor. - source := dlmetrics.NewDataSource(*modelServerMetricsScheme, + source := dlmetrics.NewMetricsDataSource(*modelServerMetricsScheme, *modelServerMetricsPath, *modelServerMetricsHttpsInsecureSkipVerify) - extractor, err := dlmetrics.NewExtractor(*totalQueuedRequestsMetric, + extractor, err := dlmetrics.NewModelServerExtractor(*totalQueuedRequestsMetric, *totalRunningRequestsMetric, *kvCacheUsagePercentageMetric, *loraInfoMetric, *cacheInfoMetric) diff --git a/pkg/epp/datalayer/datasource_test.go b/pkg/epp/datalayer/datasource_test.go index 7ac262a5c..c1f0edcf8 100644 --- a/pkg/epp/datalayer/datasource_test.go +++ b/pkg/epp/datalayer/datasource_test.go @@ -31,17 +31,17 @@ const ( ) type mockDataSource struct { - tn plugins.TypedName + typedName plugins.TypedName } -func (m *mockDataSource) TypedName() plugins.TypedName { return m.tn } +func (m *mockDataSource) TypedName() plugins.TypedName { return m.typedName } func (m *mockDataSource) Extractors() []string { return []string{} } func (m *mockDataSource) AddExtractor(_ Extractor) error { return nil } func (m *mockDataSource) Collect(_ context.Context, _ Endpoint) error { return nil } func TestRegisterAndGetSource(t *testing.T) { reg := DataSourceRegistry{} - ds := &mockDataSource{tn: plugins.TypedName{Type: testType, Name: testType}} + ds := &mockDataSource{typedName: plugins.TypedName{Type: testType, Name: testType}} err := reg.Register(ds) assert.NoError(t, err, "expected no error on first registration") diff --git a/pkg/epp/datalayer/metrics/datasource.go b/pkg/epp/datalayer/metrics/datasource.go index d5940ac65..df4b1d378 100644 --- a/pkg/epp/datalayer/metrics/datasource.go +++ b/pkg/epp/datalayer/metrics/datasource.go @@ -28,14 +28,10 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) -const ( - DataSourceType = "metrics-data-source" -) - // DataSource is a Model Server Protocol (MSP) compliant metrics data source, // returning Prometheus formatted metrics for an endpoint. type DataSource struct { - tn plugins.TypedName + typedName plugins.TypedName metricsScheme string // scheme to use in metrics URL metricsPath string // path to use in metrics URL @@ -43,10 +39,9 @@ type DataSource struct { extractors sync.Map // key: name, value: extractor } -// NewDataSource returns a new MSP compliant metrics data source, configured with -// the provided client configuration. -// The Scheme, path and certificate validation setting are command line options. -func NewDataSource(metricsScheme string, metricsPath string, skipCertVerification bool) *DataSource { +// NewMetricsDataSource returns a new MSP compliant metrics data source, configured with +// the provided scheme, path and certificate verification parameters. +func NewMetricsDataSource(metricsScheme string, metricsPath string, skipCertVerification bool) *DataSource { if metricsScheme == "https" { httpsTransport := baseTransport.Clone() httpsTransport.TLSClientConfig = &tls.Config{ @@ -56,9 +51,9 @@ func NewDataSource(metricsScheme string, metricsPath string, skipCertVerificatio } dataSrc := &DataSource{ - tn: plugins.TypedName{ - Type: DataSourceType, - Name: DataSourceType, + typedName: plugins.TypedName{ + Type: MetricsDataSourceType, + Name: MetricsDataSourceType, }, metricsScheme: metricsScheme, metricsPath: metricsPath, @@ -69,7 +64,7 @@ func NewDataSource(metricsScheme string, metricsPath string, skipCertVerificatio // TypedName returns the metrics data source type and name. func (dataSrc *DataSource) TypedName() plugins.TypedName { - return dataSrc.tn + return dataSrc.typedName } // Extractors returns a list of registered Extractor names. diff --git a/pkg/epp/datalayer/metrics/datasource_test.go b/pkg/epp/datalayer/metrics/datasource_test.go index 4d4db2a01..35255b7bc 100644 --- a/pkg/epp/datalayer/metrics/datasource_test.go +++ b/pkg/epp/datalayer/metrics/datasource_test.go @@ -28,12 +28,12 @@ import ( ) func TestDatasource(t *testing.T) { - source := NewDataSource("https", "/metrics", true) - extractor, err := NewExtractor(defaultTotalQueuedRequestsMetric, "", "", "", "") + source := NewMetricsDataSource("https", "/metrics", true) + extractor, err := NewModelServerExtractor(defaultTotalQueuedRequestsMetric, "", "", "", "") assert.Nil(t, err, "failed to create extractor") dsType := source.TypedName().Type - assert.Equal(t, DataSourceType, dsType) + assert.Equal(t, MetricsDataSourceType, dsType) err = source.AddExtractor(extractor) assert.Nil(t, err, "failed to add extractor") diff --git a/pkg/epp/datalayer/metrics/extractor.go b/pkg/epp/datalayer/metrics/extractor.go index 9f450ee5b..5404b3b62 100644 --- a/pkg/epp/datalayer/metrics/extractor.go +++ b/pkg/epp/datalayer/metrics/extractor.go @@ -35,8 +35,6 @@ import ( ) const ( - extractorType = "model-server-protocol-metrics" - // LoRA metrics based on MSP LoraInfoRunningAdaptersMetricName = "running_lora_adapters" LoraInfoWaitingAdaptersMetricName = "waiting_lora_adapters" @@ -49,10 +47,12 @@ const ( // Extractor implements the metrics extraction based on the model // server protocol standard. type Extractor struct { - tn plugins.TypedName - mapping *Mapping + typedName plugins.TypedName + mapping *Mapping } +// Produces returns the data attributes that are provided by the datalayer.metrics +// package. func Produces() map[string]any { return map[string]any{ metrics.WaitingQueueSizeKey: int(0), @@ -64,19 +64,19 @@ func Produces() map[string]any { } } -// NewExtractor returns a new model server protocol (MSP) metrics extractor, +// NewModelServerExtractor returns a new model server protocol (MSP) metrics extractor, // configured with the given metrics' specifications. // These are mandatory metrics per the MSP specification, and are used // as the basis for the built-in scheduling plugins. -func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) { +func NewModelServerExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) { mapping, err := NewMapping(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec) if err != nil { return nil, fmt.Errorf("failed to create extractor metrics Mapping - %w", err) } return &Extractor{ - tn: plugins.TypedName{ - Type: extractorType, - Name: extractorType, + typedName: plugins.TypedName{ + Type: MetricsExtractorType, + Name: MetricsExtractorType, }, mapping: mapping, }, nil @@ -84,7 +84,7 @@ func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec s // TypedName returns the type and name of the metrics.Extractor. func (ext *Extractor) TypedName() plugins.TypedName { - return ext.tn + return ext.typedName } // ExpectedType defines the type expected by the metrics.Extractor - a diff --git a/pkg/epp/datalayer/metrics/extractor_test.go b/pkg/epp/datalayer/metrics/extractor_test.go index 3900c1c98..e60847715 100644 --- a/pkg/epp/datalayer/metrics/extractor_test.go +++ b/pkg/epp/datalayer/metrics/extractor_test.go @@ -40,11 +40,11 @@ const ( func TestExtractorExtract(t *testing.T) { ctx := context.Background() - if _, err := NewExtractor("vllm: dummy", "", "", "", ""); err == nil { + if _, err := NewModelServerExtractor("vllm: dummy", "", "", "", ""); err == nil { t.Error("expected to fail to create extractor with invalid specification") } - extractor, err := NewExtractor(defaultTotalQueuedRequestsMetric, defaultTotalRunningRequestsMetric, + extractor, err := NewModelServerExtractor(defaultTotalQueuedRequestsMetric, defaultTotalRunningRequestsMetric, defaultKvCacheUsagePercentageMetric, defaultLoraInfoMetric, defaultCacheInfoMetric) if err != nil { t.Fatalf("failed to create extractor: %v", err) @@ -54,6 +54,10 @@ func TestExtractorExtract(t *testing.T) { t.Error("empty extractor type") } + if exName := extractor.TypedName().Name; exName == "" { + t.Error("empty extractor name") + } + if inputType := extractor.ExpectedInputType(); inputType != PrometheusMetricType { t.Errorf("incorrect expected input type: %v", inputType) } diff --git a/pkg/epp/datalayer/metrics/factories.go b/pkg/epp/datalayer/metrics/factories.go new file mode 100644 index 000000000..bea231f4c --- /dev/null +++ b/pkg/epp/datalayer/metrics/factories.go @@ -0,0 +1,184 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "encoding/json" + "flag" + "fmt" + "strconv" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" +) + +const ( + MetricsDataSourceType = "metrics-data-source" + MetricsExtractorType = "model-server-protocol-metrics" +) + +// Configuration parameters for metrics data source and extractor. +type ( + // Data source configuration parameters + metricsDatasourceParams struct { + // Scheme defines the protocol scheme used in metrics retrieval (e.g., "http"). + Scheme string // `json:"scheme"` + // Path defines the URL path used in metrics retrieval (e.g., "/metrics"). + Path string // `json:"path"` + // InsecureSkipVerify defines whether model server certificate should be verified or not. + InsecureSkipVerify bool // `json:"insecureSkipVerify"` + } + + // Extractor configuration parameters + modelServerExtractorParams struct { + // QueueRequestsSpec defines the metric specification string for retrieving queued request count. + QueueRequestsSpec string // `json:"queuedRequestsSpec"` + // RunningRequestsSpec defines the metric specification string for retrieving running requests count. + RunningRequestsSpec string // `json:"runningRequestsSpec"` + // KVUsage defines the metric specification string for retrieving KV cache usage. + KVUsageSpec string // `json:"kvUsageSpec"` + // LoRASpec defines the metric specification string for retrieving LoRA availability. + LoRASpec string // `json:"loraSpec"` + // CacheInfoSpec defines the metrics specification string for retrieving KV cache configuration. + CacheInfoSpec string // `json:"cacheInfoSpec"` + } +) + +// MetricsDataSourceFactory is a factory function used to instantiate data layer's +// metrics data source plugins specified in a configuration. +func MetricsDataSourceFactory(name string, parameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + cfg, err := defaultDataSourceConfigParams() + if err != nil { + return nil, err + } + + if parameters != nil { // overlay the defaults with configured values + if err := json.Unmarshal(parameters, cfg); err != nil { + return nil, err + } + } + + ds := NewMetricsDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify) + ds.typedName.Name = name + return ds, nil +} + +// ModelServerExtractorFactory is a factory function used to instantiate data layer's metrics +// Extractor plugins specified in a configuration. +func ModelServerExtractorFactory(name string, parameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + cfg, err := defaultExtractorConfigParams() + if err != nil { + return nil, err + } + + if parameters != nil { // overlay the defaults with configured values + if err := json.Unmarshal(parameters, cfg); err != nil { + return nil, err + } + } + + extractor, err := NewModelServerExtractor(cfg.QueueRequestsSpec, cfg.RunningRequestsSpec, cfg.KVUsageSpec, + cfg.LoRASpec, cfg.CacheInfoSpec) + if err != nil { + return nil, err + } + extractor.typedName.Name = name + return extractor, nil +} + +// Names of CLI flags in main +// +// TODO: +// +// 1. Consider having a cli package with all flag names and constants? +// Can't use values from runserver as this creates an import cycle with datalayer. +// Given that relevant issues/PRs have been closed so may be able to remove the cycle? +// Comment from runserver package (regarding TestPodMetricsClient *backendmetrics.FakePodMetricsClient) +// This should only be used in tests. We won't need this once we do not inject metrics in the tests. +// TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup +// +// 2. Deprecation notice on these flags being moved to the configuration file +const ( + totalQueuedRequestsMetricSpecFlag = "total-queued-requests-metric" + totalRunningRequestsMetricSpecFlag = "total-running-requests-metric" + kvCacheUsagePercentageMetricSpecFlags = "kv-cache-usage-percentage-metric" + loraInfoMetricSpecFlag = "lora-info-metric" + cacheInfoMetricSpecFlag = "cache-info-metric" + modelServerMetricsPathFlag = "model-server-metrics-path" + modelServerMetricsSchemeFlag = "model-server-metrics-scheme" + modelServerMetricsInsecureSkipVerifyFlag = "model-server-metrics-https-insecure-skip-verify" +) + +// return the default configuration state. The defaults are populated from +// existing command line flags. +func defaultDataSourceConfigParams() (*metricsDatasourceParams, error) { + var err error + cfg := &metricsDatasourceParams{} + + if cfg.Scheme, err = fromStringFlag(modelServerMetricsSchemeFlag); err != nil { + return nil, err + } + if cfg.Path, err = fromStringFlag(modelServerMetricsPathFlag); err != nil { + return nil, err + } + if cfg.InsecureSkipVerify, err = fromBoolFlag(modelServerMetricsInsecureSkipVerifyFlag); err != nil { + return nil, err + } + return cfg, nil +} + +func defaultExtractorConfigParams() (*modelServerExtractorParams, error) { + var err error + cfg := &modelServerExtractorParams{} + + if cfg.QueueRequestsSpec, err = fromStringFlag(totalQueuedRequestsMetricSpecFlag); err != nil { + return nil, err + } + if cfg.RunningRequestsSpec, err = fromStringFlag(totalRunningRequestsMetricSpecFlag); err != nil { + return nil, err + } + if cfg.KVUsageSpec, err = fromStringFlag(kvCacheUsagePercentageMetricSpecFlags); err != nil { + return nil, err + } + if cfg.LoRASpec, err = fromStringFlag(loraInfoMetricSpecFlag); err != nil { + return nil, err + } + if cfg.CacheInfoSpec, err = fromStringFlag(cacheInfoMetricSpecFlag); err != nil { + return nil, err + } + + return cfg, nil +} + +func fromStringFlag(name string) (string, error) { + f := flag.Lookup(name) + if f == nil { + return "", fmt.Errorf("flag not found: %s", name) + } + return f.Value.String(), nil +} + +func fromBoolFlag(name string) (bool, error) { + f := flag.Lookup(name) + if f == nil { + return false, fmt.Errorf("flag not found: %s", name) + } + b, err := strconv.ParseBool(f.Value.String()) + if err != nil { + return false, fmt.Errorf("invalid bool flag %q: %w", name, err) + } + return b, nil +}