diff --git a/Makefile b/Makefile index 373171265d..112a4806d3 100644 --- a/Makefile +++ b/Makefile @@ -176,7 +176,7 @@ ui-test: ui-build ./hack/test-ui.sh .PHONY: image -image: clean ui-build dist/$(BINARY_NAME)-linux-$(HOST_ARCH) +image: clean dist/$(BINARY_NAME)-linux-$(HOST_ARCH) ifdef GITHUB_ACTIONS # The binary will be built in a separate Github Actions job cp -pv numaflow-rs-linux-amd64 dist/numaflow-rs-linux-amd64 diff --git a/cmd/commands/isbsvc_create.go b/cmd/commands/isbsvc_create.go index 86c1b1438c..6db33b2432 100644 --- a/cmd/commands/isbsvc_create.go +++ b/cmd/commands/isbsvc_create.go @@ -36,11 +36,11 @@ import ( func NewISBSvcCreateCommand() *cobra.Command { var ( - isbSvcType string - buffers []string - buckets []string - sideInputsStore string - servingSourceStreams []string + isbSvcType string + buffers []string + buckets []string + sideInputsStore string + servingSourceStore string ) command := &cobra.Command{ @@ -89,7 +89,7 @@ func NewISBSvcCreateCommand() *cobra.Command { return fmt.Errorf("unsupported isb service type %q", isbSvcType) } - if err = isbsClient.CreateBuffersAndBuckets(ctx, buffers, buckets, sideInputsStore, servingSourceStreams, opts...); err != nil { + if err = isbsClient.CreateBuffersAndBuckets(ctx, buffers, buckets, sideInputsStore, servingSourceStore, opts...); err != nil { logger.Errorw("Failed to create buffers, buckets and side inputs store.", zap.Error(err)) return err } @@ -102,6 +102,6 @@ func NewISBSvcCreateCommand() *cobra.Command { command.Flags().StringSliceVar(&buffers, "buffers", []string{}, "Buffers to create") // --buffers=a,b, --buffers=c command.Flags().StringSliceVar(&buckets, "buckets", []string{}, "Buckets to create") // --buckets=xxa,xxb --buckets=xxc command.Flags().StringVar(&sideInputsStore, "side-inputs-store", "", "Name of the side inputs store") - command.Flags().StringSliceVar(&servingSourceStreams, "serving-source-streams", []string{}, "Serving source streams to create") // --serving-source-streams=a,b, --serving-source-streams=c + command.Flags().StringVar(&servingSourceStore, "serving-source-store", "", "Serving source streams to create") // --serving-source-store=a return command } diff --git a/cmd/commands/isbsvc_delete.go b/cmd/commands/isbsvc_delete.go index c6da2418e4..0997099bc5 100644 --- a/cmd/commands/isbsvc_delete.go +++ b/cmd/commands/isbsvc_delete.go @@ -33,11 +33,11 @@ import ( func NewISBSvcDeleteCommand() *cobra.Command { var ( - isbSvcType string - buffers []string - buckets []string - sideInputsStore string - servingSourceStreams []string + isbSvcType string + buffers []string + buckets []string + sideInputsStore string + servingSourceStore string ) command := &cobra.Command{ @@ -74,7 +74,7 @@ func NewISBSvcDeleteCommand() *cobra.Command { cmd.HelpFunc()(cmd, args) return fmt.Errorf("unsupported isb service type %q", isbSvcType) } - if err = isbsClient.DeleteBuffersAndBuckets(ctx, buffers, buckets, sideInputsStore, servingSourceStreams); err != nil { + if err = isbsClient.DeleteBuffersAndBuckets(ctx, buffers, buckets, sideInputsStore, servingSourceStore); err != nil { logger.Errorw("Failed on buffers, buckets and side inputs store deletion.", zap.Error(err)) return err } @@ -86,6 +86,6 @@ func NewISBSvcDeleteCommand() *cobra.Command { command.Flags().StringSliceVar(&buffers, "buffers", []string{}, "Buffers to delete") // --buffers=a,b, --buffers=c command.Flags().StringSliceVar(&buckets, "buckets", []string{}, "Buckets to delete") // --buckets=xxa,xxb --buckets=xxc return command command.Flags().StringVar(&sideInputsStore, "side-inputs-store", "", "Name of the side inputs store") - command.Flags().StringSliceVar(&servingSourceStreams, "serving-source-streams", []string{}, "Serving source streams to delete") // --serving-source-streams=a,b, --serving-source-streams=c + command.Flags().StringVar(&servingSourceStore, "serving-source-store", "", "Serving source store to delete") // --serving-source-store=a return command } diff --git a/cmd/commands/isbsvc_validate.go b/cmd/commands/isbsvc_validate.go index 8f8bc1b892..2b7aaa90c9 100644 --- a/cmd/commands/isbsvc_validate.go +++ b/cmd/commands/isbsvc_validate.go @@ -36,11 +36,11 @@ import ( func NewISBSvcValidateCommand() *cobra.Command { var ( - isbSvcType string - buffers []string - buckets []string - sideInputsStore string - servingSourceStreams []string + isbSvcType string + buffers []string + buckets []string + sideInputsStore string + servingSourceStore string ) command := &cobra.Command{ @@ -77,7 +77,7 @@ func NewISBSvcValidateCommand() *cobra.Command { return fmt.Errorf("unsupported isb service type") } _ = wait.ExponentialBackoffWithContext(ctx, sharedutil.DefaultRetryBackoff, func(_ context.Context) (bool, error) { - if err = isbsClient.ValidateBuffersAndBuckets(ctx, buffers, buckets, sideInputsStore, servingSourceStreams); err != nil { + if err = isbsClient.ValidateBuffersAndBuckets(ctx, buffers, buckets, sideInputsStore, servingSourceStore); err != nil { logger.Infow("Buffers, buckets and side inputs store might have not been created yet, will retry if the limit is not reached", zap.Error(err)) return false, nil } @@ -95,7 +95,7 @@ func NewISBSvcValidateCommand() *cobra.Command { command.Flags().StringSliceVar(&buffers, "buffers", []string{}, "Buffers to validate") // --buffers=a,b, --buffers=c command.Flags().StringSliceVar(&buckets, "buckets", []string{}, "Buckets to validate") // --buckets=xxa,xxb --buckets=xxc command.Flags().StringVar(&sideInputsStore, "side-inputs-store", "", "Name of the side inputs store") - command.Flags().StringSliceVar(&servingSourceStreams, "serving-source-streams", []string{}, "Serving source streams to validate") // --serving-source-streams=a,b, --serving-source-streams=c + command.Flags().StringVar(&servingSourceStore, "serving-source-store", "", "Serving source store to validate") // --serving-source-store=a return command } diff --git a/examples/1-simple-pipeline.yaml b/examples/1-simple-pipeline.yaml index 42e9d9e095..637e2193d2 100644 --- a/examples/1-simple-pipeline.yaml +++ b/examples/1-simple-pipeline.yaml @@ -16,8 +16,9 @@ spec: scale: min: 1 udf: - builtin: - name: cat # A built-in UDF which simply cats the message + container: + image: quay.io/numaio/numaflow-go/map-flatmap-stream:stable + imagePullPolicy: Never - name: out scale: min: 1 diff --git a/examples/15-serving-source-pipeline.yaml b/examples/15-serving-source-pipeline.yaml index fe9b1e5ef8..3b37db2bae 100644 --- a/examples/15-serving-source-pipeline.yaml +++ b/examples/15-serving-source-pipeline.yaml @@ -3,34 +3,42 @@ kind: Pipeline metadata: name: simple-pipeline spec: - templates: - vertex: - metadata: - annotations: - numaflow.numaproj.io/callback: "true" vertices: - - name: in + - name: serving-in scale: min: 1 source: serving: service: true - msgIDHeaderKey: "X-Request-ID" + msgIDHeaderKey: "X-Numaflow-Id" store: url: "redis://redis:6379" + - name: cat scale: min: 1 udf: - builtin: - name: cat # A built-in UDF which simply cats the message - - name: out + container: + image: quay.io/numaio/numaflow-go/map-forward-message:stable + env: + - name: RUST_BACKTRACE + value: "1" + + - name: serve-sink scale: min: 1 sink: - log: {} + udsink: + container: + image: docker.intuit.com/quay-rmt/numaio/servesink:v1.5.0-alpha1 + env: + - name: NUMAFLOW_CALLBACK_URL_KEY + value: "X-Numaflow-Callback-Url" + - name: NUMAFLOW_MSG_ID_HEADER_KEY + value: "X-Numaflow-Id" + edges: - - from: in + - from: serving-in to: cat - from: cat - to: out \ No newline at end of file + to: serve-sink \ No newline at end of file diff --git a/pkg/apis/numaflow/v1alpha1/pipeline_types.go b/pkg/apis/numaflow/v1alpha1/pipeline_types.go index 9d44901573..c34d2d7274 100644 --- a/pkg/apis/numaflow/v1alpha1/pipeline_types.go +++ b/pkg/apis/numaflow/v1alpha1/pipeline_types.go @@ -213,14 +213,8 @@ func (p Pipeline) GetSideInputsStoreName() string { return fmt.Sprintf("%s-%s", p.Namespace, p.Name) } -func (p Pipeline) GetServingSourceStreamNames() []string { - var servingSourceNames []string - for _, srcVertex := range p.Spec.Vertices { - if srcVertex.IsASource() && srcVertex.Source.Serving != nil { - servingSourceNames = append(servingSourceNames, fmt.Sprintf("%s-%s-serving-source", p.Name, srcVertex.Name)) - } - } - return servingSourceNames +func (p Pipeline) GetServingSourceStoreName() string { + return fmt.Sprintf("%s-%s", p.Namespace, p.Name) } func (p Pipeline) GetSideInputsManagerDeployments(req GetSideInputDeploymentReq) ([]*appv1.Deployment, error) { diff --git a/pkg/apis/numaflow/v1alpha1/pipeline_types_test.go b/pkg/apis/numaflow/v1alpha1/pipeline_types_test.go index 6a1c90566b..2caafb136d 100644 --- a/pkg/apis/numaflow/v1alpha1/pipeline_types_test.go +++ b/pkg/apis/numaflow/v1alpha1/pipeline_types_test.go @@ -527,7 +527,7 @@ func TestGetServingSourceStreamNames(t *testing.T) { }, } var expected []string - assert.Equal(t, expected, p.GetServingSourceStreamNames()) + assert.Equal(t, expected, p.GetServingSourceStoreName()) }) t.Run("with serving sources", func(t *testing.T) { @@ -545,7 +545,7 @@ func TestGetServingSourceStreamNames(t *testing.T) { }, } expected := []string{"test-pipeline-v1-serving-source", "test-pipeline-v2-serving-source"} - assert.Equal(t, expected, p.GetServingSourceStreamNames()) + assert.Equal(t, expected, p.GetServingSourceStoreName()) }) } diff --git a/pkg/apis/proto/serving/v1/store.proto b/pkg/apis/proto/serving/v1/store.proto new file mode 100644 index 0000000000..da62e17c43 --- /dev/null +++ b/pkg/apis/proto/serving/v1/store.proto @@ -0,0 +1,62 @@ +/* +Copyright 2022 The Numaproj Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +syntax = "proto3"; + +option go_package = "github.com/numaproj/numaflow-go/pkg/apis/proto/serving/v1"; +option java_package = "io.numaproj.numaflow.serving.v1"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +package map.v1; + +service ServingStore { + rpc Put(PutRequest) returns (PutResponse); + + rpc Get(GetRequest) returns (GetResponse); + + rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); +} + +message Payload { + string id = 1; + bytes value = 2; +} + + +message PutRequest { + repeated Payload payload = 1; +} + +message PutResponse { + bool success = 1; +} + +message GetRequest { + string id = 1; +} + +message GetResponse { + repeated Payload payload = 1; +} + +/** + * ReadyResponse is the health check result. + */ +message ReadyResponse { + bool ready = 1; +} \ No newline at end of file diff --git a/pkg/daemon/server/service/pipeline_metrics_query_test.go b/pkg/daemon/server/service/pipeline_metrics_query_test.go index 25dd47bdfc..bb0027dcd1 100644 --- a/pkg/daemon/server/service/pipeline_metrics_query_test.go +++ b/pkg/daemon/server/service/pipeline_metrics_query_test.go @@ -58,15 +58,15 @@ func (ms *mockIsbSvcClient) GetBufferInfo(ctx context.Context, buffer string) (* }, nil } -func (ms *mockIsbSvcClient) CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string, opts ...isbsvc.CreateOption) error { +func (ms *mockIsbSvcClient) CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string, opts ...isbsvc.CreateOption) error { return nil } -func (ms *mockIsbSvcClient) DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error { +func (ms *mockIsbSvcClient) DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error { return nil } -func (ms *mockIsbSvcClient) ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error { +func (ms *mockIsbSvcClient) ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error { return nil } diff --git a/pkg/isbsvc/interface.go b/pkg/isbsvc/interface.go index 82cfed134b..6aeeb9b79d 100644 --- a/pkg/isbsvc/interface.go +++ b/pkg/isbsvc/interface.go @@ -25,11 +25,11 @@ import ( // ISBService is an interface used to do the operations on ISBSvc type ISBService interface { // CreateBuffersAndBuckets creates buffers and buckets - CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string, opts ...CreateOption) error + CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string, opts ...CreateOption) error // DeleteBuffersAndBuckets deletes buffers and buckets - DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error + DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error // ValidateBuffersAndBuckets validates buffers and buckets - ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceSTreams []string) error + ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error // GetBufferInfo returns buffer info for the given buffer GetBufferInfo(ctx context.Context, buffer string) (*BufferInfo, error) // CreateWatermarkStores creates watermark stores diff --git a/pkg/isbsvc/jetstream_service.go b/pkg/isbsvc/jetstream_service.go index 867d8a655e..3870ef457c 100644 --- a/pkg/isbsvc/jetstream_service.go +++ b/pkg/isbsvc/jetstream_service.go @@ -52,7 +52,7 @@ func NewISBJetStreamSvc(pipelineName string, jsClient *jsclient.Client) (ISBServ return j, nil } -func (jss *jetStreamSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string, opts ...CreateOption) error { +func (jss *jetStreamSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string, opts ...CreateOption) error { if len(buffers) == 0 && len(buckets) == 0 { return nil } @@ -90,26 +90,24 @@ func (jss *jetStreamSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, b } } - if len(servingSourceStreams) > 0 { - for _, servingSourceStream := range servingSourceStreams { - _, err := jss.js.StreamInfo(servingSourceStream) - if err != nil { - if !errors.Is(err, nats.ErrStreamNotFound) { - return fmt.Errorf("failed to query information of stream %q during buffer creating, %w", servingSourceStream, err) - } - if _, err := jss.js.AddStream(&nats.StreamConfig{ - Name: servingSourceStream, - Subjects: []string{servingSourceStream}, // Use the stream name as the only subject - Storage: nats.StorageType(v.GetInt("stream.storage")), - Replicas: v.GetInt("stream.replicas"), - Retention: nats.WorkQueuePolicy, // we can delete the message immediately after it's consumed and acked - MaxMsgs: -1, // unlimited messages - MaxBytes: -1, // unlimited bytes - Duplicates: v.GetDuration("stream.duplicates"), - }); err != nil { - return fmt.Errorf("failed to create serving source stream %q, %w", servingSourceStream, err) - } + if servingSourceStore != "" { + kvName := JetStreamServingSourceStoreKVName(sideInputsStore) + if _, err := jss.js.KeyValue(kvName); err != nil { + if !errors.Is(err, nats.ErrBucketNotFound) && !errors.Is(err, nats.ErrStreamNotFound) { + return fmt.Errorf("failed to query information of KV %q, %w", kvName, err) } + if _, err := jss.js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: kvName, + MaxValueSize: 0, + History: 64, // No history + TTL: time.Hour * 24 * 1, // 1 day + MaxBytes: 0, + Storage: nats.FileStorage, + Replicas: v.GetInt("stream.replicas"), + }); err != nil { + return fmt.Errorf("failed to create serving source KV %q, %w", kvName, err) + } + log.Infow("Succeeded to create a serving source KV", zap.String("kvName", kvName)) } } @@ -216,7 +214,7 @@ func (jss *jetStreamSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, b return nil } -func (jss *jetStreamSvc) DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error { +func (jss *jetStreamSvc) DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error { if len(buffers) == 0 && len(buckets) == 0 { return nil } @@ -249,18 +247,17 @@ func (jss *jetStreamSvc) DeleteBuffersAndBuckets(ctx context.Context, buffers, b log.Infow("Succeeded to delete a side inputs KV", zap.String("kvName", sideInputsKVName)) } - if len(servingSourceStreams) > 0 { - for _, servingSourceStream := range servingSourceStreams { - if err := jss.js.DeleteStream(servingSourceStream); err != nil && !errors.Is(err, nats.ErrStreamNotFound) { - return fmt.Errorf("failed to delete serving source stream %q, %w", servingSourceStream, err) - } - log.Infow("Succeeded to delete the serving source stream", zap.String("stream", servingSourceStream)) + if servingSourceStore != "" { + servingSourceStoreKVName := JetStreamServingSourceStoreKVName(sideInputsStore) + if err := jss.js.DeleteKeyValue(servingSourceStoreKVName); err != nil && !errors.Is(err, nats.ErrBucketNotFound) && !errors.Is(err, nats.ErrStreamNotFound) { + return fmt.Errorf("failed to serving source store %q, %w", servingSourceStoreKVName, err) } + log.Infow("Succeeded to delete a serving source store", zap.String("kvName", servingSourceStoreKVName)) } return nil } -func (jss *jetStreamSvc) ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error { +func (jss *jetStreamSvc) ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error { if len(buffers) == 0 && len(buckets) == 0 { return nil } @@ -288,11 +285,10 @@ func (jss *jetStreamSvc) ValidateBuffersAndBuckets(ctx context.Context, buffers, return fmt.Errorf("failed to query side inputs store KV %q, %w", sideInputsKVName, err) } } - if len(servingSourceStreams) > 0 { - for _, servingSourceStream := range servingSourceStreams { - if _, err := jss.js.StreamInfo(servingSourceStream); err != nil { - return fmt.Errorf("failed to query information of stream %q, %w", servingSourceStream, err) - } + if servingSourceStore != "" { + servingSourceStoreKVName := JetStreamServingSourceStoreKVName(servingSourceStore) + if _, err := jss.js.KeyValue(servingSourceStoreKVName); err != nil { + return fmt.Errorf("failed to query serving source store KV %q, %w", servingSourceStoreKVName, err) } } return nil @@ -348,3 +344,7 @@ func JetStreamName(bufferName string) string { func JetStreamSideInputsStoreKVName(sideInputStoreName string) string { return fmt.Sprintf("%s_SIDE_INPUTS", sideInputStoreName) } + +func JetStreamServingSourceStoreKVName(servingSourceStoreName string) string { + return fmt.Sprintf("%s_SERVING_STORE", servingSourceStoreName) +} diff --git a/pkg/isbsvc/redis_service.go b/pkg/isbsvc/redis_service.go index 6fadc9814a..0f6b68d2dd 100644 --- a/pkg/isbsvc/redis_service.go +++ b/pkg/isbsvc/redis_service.go @@ -39,7 +39,7 @@ func NewISBRedisSvc(client *redisclient.RedisClient) ISBService { } // CreateBuffersAndBuckets is used to create the inter-step redis buffers. -func (r *isbsRedisSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string, opts ...CreateOption) error { +func (r *isbsRedisSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string, opts ...CreateOption) error { if len(buffers) == 0 && len(buckets) == 0 { return nil } @@ -67,7 +67,7 @@ func (r *isbsRedisSvc) CreateBuffersAndBuckets(ctx context.Context, buffers, buc } // DeleteBuffersAndBuckets is used to delete the inter-step redis buffers. -func (r *isbsRedisSvc) DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error { +func (r *isbsRedisSvc) DeleteBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error { if len(buffers) == 0 && len(buckets) == 0 { return nil } @@ -103,7 +103,7 @@ func (r *isbsRedisSvc) DeleteBuffersAndBuckets(ctx context.Context, buffers, buc } // ValidateBuffersAndBuckets is used to validate inter-step redis buffers to see if the stream/stream group exist -func (r *isbsRedisSvc) ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStreams []string) error { +func (r *isbsRedisSvc) ValidateBuffersAndBuckets(ctx context.Context, buffers, buckets []string, sideInputsStore string, servingSourceStore string) error { if len(buffers) == 0 && len(buckets) == 0 { return nil } diff --git a/pkg/reconciler/pipeline/controller.go b/pkg/reconciler/pipeline/controller.go index c22b5702e0..f974a3c943 100644 --- a/pkg/reconciler/pipeline/controller.go +++ b/pkg/reconciler/pipeline/controller.go @@ -313,13 +313,13 @@ func (r *pipelineReconciler) reconcileFixedResources(ctx context.Context, pl *df } args := []string{fmt.Sprintf("--buffers=%s", strings.Join(bfs, ",")), fmt.Sprintf("--buckets=%s", strings.Join(bks, ","))} args = append(args, fmt.Sprintf("--side-inputs-store=%s", pl.GetSideInputsStoreName())) - args = append(args, fmt.Sprintf("--serving-source-streams=%s", strings.Join(pl.GetServingSourceStreamNames(), ","))) + args = append(args, fmt.Sprintf("--serving-source-store=%s", pl.GetServingSourceStoreName())) batchJob := buildISBBatchJob(pl, r.image, isbSvc.Status.Config, "isbsvc-create", args, "cre") if err := r.client.Create(ctx, batchJob); err != nil && !apierrors.IsAlreadyExists(err) { r.recorder.Eventf(pl, corev1.EventTypeWarning, "CreateJobForISBCeationFailed", "Failed to create a Job: %w", err.Error()) return fmt.Errorf("failed to create ISB creating job, err: %w", err) } - log.Infow("Created a job successfully for ISB creating", zap.Any("buffers", bfs), zap.Any("buckets", bks), zap.Any("servingStreams", pl.GetServingSourceStreamNames())) + log.Infow("Created a job successfully for ISB creating", zap.Any("buffers", bfs), zap.Any("buckets", bks), zap.Any("servingStreams", pl.GetServingSourceStoreName())) r.recorder.Eventf(pl, corev1.EventTypeNormal, "CreateJobForISBCeationSuccessful", "Create ISB creation job successfully") } @@ -614,7 +614,7 @@ func (r *pipelineReconciler) cleanUpBuffers(ctx context.Context, pl *dfv1.Pipeli args = append(args, fmt.Sprintf("--buffers=%s", strings.Join(allBuffers, ","))) args = append(args, fmt.Sprintf("--buckets=%s", strings.Join(allBuckets, ","))) args = append(args, fmt.Sprintf("--side-inputs-store=%s", pl.GetSideInputsStoreName())) - args = append(args, fmt.Sprintf("--serving-source-streams=%s", strings.Join(pl.GetServingSourceStreamNames(), ","))) + args = append(args, fmt.Sprintf("--serving-source-store=%s", pl.GetServingSourceStoreName())) batchJob := buildISBBatchJob(pl, r.image, isbSvc.Status.Config, "isbsvc-delete", args, "cln") batchJob.OwnerReferences = []metav1.OwnerReference{} diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml index 36c419bb3e..6ac06e1bcd 100644 --- a/rust/.rustfmt.toml +++ b/rust/.rustfmt.toml @@ -1 +1,2 @@ -edition = "2021" \ No newline at end of file +edition = "2021" +group_imports = "StdExternalCrate" \ No newline at end of file diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8c3b7ef521..140a7bde9a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -66,9 +66,9 @@ dependencies = [ [[package]] name = "async-nats" -version = "0.38.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76433c4de73442daedb3a59e991d94e85c14ebfc33db53dfcd347a21cd6ef4f8" +checksum = "a798aab0c0203b31d67d501e5ed1f3ac6c36a329899ce47fc93c3bea53f3ae89" dependencies = [ "base64 0.22.1", "bytes", @@ -1028,41 +1028,6 @@ dependencies = [ "syn 2.0.90", ] -[[package]] -name = "darling" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.90", -] - -[[package]] -name = "darling_macro" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" -dependencies = [ - "darling_core", - "quote", - "syn 2.0.90", -] - [[package]] name = "data-encoding" version = "2.6.0" @@ -1090,37 +1055,6 @@ dependencies = [ "serde", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.90", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn 2.0.90", -] - [[package]] name = "diff" version = "0.1.13" @@ -1857,12 +1791,6 @@ dependencies = [ "syn 2.0.90", ] -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.0.3" @@ -2357,7 +2285,6 @@ dependencies = [ "base64 0.22.1", "bytes", "chrono", - "derive_builder", "futures", "http 1.2.0", "hyper-util", @@ -3561,6 +3488,7 @@ dependencies = [ name = "serving" version = "0.1.0" dependencies = [ + "async-nats", "axum 0.8.2", "axum-macros", "axum-server", @@ -3581,6 +3509,7 @@ dependencies = [ "serde_json", "thiserror 1.0.69", "tokio", + "tokio-stream", "tower 0.5.2", "tower-http", "tracing", @@ -3709,12 +3638,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "strsim" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" - [[package]] name = "subtle" version = "2.6.1" @@ -3971,9 +3894,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", @@ -4007,11 +3930,11 @@ dependencies = [ "httparse", "rand", "ring", - "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", "tokio-util", + "webpki-roots 0.26.7", ] [[package]] diff --git a/rust/Cargo.toml b/rust/Cargo.toml index da5a3f74c4..b80cc834fb 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -41,8 +41,8 @@ verbose_file_reads = "warn" # This profile optimizes for runtime performance and small binary size at the expense of longer build times. # Compared to default release profile, this profile reduced binary size from 29MB to 21MB # and increased build time (with only one line change in code) from 12 seconds to 133 seconds (tested on Mac M2 Max). -[profile.release] -lto = "fat" +#[profile.release] +#lto = "fat" # This profile optimizes for short build times at the expense of larger binary size and slower runtime performance. # If you have to rebuild image often, in Dockerfile you may replace `--release` passed to cargo command with `--profile quick-release` diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index c431d528b7..e969eff790 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -48,8 +48,7 @@ kube = "0.95.0" futures = "0.3.30" pin-project = "1.1.5" rand = "0.8.5" -async-nats = "0.38.0" -derive_builder = "0.20.2" +async-nats = "0.39.0" [dev-dependencies] tempfile = "3.11.0" diff --git a/rust/numaflow-core/src/config.rs b/rust/numaflow-core/src/config.rs index a07ffff147..307b1d9e98 100644 --- a/rust/numaflow-core/src/config.rs +++ b/rust/numaflow-core/src/config.rs @@ -25,6 +25,8 @@ pub(crate) mod pipeline; pub const NUMAFLOW_MONO_VERTEX_NAME: &str = "NUMAFLOW_MONO_VERTEX_NAME"; const NUMAFLOW_VERTEX_NAME: &str = "NUMAFLOW_VERTEX_NAME"; const NUMAFLOW_REPLICA: &str = "NUMAFLOW_REPLICA"; +const NUMAFLOW_PIPELINE_NAME: &str = "NUMAFLOW_PIPELINE_NAME"; +const NUMAFLOW_NAMESPACE: &str = "NUMAFLOW_NAMESPACE"; static VERTEX_NAME: OnceLock = OnceLock::new(); /// fetch the vertex name from the environment variable @@ -59,7 +61,7 @@ pub(crate) fn get_component_type() -> &'static str { static PIPELINE_NAME: OnceLock = OnceLock::new(); pub(crate) fn get_pipeline_name() -> &'static str { - PIPELINE_NAME.get_or_init(|| env::var("NUMAFLOW_PIPELINE_NAME").unwrap_or_default()) + PIPELINE_NAME.get_or_init(|| env::var(NUMAFLOW_PIPELINE_NAME).unwrap_or_default()) } static VERTEX_REPLICA: OnceLock = OnceLock::new(); @@ -74,6 +76,13 @@ pub(crate) fn get_vertex_replica() -> &'static u16 { }) } +static NAMESPACE: OnceLock = OnceLock::new(); + +/// fetch the namespace from the environment variable +pub(crate) fn get_namespace() -> &'static str { + NAMESPACE.get_or_init(|| env::var(NUMAFLOW_NAMESPACE).unwrap_or_default()) +} + /// Exposes the [Settings] via lazy loading. pub fn config() -> &'static Settings { static CONF: OnceLock = OnceLock::new(); diff --git a/rust/numaflow-core/src/config/components.rs b/rust/numaflow-core/src/config/components.rs index 8dfdd96a95..d5eee4146f 100644 --- a/rust/numaflow-core/src/config/components.rs +++ b/rust/numaflow-core/src/config/components.rs @@ -13,6 +13,7 @@ pub(crate) mod source { use numaflow_pulsar::source::{PulsarAuth, PulsarSourceConfig}; use tracing::warn; + use crate::config::{get_namespace, get_pipeline_name, get_vertex_name}; use crate::error::Error; use crate::Result; @@ -126,7 +127,7 @@ pub(crate) mod source { .map_err(|e| Error::Config(format!("Reading API auth token secret: {e:?}")))?; settings.api_auth_token = Some(secret); } else { - tracing::warn!("Authentication token for Serving API is specified, but the secret is empty"); + warn!("Authentication token for Serving API is specified, but the secret is empty"); }; } @@ -146,6 +147,8 @@ pub(crate) mod source { } settings.redis.ttl_secs = Some(ttl_secs); } + settings.js_store = + format!("{}-{}_SERVING_STORE", get_namespace(), get_pipeline_name(),); settings.redis.addr = cfg.store.url; settings.drain_timeout_secs = cfg.request_timeout_seconds.unwrap_or(120).max(1) as u64; // Ensure timeout is atleast 1 second diff --git a/rust/numaflow-core/src/config/monovertex.rs b/rust/numaflow-core/src/config/monovertex.rs index edbbec3997..9b36d58504 100644 --- a/rust/numaflow-core/src/config/monovertex.rs +++ b/rust/numaflow-core/src/config/monovertex.rs @@ -7,7 +7,10 @@ use numaflow_models::models::MonoVertex; use serde_json::from_slice; use super::pipeline::ServingCallbackConfig; -use super::{DEFAULT_CALLBACK_CONCURRENCY, ENV_CALLBACK_CONCURRENCY, ENV_CALLBACK_ENABLED}; +use super::{ + get_namespace, get_pipeline_name, get_vertex_name, DEFAULT_CALLBACK_CONCURRENCY, + ENV_CALLBACK_CONCURRENCY, ENV_CALLBACK_ENABLED, +}; use crate::config::components::metrics::MetricsConfig; use crate::config::components::sink::SinkConfig; use crate::config::components::source::{GeneratorConfig, SourceConfig}; @@ -158,6 +161,11 @@ impl MonovertexConfig { )) })?; callback_config = Some(ServingCallbackConfig { + callback_store: Box::leak(Box::new(format!( + "{}-{}_SERVING_STORE", + get_namespace(), + get_pipeline_name(), + ))), callback_concurrency, }); } diff --git a/rust/numaflow-core/src/config/pipeline.rs b/rust/numaflow-core/src/config/pipeline.rs index 1d272ae9cc..5f63acb927 100644 --- a/rust/numaflow-core/src/config/pipeline.rs +++ b/rust/numaflow-core/src/config/pipeline.rs @@ -60,6 +60,7 @@ pub(crate) struct PipelineConfig { #[derive(Debug, Clone, PartialEq)] pub(crate) struct ServingCallbackConfig { + pub(crate) callback_store: &'static str, pub(crate) callback_concurrency: usize, } @@ -439,6 +440,9 @@ impl PipelineConfig { )) })?; callback_config = Some(ServingCallbackConfig { + callback_store: Box::leak( + format!("{}-{}_SERVING_STORE", namespace, pipeline_name).into_boxed_str(), + ), callback_concurrency, }); } diff --git a/rust/numaflow-core/src/mapper/map.rs b/rust/numaflow-core/src/mapper/map.rs index 3c12e63dd7..c41d501ab0 100644 --- a/rust/numaflow-core/src/mapper/map.rs +++ b/rust/numaflow-core/src/mapper/map.rs @@ -1,14 +1,6 @@ use std::sync::Arc; use std::time::Duration; -use crate::config::pipeline::map::MapMode; -use crate::error; -use crate::error::Error; -use crate::mapper::map::user_defined::{ - UserDefinedBatchMap, UserDefinedStreamMap, UserDefinedUnaryMap, -}; -use crate::message::{Message, Offset}; -use crate::tracker::TrackerHandle; use numaflow_pb::clients::map::map_client::MapClient; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; use tokio::task::JoinHandle; @@ -17,6 +9,15 @@ use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; use tonic::transport::Channel; use tracing::{info, warn}; + +use crate::config::pipeline::map::MapMode; +use crate::error; +use crate::error::Error; +use crate::mapper::map::user_defined::{ + UserDefinedBatchMap, UserDefinedStreamMap, UserDefinedUnaryMap, +}; +use crate::message::{Message, Offset}; +use crate::tracker::TrackerHandle; pub(super) mod user_defined; /// UnaryActorMessage is a message that is sent to the UnaryMapperActor. @@ -629,17 +630,18 @@ impl MapHandle { mod tests { use std::time::Duration; + use numaflow::{batchmap, map, mapstream}; + use numaflow_pb::clients::map::map_client::MapClient; + use tempfile::TempDir; + use tokio::sync::{mpsc::Sender, oneshot}; + use tokio::time::sleep; + use super::*; use crate::{ message::{MessageID, Offset, StringOffset}, shared::grpc::create_rpc_channel, Result, }; - use numaflow::{batchmap, map, mapstream}; - use numaflow_pb::clients::map::map_client::MapClient; - use tempfile::TempDir; - use tokio::sync::{mpsc::Sender, oneshot}; - use tokio::time::sleep; struct SimpleMapper; diff --git a/rust/numaflow-core/src/mapper/map/user_defined.rs b/rust/numaflow-core/src/mapper/map/user_defined.rs index 7a30a25a5b..d6a728838a 100644 --- a/rust/numaflow-core/src/mapper/map/user_defined.rs +++ b/rust/numaflow-core/src/mapper/map/user_defined.rs @@ -1,10 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::config::get_vertex_name; -use crate::error::{Error, Result}; -use crate::message::{Message, MessageID, Offset}; -use crate::shared::grpc::prost_timestamp_from_utc; use chrono::{DateTime, Utc}; use numaflow_pb::clients::map::{self, map_client::MapClient, MapRequest, MapResponse}; use tokio::sync::Mutex; @@ -14,6 +10,11 @@ use tonic::transport::Channel; use tonic::{Request, Streaming}; use tracing::error; +use crate::config::get_vertex_name; +use crate::error::{Error, Result}; +use crate::message::{Message, MessageID, Offset}; +use crate::shared::grpc::prost_timestamp_from_utc; + type ResponseSenderMap = Arc>>)>>>; diff --git a/rust/numaflow-core/src/monovertex.rs b/rust/numaflow-core/src/monovertex.rs index 7f5b17c9c1..21740b5b87 100644 --- a/rust/numaflow-core/src/monovertex.rs +++ b/rust/numaflow-core/src/monovertex.rs @@ -1,4 +1,3 @@ -use serving::callback::CallbackHandler; use tokio_util::sync::CancellationToken; use tracing::info; @@ -24,11 +23,12 @@ pub(crate) async fn start_forwarder( cln_token: CancellationToken, config: &MonovertexConfig, ) -> error::Result<()> { - let callback_handler = config - .callback_config - .as_ref() - .map(|cb_cfg| CallbackHandler::new(config.name.clone(), cb_cfg.callback_concurrency)); - let tracker_handle = TrackerHandle::new(None, callback_handler); + // FIXME: (serving) + // let callback_handler = config + // .callback_config + // .as_ref() + // .map(|cb_cfg| CallbackHandler::new(config.name.clone(), cb_cfg.callback_concurrency)); + let tracker_handle = TrackerHandle::new(None, None); let (transformer, transformer_grpc_client) = create_components::create_transformer( config.batch_size, @@ -39,6 +39,7 @@ pub(crate) async fn start_forwarder( .await?; let (source, source_grpc_client) = create_components::create_source( + None, config.batch_size, config.read_timeout, &config.source_config, diff --git a/rust/numaflow-core/src/pipeline.rs b/rust/numaflow-core/src/pipeline.rs index f816b4d5eb..d3a05fef54 100644 --- a/rust/numaflow-core/src/pipeline.rs +++ b/rust/numaflow-core/src/pipeline.rs @@ -152,9 +152,19 @@ async fn start_source_forwarder( source_config: SourceVtxConfig, source_watermark_handle: Option, ) -> Result<()> { - let serving_callback_handler = config.callback_config.as_ref().map(|cb_cfg| { - CallbackHandler::new(config.vertex_name.to_string(), cb_cfg.callback_concurrency) - }); + let serving_callback_handler = if let Some(cb_cfg) = &config.callback_config { + Some( + CallbackHandler::new( + config.vertex_name.to_string(), + js_context.clone(), + cb_cfg.callback_store, + cb_cfg.callback_concurrency, + ) + .await, + ) + } else { + None + }; let tracker_handle = TrackerHandle::new(None, serving_callback_handler); let buffer_writer = create_buffer_writer( @@ -175,6 +185,7 @@ async fn start_source_forwarder( .await?; let (source, source_grpc_client) = create_components::create_source( + Some(js_context.clone()), config.batch_size, config.read_timeout, &source_config.source_config, @@ -227,9 +238,19 @@ async fn start_map_forwarder( let mut mapper_grpc_client = None; let mut isb_lag_readers = vec![]; - let serving_callback_handler = config.callback_config.as_ref().map(|cb_cfg| { - CallbackHandler::new(config.vertex_name.to_string(), cb_cfg.callback_concurrency) - }); + let serving_callback_handler = if let Some(cb_cfg) = &config.callback_config { + Some( + CallbackHandler::new( + config.vertex_name.to_string(), + js_context.clone(), + cb_cfg.callback_store, + cb_cfg.callback_concurrency, + ) + .await, + ) + } else { + None + }; // create tracker and buffer writer, they can be shared across all forwarders let tracker_handle = @@ -326,9 +347,19 @@ async fn start_sink_forwarder( .ok_or_else(|| error::Error::Config("No from vertex config found".to_string()))? .reader_config; - let serving_callback_handler = config.callback_config.as_ref().map(|cb_cfg| { - CallbackHandler::new(config.vertex_name.to_string(), cb_cfg.callback_concurrency) - }); + let serving_callback_handler = if let Some(cb_cfg) = &config.callback_config { + Some( + CallbackHandler::new( + config.vertex_name.to_string(), + js_context.clone(), + cb_cfg.callback_store, + cb_cfg.callback_concurrency, + ) + .await, + ) + } else { + None + }; // Create sink writers and buffer readers for each stream let mut sink_writers = vec![]; diff --git a/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs index 027129cdf8..0cfadd0a23 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/map_forwarder.rs @@ -1,10 +1,11 @@ +use tokio_util::sync::CancellationToken; +use tracing::error; + use crate::error::Error; use crate::mapper::map::MapHandle; use crate::pipeline::isb::jetstream::reader::JetStreamReader; use crate::pipeline::isb::jetstream::writer::JetstreamWriter; use crate::Result; -use tokio_util::sync::CancellationToken; -use tracing::error; /// Map forwarder is a component which starts a streaming reader, a mapper, and a writer /// and manages the lifecycle of these components. diff --git a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs index e14ff626c4..e496c54a42 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs @@ -1,9 +1,10 @@ +use tokio_util::sync::CancellationToken; +use tracing::error; + use crate::error::Error; use crate::pipeline::isb::jetstream::reader::JetStreamReader; use crate::sink::SinkWriter; use crate::Result; -use tokio_util::sync::CancellationToken; -use tracing::error; /// Sink forwarder is a component which starts a streaming reader and a sink writer /// and manages the lifecycle of these components. diff --git a/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs b/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs index 5b18358dbd..d7257f99b2 100644 --- a/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs +++ b/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs @@ -1,19 +1,7 @@ use std::fmt; use std::sync::Arc; use std::time::Duration; -use tracing::warn; -use crate::config::get_vertex_name; -use crate::config::pipeline::isb::{BufferReaderConfig, Stream}; -use crate::error::Error; -use crate::message::{IntOffset, Message, MessageID, MessageType, Metadata, Offset, ReadAck}; -use crate::metrics::{ - pipeline_forward_metric_labels, pipeline_isb_metric_labels, pipeline_metrics, -}; -use crate::shared::grpc::utc_from_timestamp; -use crate::tracker::TrackerHandle; -use crate::watermark::isb::ISBWatermarkHandle; -use crate::{metrics, Result}; use async_nats::jetstream::{ consumer::PullConsumer, AckKind, Context, Message as JetstreamMessage, }; @@ -26,8 +14,21 @@ use tokio::time::{self, Instant}; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; +use tracing::warn; use tracing::{error, info}; +use crate::config::get_vertex_name; +use crate::config::pipeline::isb::{BufferReaderConfig, Stream}; +use crate::error::Error; +use crate::message::{IntOffset, Message, MessageID, MessageType, Metadata, Offset, ReadAck}; +use crate::metrics::{ + pipeline_forward_metric_labels, pipeline_isb_metric_labels, pipeline_metrics, +}; +use crate::shared::grpc::utc_from_timestamp; +use crate::tracker::TrackerHandle; +use crate::watermark::isb::ISBWatermarkHandle; +use crate::{metrics, Result}; + const ACK_RETRY_INTERVAL: u64 = 100; const ACK_RETRY_ATTEMPTS: usize = usize::MAX; const MAX_ACK_PENDING: usize = 25000; diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index ff82256dde..9525cf6d45 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::time::Duration; +use async_nats::jetstream::Context; use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; @@ -268,6 +269,7 @@ pub(crate) async fn create_mapper( /// Creates a source type based on the configuration pub async fn create_source( + js_context: Option, batch_size: usize, read_timeout: Duration, source_config: &SourceConfig, @@ -355,9 +357,14 @@ pub async fn create_source( // for serving we use batch size as 1 as we are not batching the messages // and read ahead is enabled as it supports it. SourceType::Serving(config) => { - let serving = - ServingSource::new(Arc::clone(config), 1, read_timeout, *get_vertex_replica()) - .await?; + let serving = ServingSource::new( + js_context.expect("Jetstream context is required for serving source"), + Arc::clone(config), + 1, + read_timeout, + *get_vertex_replica(), + ) + .await?; Ok(( Source::new( 1, diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index 130300e4ae..12ba7f6719 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -5,7 +5,6 @@ //! [Watermark]: https://numaflow.numaproj.io/core-concepts/watermarks/ use std::sync::Arc; -use tracing::warn; use numaflow_pulsar::source::PulsarSource; use numaflow_sqs::source::SQSSource; @@ -16,6 +15,7 @@ use tokio::task::JoinHandle; use tokio::time::Instant; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; +use tracing::warn; use tracing::{error, info}; use crate::config::{get_vertex_name, is_mono_vertex}; diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs index d83f0a0cc6..dcdbef95e9 100644 --- a/rust/numaflow-core/src/source/serving.rs +++ b/rust/numaflow-core/src/source/serving.rs @@ -87,6 +87,7 @@ impl super::LagReader for ServingSource { mod tests { use std::{collections::HashMap, sync::Arc, time::Duration}; + use async_nats::jetstream; use bytes::Bytes; use serving::{ServingSource, Settings}; @@ -139,19 +140,24 @@ mod tests { } } - #[cfg(feature = "redis-tests")] + #[cfg(all(feature = "redis-tests", feature = "nats-tests"))] #[tokio::test] async fn test_serving_source_reader_acker() -> Result<()> { let settings = Settings { app_listen_port: 2000, ..Default::default() }; + + let client = async_nats::connect("localhost:4222").await.unwrap(); + let js_context = jetstream::new(client); + let settings = Arc::new(settings); // Set up the CryptoProvider (controls core cryptography used by rustls) for the process // ServingSource starts an Axum HTTPS server in the background. Rustls is used to generate // self-signed certs when starting the server. let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let mut serving_source = ServingSource::new( + js_context, Arc::clone(&settings), 10, Duration::from_millis(1), diff --git a/rust/numaflow-core/src/tracker.rs b/rust/numaflow-core/src/tracker.rs index 7b83221193..c23c8298f6 100644 --- a/rust/numaflow-core/src/tracker.rs +++ b/rust/numaflow-core/src/tracker.rs @@ -14,16 +14,17 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::error::Error; -use crate::message::{Message, Offset, ReadAck}; -use crate::watermark::isb::ISBWatermarkHandle; -use crate::Result; use chrono::{DateTime, Utc}; use serving::callback::CallbackHandler; -use serving::{DEFAULT_CALLBACK_URL_HEADER_KEY, DEFAULT_ID_HEADER}; +use serving::DEFAULT_ID_HEADER; use tokio::sync::{mpsc, oneshot}; use tracing::error; +use crate::error::Error; +use crate::message::{Message, Offset, ReadAck}; +use crate::watermark::isb::ISBWatermarkHandle; +use crate::Result; + /// TrackerEntry represents the state of a tracked message. #[derive(Debug)] struct TrackerEntry { @@ -82,7 +83,6 @@ struct Tracker { #[derive(Debug)] struct ServingCallbackInfo { id: String, - callback_url: String, from_vertex: String, /// at the moment these are just tags. responses: Vec>>, @@ -92,15 +92,6 @@ impl TryFrom<&Message> for ServingCallbackInfo { type Error = Error; fn try_from(message: &Message) -> std::result::Result { - let callback_url = message - .headers - .get(DEFAULT_CALLBACK_URL_HEADER_KEY) - .ok_or_else(|| { - Error::Source(format!( - "{DEFAULT_CALLBACK_URL_HEADER_KEY} header is not present in the message headers", - )) - })? - .to_owned(); let uuid = message .headers .get(DEFAULT_ID_HEADER) @@ -120,7 +111,6 @@ impl TryFrom<&Message> for ServingCallbackInfo { Ok(ServingCallbackInfo { id: uuid, - callback_url, from_vertex, responses: vec![None], }) @@ -319,7 +309,6 @@ impl Tracker { let result = callback_handler .callback( callback_info.id, - callback_info.callback_url, callback_info.from_vertex, callback_info.responses, ) @@ -464,11 +453,10 @@ impl TrackerHandle { #[cfg(test)] mod tests { - use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use axum::routing::{get, post}; - use axum::{http::StatusCode, Router}; + use async_nats::jetstream; + use async_nats::jetstream::kv::Config; use bytes::Bytes; use tokio::sync::oneshot; use tokio::time::{timeout, Duration}; @@ -502,13 +490,10 @@ mod tests { assert!(callback_info.is_err()); const CALLBACK_URL: &str = "https://localhost/v1/process/callback"; - let headers = [ - (DEFAULT_CALLBACK_URL_HEADER_KEY, CALLBACK_URL), - (DEFAULT_ID_HEADER, "1234"), - ] - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); + let headers = [(DEFAULT_ID_HEADER, "1234")] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); message.headers = headers; const FROM_VERTEX_NAME: &str = "source-vertex"; @@ -518,7 +503,6 @@ mod tests { let callback_info: ServingCallbackInfo = TryFrom::try_from(&message).unwrap(); assert_eq!(callback_info.id, "1234"); - assert_eq!(callback_info.callback_url, CALLBACK_URL); assert_eq!(callback_info.from_vertex, FROM_VERTEX_NAME); assert_eq!(callback_info.responses, vec![None]); } @@ -688,66 +672,29 @@ mod tests { assert!(handle.is_empty().await.unwrap(), "Tracker should be empty"); } + #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_tracker_with_callback_handler() -> Result<()> { - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let port = listener - .local_addr() - .map_err(|e| format!("Failed to bind to 127.0.0.1:0: error={e:?}"))? - .port(); - - let server_addr = format!("127.0.0.1:{port}"); - let callback_url = format!("http://{server_addr}/v1/process/callback"); - - let request_count = Arc::new(AtomicUsize::new(0)); - let router = Router::new() - .route("/livez", get(|| async { StatusCode::OK })) - .route( - "/v1/process/callback", - post({ - let req_count = Arc::clone(&request_count); - || async move { - req_count.fetch_add(1, Ordering::Relaxed); - StatusCode::OK - } - }), - ); - - let server = tokio::spawn(async move { - axum::serve(listener, router).await.unwrap(); - }); + let store_name = "test_tracker_with_callback_handler"; + let client = async_nats::connect("localhost:4222").await.unwrap(); + let js_context = jetstream::new(client); + let callback_bucket = js_context + .create_key_value(Config { + bucket: store_name.to_string(), + history: 1, + ..Default::default() + }) + .await + .unwrap(); - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(2)) - .build()?; - - // Wait for the server to be ready - let mut server_ready = false; - let health_url = format!("http://{server_addr}/livez"); - for _ in 0..10 { - let Ok(resp) = client.get(&health_url).send().await else { - tokio::time::sleep(Duration::from_millis(5)).await; - continue; - }; - if resp.status().is_success() { - server_ready = true; - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - assert!(server_ready, "Server is not ready"); + let callback_handler = + CallbackHandler::new("test".into(), js_context.clone(), store_name, 10).await; - let callback_handler = CallbackHandler::new("test".into(), 10); let handle = TrackerHandle::new(None, Some(callback_handler)); let (ack_send, ack_recv) = oneshot::channel(); - let headers = [ - (DEFAULT_CALLBACK_URL_HEADER_KEY, callback_url), - (DEFAULT_ID_HEADER, "1234".into()), - ] - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); + let mut headers = HashMap::new(); + headers.insert(DEFAULT_ID_HEADER.to_string(), "1234".to_string()); let offset = Offset::String(StringOffset::new("offset1".to_string(), 0)); let message = Message { @@ -779,17 +726,16 @@ mod tests { assert_eq!(result.unwrap(), ReadAck::Ack); assert!(handle.is_empty().await.unwrap(), "Tracker should be empty"); - // Callback request is made after sending data on ack_send channel. - let mut received_callback_request = false; - for _ in 0..5 { - tokio::time::sleep(Duration::from_millis(10)).await; - received_callback_request = request_count.load(Ordering::Relaxed) == 1; - if received_callback_request { - break; - } - } - assert!(received_callback_request, "Expected one callback request"); - server.abort(); + // Verify that the callback was written to the KV store + let callback_key = "1234"; + let callback_value = callback_bucket.get(callback_key).await.unwrap(); + assert!( + callback_value.is_some(), + "Callback should be written to the KV store" + ); + + // Clean up the KV store + js_context.delete_key_value(store_name).await.unwrap(); Ok(()) } } diff --git a/rust/numaflow-core/src/watermark/processor/manager.rs b/rust/numaflow-core/src/watermark/processor/manager.rs index 094416c7b0..b7140232f0 100644 --- a/rust/numaflow-core/src/watermark/processor/manager.rs +++ b/rust/numaflow-core/src/watermark/processor/manager.rs @@ -335,7 +335,6 @@ impl ProcessorManager { processors: Arc>>, ) { let mut hb_watcher = Self::create_watcher(hb_bucket.clone()).await; - loop { let Some(val) = hb_watcher.next().await else { warn!("HB watcher stopped, recreating watcher"); @@ -416,7 +415,6 @@ impl ProcessorManager { /// creates a watcher for the given bucket, will retry infinitely until it succeeds async fn create_watcher(bucket: async_nats::jetstream::kv::Store) -> Watch { const RECONNECT_INTERVAL: u64 = 1000; - // infinite retry let interval = fixed::Interval::from_millis(RECONNECT_INTERVAL).take(usize::MAX); diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index 9e70b8bad3..fb487c656c 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -25,7 +25,7 @@ axum-macros = "0.4.1" hyper-util = { version = "0.1.6", features = ["client-legacy"] } serde_json = "1.0.120" tower-http = { version = "0.5.2", features = ["trace", "timeout"] } -uuid = { version = "1.10.0", features = ["std","v7"] } +uuid = { version = "1.10.0", features = ["std", "v7"] } redis = { version = "0.26.0", features = [ "tokio-comp", "aio", @@ -40,6 +40,8 @@ prometheus-client = "0.22.3" thiserror = "1.0.63" reqwest = { workspace = true, features = ["rustls-tls", "json"] } http = "1.2.0" +async-nats = "0.39.0" +tokio-stream = "0.1.17" [dev-dependencies] reqwest = { workspace = true, features = ["json"] } diff --git a/rust/serving/src/app.rs b/rust/serving/src/app.rs index 9e904ba949..7aec44431e 100644 --- a/rust/serving/src/app.rs +++ b/rust/serving/src/app.rs @@ -19,10 +19,11 @@ use tracing::{info, info_span, Span}; use uuid::Uuid; use self::{ - callback::callback_handler, direct_proxy::direct_proxy, jetstream_proxy::jetstream_proxy, - message_path::get_message_path, + direct_proxy::direct_proxy, jetstream_proxy::jetstream_proxy, message_path::get_message_path, }; -use crate::app::callback::store::Store; +use crate::app::callback::callback_handler; +use crate::app::callback::cbstore::CallbackStore; +use crate::app::callback::datumstore::DatumStore; use crate::metrics::capture_metrics; use crate::AppState; use crate::Error::InitError; @@ -44,12 +45,13 @@ pub(crate) mod tracker; // - [ ] outer fallback for /v1/direct /// Start the main application Router and the axum server. -pub(crate) async fn start_main_server( - app: AppState, +pub(crate) async fn start_main_server( + app: AppState, tls_config: RustlsConfig, ) -> crate::Result<()> where - T: Clone + Send + Sync + Store + 'static, + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, { let app_addr: SocketAddr = format!("0.0.0.0:{}", &app.settings.app_listen_port) .parse() @@ -72,9 +74,10 @@ where Ok(()) } -pub(crate) async fn router_with_auth(app: AppState) -> crate::Result +pub(crate) async fn router_with_auth(app: AppState) -> crate::Result where - T: Clone + Send + Sync + Store + 'static, + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, { let layers = ServiceBuilder::new() // Add tracing to all requests @@ -208,8 +211,11 @@ async fn auth_middleware( } } -async fn setup_app( - app: AppState, +async fn setup_app< + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, +>( + app: AppState, ) -> crate::Result { let parent = Router::new() .route("/health", get(health_check)) @@ -241,8 +247,11 @@ async fn livez() -> impl IntoResponse { StatusCode::NO_CONTENT } -async fn readyz( - State(app): State>, +async fn readyz< + T: Send + Sync + Clone + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( + State(app): State>, ) -> impl IntoResponse { if app.callback_state.clone().ready().await { StatusCode::NO_CONTENT @@ -251,8 +260,11 @@ async fn readyz( } } -async fn routes( - app_state: AppState, +async fn routes< + T: Clone + Send + Sync + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( + app_state: AppState, ) -> crate::Result { let state = app_state.callback_state.clone(); let jetstream_proxy = jetstream_proxy(app_state.clone()).await?; @@ -266,109 +278,109 @@ async fn routes( .merge(message_path_handler)) } -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use axum::http::StatusCode; - use callback::state::State as CallbackState; - use tokio::sync::mpsc; - use tower::ServiceExt; - use tracker::MessageGraph; - - use super::*; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::Settings; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - type Result = core::result::Result; - type Error = Box; - - #[tokio::test] - async fn test_setup_app() -> Result<()> { - let settings = Arc::new(Settings::default()); - - let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - - let callback_state = CallbackState::new(msg_graph, mem_store).await?; - let (tx, _) = mpsc::channel(10); - let app = AppState { - message: tx, - settings, - callback_state, - }; - - let result = setup_app(app).await; - assert!(result.is_ok()); - Ok(()) - } - - #[tokio::test] - async fn test_health_check_endpoints() -> Result<()> { - let settings = Arc::new(Settings::default()); - - let mem_store = InMemoryStore::new(); - let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec)?; - let callback_state = CallbackState::new(msg_graph, mem_store).await?; - - let (messages_tx, _messages_rx) = mpsc::channel(10); - let app = AppState { - message: messages_tx, - settings, - callback_state, - }; - - let router = setup_app(app).await.unwrap(); - - let request = Request::builder().uri("/livez").body(Body::empty())?; - let response = router.clone().oneshot(request).await?; - assert_eq!(response.status(), StatusCode::NO_CONTENT); - - let request = Request::builder().uri("/readyz").body(Body::empty())?; - let response = router.clone().oneshot(request).await?; - assert_eq!(response.status(), StatusCode::NO_CONTENT); - - let request = Request::builder().uri("/health").body(Body::empty())?; - let response = router.clone().oneshot(request).await?; - assert_eq!(response.status(), StatusCode::OK); - Ok(()) - } - - #[tokio::test] - async fn test_auth_middleware() -> Result<()> { - let settings = Settings { - api_auth_token: Some("test-token".into()), - ..Default::default() - }; - - let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - let callback_state = CallbackState::new(msg_graph, mem_store).await?; - - let (messages_tx, _messages_rx) = mpsc::channel(10); - - let app_state = AppState { - message: messages_tx, - settings: Arc::new(settings), - callback_state, - }; - - let router = router_with_auth(app_state).await.unwrap(); - let res = router - .oneshot( - axum::extract::Request::builder() - .method("POST") - .uri("/v1/process/sync") - .body(Body::empty()) - .unwrap(), - ) - .await?; - - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); - Ok(()) - } -} +// #[cfg(test)] +// mod tests { +// use std::sync::Arc; +// +// use axum::http::StatusCode; +// use callback::state::State as CallbackState; +// use tokio::sync::mpsc; +// use tower::ServiceExt; +// use tracker::MessageGraph; +// +// use super::*; +// use crate::app::callback::datumstore::memstore::InMemoryStore; +// use crate::Settings; +// +// const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; +// +// type Result = core::result::Result; +// type Error = Box; +// +// #[tokio::test] +// async fn test_setup_app() -> Result<()> { +// let settings = Arc::new(Settings::default()); +// +// let mem_store = InMemoryStore::new(); +// let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; +// +// let callback_state = CallbackState::new(msg_graph, mem_store).await?; +// let (tx, _) = mpsc::channel(10); +// let app = AppState { +// message: tx, +// settings, +// callback_state, +// }; +// +// let result = setup_app(app).await; +// assert!(result.is_ok()); +// Ok(()) +// } +// +// #[tokio::test] +// async fn test_health_check_endpoints() -> Result<()> { +// let settings = Arc::new(Settings::default()); +// +// let mem_store = InMemoryStore::new(); +// let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec)?; +// let callback_state = CallbackState::new(msg_graph, mem_store).await?; +// +// let (messages_tx, _messages_rx) = mpsc::channel(10); +// let app = AppState { +// message: messages_tx, +// settings, +// callback_state, +// }; +// +// let router = setup_app(app).await.unwrap(); +// +// let request = Request::builder().uri("/livez").body(Body::empty())?; +// let response = router.clone().oneshot(request).await?; +// assert_eq!(response.status(), StatusCode::NO_CONTENT); +// +// let request = Request::builder().uri("/readyz").body(Body::empty())?; +// let response = router.clone().oneshot(request).await?; +// assert_eq!(response.status(), StatusCode::NO_CONTENT); +// +// let request = Request::builder().uri("/health").body(Body::empty())?; +// let response = router.clone().oneshot(request).await?; +// assert_eq!(response.status(), StatusCode::OK); +// Ok(()) +// } +// +// #[tokio::test] +// async fn test_auth_middleware() -> Result<()> { +// let settings = Settings { +// api_auth_token: Some("test-token".into()), +// ..Default::default() +// }; +// +// let mem_store = InMemoryStore::new(); +// let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; +// let callback_state = CallbackState::new(msg_graph, mem_store).await?; +// +// let (messages_tx, _messages_rx) = mpsc::channel(10); +// +// let app_state = AppState { +// message: messages_tx, +// settings: Arc::new(settings), +// callback_state, +// }; +// +// let router = router_with_auth(app_state).await.unwrap(); +// let res = router +// .oneshot( +// axum::extract::Request::builder() +// .method("POST") +// .uri("/v1/process/sync") +// .body(Body::empty()) +// .unwrap(), +// ) +// .await?; +// +// assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +// Ok(()) +// } +// } diff --git a/rust/serving/src/app/callback.rs b/rust/serving/src/app/callback.rs index f8a216437f..badb29d91c 100644 --- a/rust/serving/src/app/callback.rs +++ b/rust/serving/src/app/callback.rs @@ -1,39 +1,51 @@ -use axum::{body::Bytes, extract::State, http::HeaderMap, routing, Json, Router}; -use tracing::{error, info}; +use axum::extract::State; +use axum::routing; +use axum::Router; +use bytes::Bytes; +use http::HeaderMap; +use state::State as CallbackState; +use tracing::error; -use self::store::Store; +use crate::app::callback::cbstore::CallbackStore; +use crate::app::callback::datumstore::DatumStore; use crate::app::response::ApiError; use crate::callback::Callback; /// in-memory state store including connection tracking pub(crate) mod state; -use state::State as CallbackState; -/// store for storing the state -pub(crate) mod store; +pub(crate) mod cbstore; +/// datumstore for storing the state +pub(crate) mod datumstore; #[derive(Clone)] -struct CallbackAppState { +struct CallbackAppState { tid_header: String, - callback_state: CallbackState, + callback_state: CallbackState, } -pub fn callback_handler( +pub fn callback_handler< + T: Send + Sync + Clone + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( tid_header: String, - callback_state: CallbackState, + callback_state: CallbackState, ) -> Router { let app_state = CallbackAppState { tid_header, callback_state, }; Router::new() - .route("/callback", routing::post(callback)) + // .route("/callback", routing::post(callback)) .route("/callback_save", routing::post(callback_save)) .with_state(app_state) } -async fn callback_save( - State(app_state): State>, +async fn callback_save< + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, +>( + State(app_state): State>, headers: HeaderMap, body: Bytes, ) -> Result<(), ApiError> { @@ -56,157 +68,21 @@ async fn callback_save( Ok(()) } - -async fn callback( - State(app_state): State>, - Json(payload): Json>, -) -> Result<(), ApiError> { - info!(?payload, "Received callback request"); - app_state - .callback_state - .clone() - .insert_callback_requests(payload) - .await - .map_err(|e| { - error!(error=?e, "Inserting callback requests"); - ApiError::InternalServerError("Failed to insert callback requests".to_string()) - })?; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use axum::body::Body; - use axum::extract::Request; - use axum::http::header::CONTENT_TYPE; - use axum::http::StatusCode; - use tower::ServiceExt; - - use super::*; - use crate::app::callback::state::State as CallbackState; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::tracker::MessageGraph; - use crate::callback::Response; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_callback_failure() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let payload = vec![Callback { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }]; - - let res = Request::builder() - .method("POST") - .uri("/callback") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_vec(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - // id is not registered, so it should return 500 - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[tokio::test] - async fn test_callback_success() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let mut state = CallbackState::new(msg_graph, store).await.unwrap(); - - let (_, x) = state.register(Some("test_id".to_string())).await.unwrap(); - // spawn a task which will be awaited later - let handle = tokio::spawn(async move { - let _ = x.await.unwrap(); - }); - - let app = callback_handler("ID".to_owned(), state); - - let payload = vec![ - Callback { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: "test_id".to_string(), - vertex: "cat".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: "test_id".to_string(), - vertex: "out".to_string(), - cb_time: 12345, - from_vertex: "cat".to_string(), - responses: vec![Response { tags: None }], - }, - ]; - - let res = Request::builder() - .method("POST") - .uri("/callback") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_vec(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - handle.await.unwrap(); - } - - #[tokio::test] - async fn test_callback_save() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let res = Request::builder() - .method("POST") - .uri("/callback_save") - .header(CONTENT_TYPE, "application/json") - .header("id", "test_id") - .body(Body::from("test_body")) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[tokio::test] - async fn test_without_id_header() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let res = Request::builder() - .method("POST") - .uri("/callback_save") - .body(Body::from("test_body")) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } -} +// +// async fn callback( +// State(app_state): State>, +// Json(payload): Json>, +// ) -> Result<(), ApiError> { +// info!(?payload, "Received callback request"); +// app_state +// .callback_state +// .clone() +// .insert_callback_requests(payload) +// .await +// .map_err(|e| { +// error!(error=?e, "Inserting callback requests"); +// ApiError::InternalServerError("Failed to insert callback requests".to_string()) +// })?; +// +// Ok(()) +// } diff --git a/rust/serving/src/app/callback/cbstore.rs b/rust/serving/src/app/callback/cbstore.rs new file mode 100644 index 0000000000..953a05c231 --- /dev/null +++ b/rust/serving/src/app/callback/cbstore.rs @@ -0,0 +1,74 @@ +use std::sync::Arc; + +use bytes::Bytes; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; + +use crate::app::callback::datumstore::Result as StoreResult; +use crate::callback::Callback; + +pub(crate) mod jetstreamstore; +pub(crate) mod memstore; + +#[derive(Debug, PartialEq)] +/// Represents the current processing status of a request id in the `Store`. +pub(crate) enum ProcessingStatus { + InProgress, + Completed(String), // Store subgraph string + Failed(String), // Store error string +} + +impl From for ProcessingStatus { + fn from(value: Bytes) -> Self { + let in_progress = Bytes::from_static(b"inprogress"); + let completed_prefix = b"completed:"; + let failed_prefix = b"failed:"; + + if value == in_progress { + ProcessingStatus::InProgress + } else if value.starts_with(completed_prefix) { + let subgraph = String::from_utf8(value[completed_prefix.len()..].to_vec()) + .expect("Invalid UTF-8 sequence"); + ProcessingStatus::Completed(subgraph) + } else if value.starts_with(failed_prefix) { + let error = String::from_utf8(value[failed_prefix.len()..].to_vec()) + .expect("Invalid UTF-8 sequence"); + ProcessingStatus::Failed(error) + } else { + panic!("Invalid processing status") + } + } +} + +impl From for Bytes { + fn from(value: ProcessingStatus) -> Self { + match value { + ProcessingStatus::InProgress => Bytes::from_static(b"inprogress"), + ProcessingStatus::Completed(subgraph) => Bytes::from(format!("completed:{}", subgraph)), + ProcessingStatus::Failed(error) => Bytes::from(format!("failed:{}", error)), + } + } +} + +/// Store trait to store the callback information. +#[trait_variant::make(CallbackStore: Send)] +#[allow(dead_code)] +pub(crate) trait LocalCallbackStore { + /// Register a request id in the store. If user provides a request id, the same should be returned + /// if it doesn't already exist in the store. An error should be returned if the user-specified request id + /// already exists in the store. If the `id` is `None`, the store should generate a new unique request id. + async fn register(&mut self, id: &str) -> StoreResult<()>; + /// This method will be called when processing is completed for a request id. + async fn deregister(&mut self, id: &str, sub_graph: &str) -> StoreResult<()>; + /// This method will be called when processing is failed for a request id. + async fn mark_as_failed(&mut self, id: &str, error: &str) -> StoreResult<()>; + /// retrieve the callback payloads for a given request id. + async fn watch_callbacks( + &mut self, + id: &str, + ) -> StoreResult<(ReceiverStream>, JoinHandle<()>)>; + /// retrieve the processing status of a request id. + async fn status(&mut self, id: &str) -> StoreResult; + /// check if the store is ready + async fn ready(&mut self) -> bool; +} diff --git a/rust/serving/src/app/callback/cbstore/jetstreamstore.rs b/rust/serving/src/app/callback/cbstore/jetstreamstore.rs new file mode 100644 index 0000000000..ceccdb8925 --- /dev/null +++ b/rust/serving/src/app/callback/cbstore/jetstreamstore.rs @@ -0,0 +1,234 @@ +use std::sync::Arc; + +use async_nats::jetstream::kv::Store; +use async_nats::jetstream::Context; +use bytes::Bytes; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; + +use crate::app::callback::cbstore::ProcessingStatus; +use crate::app::callback::datumstore::{Error as StoreError, Result as StoreResult}; +use crate::app::callback::Callback; + +#[derive(Clone)] +pub(crate) struct JSCallbackStore { + kv_store: Store, +} + +impl JSCallbackStore { + pub(crate) async fn new(js_context: Context, bucket_name: &str) -> StoreResult { + let kv_store = js_context + .get_key_value(bucket_name) + .await + .map_err(|e| StoreError::Connection(format!("Failed to get kv store: {e:?}")))?; + Ok(Self { kv_store }) + } +} + +impl super::CallbackStore for JSCallbackStore { + async fn register(&mut self, id: &str) -> StoreResult<()> { + let key = format!("{}=status", id); + println!("Registering key: {}", key); + let exists = self.kv_store.get(&key).await.map_err(|e| { + StoreError::StoreRead(format!("Failed to get request id in kv store: {e:?}")) + })?; + if exists.is_some() { + return Err(StoreError::DuplicateRequest(id.to_string())); + } + + self.kv_store + .put(&key, ProcessingStatus::InProgress.into()) + .await + .map_err(|e| { + StoreError::StoreWrite(format!( + "Failed to register request id {key} in kv store: {e:?}" + )) + })?; + + self.kv_store + .put(id, Bytes::from_static(b"")) + .await + .map_err(|e| { + StoreError::StoreWrite(format!( + "Failed to register request id {id} in kv store: {e:?}" + )) + })?; + + Ok(()) + } + + async fn deregister(&mut self, id: &str, sub_graph: &str) -> StoreResult<()> { + let key = format!("{}=status", id); + let completed_value = format!("completed:{}", sub_graph); + self.kv_store + .put( + key, + ProcessingStatus::Completed(completed_value.to_string()).into(), + ) + .await + .map_err(|e| { + StoreError::StoreWrite(format!("Failed to mark request as done in kv store: {e:?}")) + })?; + Ok(()) + } + + async fn mark_as_failed(&mut self, id: &str, error: &str) -> StoreResult<()> { + let key = format!("{}=status", id); + let failed_value = format!("failed:{}", error); + self.kv_store + .put( + key, + ProcessingStatus::Failed(failed_value.to_string()).into(), + ) + .await + .map_err(|e| { + StoreError::StoreWrite(format!( + "Failed to mark request as failed in kv store: {e:?}" + )) + })?; + Ok(()) + } + + async fn watch_callbacks( + &mut self, + id: &str, + ) -> StoreResult<(ReceiverStream>, JoinHandle<()>)> { + let mut watcher = self.kv_store.watch_with_history(id).await.map_err(|e| { + StoreError::StoreRead(format!("Failed to watch request id in kv store: {e:?}")) + })?; + let (tx, rx) = tokio::sync::mpsc::channel(10); + + let handle = tokio::spawn(async move { + // TODO: handle watch errors + while let Some(Ok(entry)) = watcher.next().await { + // all callbacks received + if entry.operation == async_nats::jetstream::kv::Operation::Delete { + break; + } + + if entry.value.is_empty() { + continue; + } + + let cbr: Callback = serde_json::from_slice(entry.value.as_ref()) + .map_err(|e| { + StoreError::StoreRead(format!("Parsing payload from bytes - {}", e)) + }) + .expect("Failed to parse callback from bytes"); + tx.send(Arc::new(cbr)) + .await + .expect("Failed to send callback"); + } + }); + + Ok((ReceiverStream::new(rx), handle)) + } + + async fn status(&mut self, id: &str) -> StoreResult { + let key = format!("{}=status", id); + let status = self.kv_store.get(&key).await.map_err(|e| { + StoreError::StoreRead(format!("Failed to get status for request id: {e:?}")) + })?; + let Some(status) = status else { + return Err(StoreError::InvalidRequestId(id.to_string())); + }; + Ok(status.into()) + } + + async fn ready(&mut self) -> bool { + // Implement a health check for the JetStream connection if possible + true + } +} + +#[cfg(test)] +mod tests { + use async_nats::jetstream; + use async_nats::jetstream::kv::Config; + use bytes::Bytes; + + use super::*; + use crate::app::callback::cbstore::CallbackStore; + + #[tokio::test] + async fn test_register() { + let js_url = "localhost:4222"; + let client = async_nats::connect(js_url).await.unwrap(); + let context = jetstream::new(client); + let serving_store = "test_serving_store"; + + context + .create_key_value(Config { + bucket: serving_store.to_string(), + history: 1, + ..Default::default() + }) + .await + .unwrap(); + + let mut store = JSCallbackStore::new(context.clone(), serving_store) + .await + .unwrap(); + + let id = "AFA7E0A1-3F0A-4C1B-AB94-BDA57694648D"; + let result = store.register(id).await; + assert!(result.is_ok()); + + // delete store + context.delete_key_value(serving_store).await.unwrap(); + } + + #[tokio::test] + async fn test_watch_callbacks() { + let js_url = "localhost:4222"; + let client = async_nats::connect(js_url).await.unwrap(); + let context = jetstream::new(client); + let serving_store = "test_watch_callbacks"; + + context + .create_key_value(Config { + bucket: serving_store.to_string(), + history: 1, + ..Default::default() + }) + .await + .unwrap(); + + let mut store = JSCallbackStore::new(context.clone(), serving_store) + .await + .unwrap(); + + let id = "test_watch_id_two"; + store.register(id).await.unwrap(); + + let (mut rx, handle) = store.watch_callbacks(id).await.unwrap(); + + // Simulate a callback being added to the store + let callback = Callback { + id: id.to_string(), + vertex: "test_vertex".to_string(), + cb_time: 12345, + from_vertex: "test_from_vertex".to_string(), + responses: vec![], + }; + store + .kv_store + .put(id, Bytes::from(serde_json::to_vec(&callback).unwrap())) + .await + .unwrap(); + + // Verify that the callback is received + let received_callback = rx.next().await.unwrap(); + assert_eq!(received_callback.id, callback.id); + assert_eq!(received_callback.vertex, callback.vertex); + assert_eq!(received_callback.cb_time, callback.cb_time); + assert_eq!(received_callback.from_vertex, callback.from_vertex); + assert_eq!(received_callback.responses.len(), callback.responses.len()); + + handle.abort(); + + // delete store + context.delete_key_value(serving_store).await.unwrap(); + } +} diff --git a/rust/serving/src/app/callback/cbstore/memstore.rs b/rust/serving/src/app/callback/cbstore/memstore.rs new file mode 100644 index 0000000000..53b8049354 --- /dev/null +++ b/rust/serving/src/app/callback/cbstore/memstore.rs @@ -0,0 +1,73 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::mpsc; +use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; + +use crate::app::callback::cbstore::LocalCallbackStore; +use crate::app::callback::datumstore::{Error as StoreError, Result as StoreResult}; +use crate::app::callback::Callback; + +#[derive(Clone)] +pub(crate) struct InMemoryCallbackStore { + data: Arc>>>>, +} + +impl InMemoryCallbackStore { + pub(crate) fn new() -> Self { + Self { + data: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +impl LocalCallbackStore for InMemoryCallbackStore { + async fn register(&mut self, id: &str) -> StoreResult<()> { + let mut data = self.data.lock().await; + if data.contains_key(id) { + return Err(StoreError::DuplicateRequest(id.to_string())); + } + data.insert(id.to_string(), Vec::new()); + Ok(()) + } + + async fn deregister(&mut self, id: &str, sub_graph: &str) -> StoreResult<()> { + let mut data = self.data.lock().await; + if data.remove(id).is_none() { + return Err(StoreError::InvalidRequestId(id.to_string())); + } + Ok(()) + } + + async fn mark_as_failed(&mut self, _id: &str, _error: &str) -> StoreResult<()> { + Ok(()) + } + + async fn watch_callbacks( + &mut self, + id: &str, + ) -> StoreResult<(ReceiverStream>, JoinHandle<()>)> { + let (tx, rx) = mpsc::channel(10); + let data = self.data.lock().await; + if let Some(callbacks) = data.get(id) { + for callback in callbacks { + tx.send(Arc::clone(callback)) + .await + .map_err(|_| StoreError::StoreRead("Failed to send callback".to_string()))?; + } + } else { + return Err(StoreError::InvalidRequestId(id.to_string())); + } + Ok((ReceiverStream::new(rx), tokio::spawn(async {}))) + } + + async fn ready(&mut self) -> bool { + true + } + + async fn status(&mut self, id: &str) -> StoreResult { + unimplemented!(); + } +} diff --git a/rust/serving/src/app/callback/datumstore.rs b/rust/serving/src/app/callback/datumstore.rs new file mode 100644 index 0000000000..26f4f53153 --- /dev/null +++ b/rust/serving/src/app/callback/datumstore.rs @@ -0,0 +1,43 @@ +use thiserror::Error; + +// redis as the store +pub(crate) mod redisstore; +// in-memory store +pub(crate) mod memstore; + +#[derive(Error, Debug, Clone)] +pub(crate) enum Error { + #[error("Connecting to the store: {0}")] + Connection(String), + + #[error("Request id {0} doesn't exist in store")] + InvalidRequestId(String), + + #[error("Request id {0} already exists in the store")] + DuplicateRequest(String), + + #[error("Reading from the store: {0}")] + StoreRead(String), + + #[error("Writing payload to the store: {0}")] + StoreWrite(String), +} + +impl From for crate::Error { + fn from(value: Error) -> Self { + crate::Error::Store(value.to_string()) + } +} + +pub(crate) type Result = std::result::Result; + +/// Store trait to datastore the callback information. +#[trait_variant::make(DatumStore: Send)] +#[allow(dead_code)] +pub(crate) trait LocalDatumStore { + async fn save(&mut self, id: &str, payload: bytes::Bytes) -> Result<()>; + /// retrieve the data from the store + async fn retrieve_datum(&mut self, id: &str) -> Result>>>; + /// check if the store is ready + async fn ready(&mut self) -> bool; +} diff --git a/rust/serving/src/app/callback/datumstore/memstore.rs b/rust/serving/src/app/callback/datumstore/memstore.rs new file mode 100644 index 0000000000..570c4753ae --- /dev/null +++ b/rust/serving/src/app/callback/datumstore/memstore.rs @@ -0,0 +1,178 @@ +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; + +use bytes::Bytes; + +use crate::app::callback::datumstore::{Error as StoreError, Result as StoreResult}; +const STORE_KEY_SUFFIX: &str = "saved"; + +/// `InMemoryStore` is an in-memory implementation of the `Store` trait. +/// It uses a `HashMap` to store data in memory. +#[derive(Clone)] +pub(crate) struct InMemoryStore { + /// The data field is a `HashMap` where the key is a `String` and the value is a `Vec>`. + /// It is wrapped in an `Arc>` to allow shared ownership and thread safety. + pub(crate) data: Arc>>>>, +} + +impl InMemoryStore { + /// Creates a new `InMemoryStore` with an empty `HashMap`. + #[allow(dead_code)] + pub(crate) fn new() -> Self { + Self { + data: Arc::new(std::sync::Mutex::new(HashMap::new())), + } + } +} + +impl super::DatumStore for InMemoryStore { + async fn save(&mut self, id: &str, payload: Bytes) -> StoreResult<()> { + todo!() + } + + /// Retrieves data for a given id from the `HashMap`. + /// Each piece of data is deserialized from bytes into a `String`. + async fn retrieve_datum(&mut self, id: &str) -> StoreResult>>> { + let id = format!("{id}_{STORE_KEY_SUFFIX}"); + let data = self.data.lock().unwrap(); + match data.get(&id) { + Some(result) => Ok(Some(result.to_vec())), + None => Err(StoreError::InvalidRequestId(format!( + "No entry found for id: {}", + id + ))), + } + } + + async fn ready(&mut self) -> bool { + true + } +} + +// #[cfg(test)] +// mod tests { +// use std::sync::Arc; +// +// use super::*; +// use crate::app::callback::datumstore::DatumStore; +// use crate::callback::{Callback, Response}; +// +// #[tokio::test] +// async fn test_save_and_retrieve_callbacks() { +// let mut store = InMemoryStore::new(); +// let key = "test_key".to_string(); +// let value = Arc::new(Callback { +// id: "test_id".to_string(), +// vertex: "in".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { tags: None }], +// }); +// +// // Save a callback +// store +// .save(vec![PayloadToSave::Callback { +// key: key.clone(), +// value: Arc::clone(&value), +// }]) +// .await +// .unwrap(); +// +// // Retrieve the callback +// let retrieved = store.retrieve_callbacks(&key).await.unwrap(); +// +// // Check that the retrieved callback is the same as the one we saved +// assert_eq!(retrieved.len(), 1); +// assert_eq!(retrieved[0].id, "test_id".to_string()) +// } +// +// #[tokio::test] +// async fn test_save_and_retrieve_datum() { +// let mut store = InMemoryStore::new(); +// let key = "test_key".to_string(); +// let value = "test_value".to_string(); +// +// // Save a datum +// store +// .save(vec![PayloadToSave::DatumFromPipeline { +// key: key.clone(), +// value: value.clone().into(), +// }]) +// .await +// .unwrap(); +// +// // Retrieve the datum +// let retrieved = store.retrieve_datum(&key).await.unwrap(); +// let ProcessingStatus::Completed(retrieved) = retrieved else { +// panic!("Expected pipeline processing to be completed"); +// }; +// +// // Check that the retrieved datum is the same as the one we saved +// assert_eq!(retrieved.len(), 1); +// assert_eq!(retrieved[0], value.as_bytes()); +// } +// +// #[tokio::test] +// async fn test_retrieve_callbacks_no_entry() { +// let mut store = InMemoryStore::new(); +// let key = "nonexistent_key".to_string(); +// +// // Try to retrieve a callback for a key that doesn't exist +// let result = store.retrieve_callbacks(&key).await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// +// #[tokio::test] +// async fn test_retrieve_datum_no_entry() { +// let mut store = InMemoryStore::new(); +// let key = "nonexistent_key".to_string(); +// +// // Try to retrieve a datum for a key that doesn't exist +// let result = store.retrieve_datum(&key).await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// +// #[tokio::test] +// async fn test_save_invalid_callback() { +// let mut store = InMemoryStore::new(); +// let value = Arc::new(Callback { +// id: "test_id".to_string(), +// vertex: "in".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { tags: None }], +// }); +// +// // Try to save a callback with an invalid key +// let result = store +// .save(vec![PayloadToSave::Callback { +// key: "".to_string(), +// value: Arc::clone(&value), +// }]) +// .await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// +// #[tokio::test] +// async fn test_save_invalid_datum() { +// let mut store = InMemoryStore::new(); +// +// // Try to save a datum with an invalid key +// let result = store +// .save(vec![PayloadToSave::DatumFromPipeline { +// key: "".to_string(), +// value: "test_value".into(), +// }]) +// .await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// } diff --git a/rust/serving/src/app/callback/datumstore/redisstore.rs b/rust/serving/src/app/callback/datumstore/redisstore.rs new file mode 100644 index 0000000000..5688894496 --- /dev/null +++ b/rust/serving/src/app/callback/datumstore/redisstore.rs @@ -0,0 +1,274 @@ +use std::sync::Arc; + +use backoff::retry::Retry; +use backoff::strategy::fixed; +use bytes::Bytes; +use redis::aio::ConnectionManager; +use redis::RedisError; +use tokio::sync::Semaphore; +use tracing::info; + +use crate::app::callback::datumstore::{DatumStore, Error as StoreError, Result as StoreResult}; +use crate::config::RedisConfig; + +const LPUSH: &str = "LPUSH"; +const LRANGE: &str = "LRANGE"; +const EXPIRE: &str = "EXPIRE"; + +const STATUS_PROCESSING: &str = "processing"; +const STATUS_COMPLETED: &str = "completed"; + +// Handle to the Redis actor. +#[derive(Clone)] +pub(crate) struct RedisConnection { + conn_manager: ConnectionManager, + config: RedisConfig, +} + +impl RedisConnection { + /// Creates a new RedisConnection with concurrent operations on Redis set by max_tasks. + pub(crate) async fn new(config: RedisConfig) -> crate::Result { + let client = redis::Client::open(config.addr.as_str()) + .map_err(|e| StoreError::Connection(format!("Creating Redis client: {e:?}")))?; + let conn = client + .get_connection_manager() + .await + .map_err(|e| StoreError::Connection(format!("Connecting to Redis server: {e:?}")))?; + Ok(Self { + conn_manager: conn, + config, + }) + } + + async fn execute_redis_cmd( + conn_manager: &mut ConnectionManager, + ttl_secs: Option, + key: &str, + val: &Vec, + ) -> Result<(), RedisError> { + let mut pipe = redis::pipe(); + pipe.cmd(LPUSH).arg(key).arg(val); + + // if the ttl is configured, add the EXPIRE command to the pipeline + if let Some(ttl) = ttl_secs { + pipe.cmd(EXPIRE).arg(key).arg(ttl); + } + + // Execute the pipeline + pipe.exec_async(conn_manager).await + } + + // write to Redis with retries + async fn write_to_redis(&self, key: &str, value: &Vec) -> StoreResult<()> { + let interval = fixed::Interval::from_millis(self.config.retries_duration_millis.into()) + .take(self.config.retries); + + Retry::retry( + interval, + || async { + // https://hackmd.io/@compiler-errors/async-closures + Self::execute_redis_cmd( + &mut self.conn_manager.clone(), + self.config.ttl_secs, + key, + value, + ) + .await + }, + |e: &RedisError| !e.is_unrecoverable_error(), + ) + .await + .map_err(|err| StoreError::StoreWrite(format!("Saving to redis: {}", err).to_string())) + } +} + +async fn handle_write_requests( + redis_conn: RedisConnection, + key: &str, + value: Bytes, +) -> StoreResult<()> { + // Write the byte array to Redis + // we have to differentiate between the saved responses and the callback requests + // saved responses are stored in "id_SAVED", callback requests are stored in "id" + let key = format!("request:{key}:results"); + info!(?key, "Writing to Redis"); + let value: Vec = value.into(); + redis_conn.write_to_redis(&key, &value).await +} + +// It is possible to move the methods defined here to be methods on the Redis actor and communicate through channels. +// With that, all public APIs defined on RedisConnection can be on &self (immutable). +impl DatumStore for RedisConnection { + async fn retrieve_datum(&mut self, id: &str) -> StoreResult>>> { + let key = format!("request:{id}:results"); + info!(?key, "Reading from Redis"); + let result: Result>, RedisError> = redis::cmd(LRANGE) + .arg(key) + .arg(0) + .arg(-1) + .query_async(&mut self.conn_manager) + .await; + + match result { + Ok(result) => { + if result.is_empty() { + info!("No results found in Redis"); + Ok(None) + } else { + info!("Results found in Redis"); + Ok(Some(result)) + } + } + Err(e) => Err(StoreError::StoreRead(format!( + "Failed to read from redis: {:?}", + e + ))), + } + } + + // Attempt to save all payloads. Returns error if we fail to save at least one message. + async fn save(&mut self, key: &str, payload: Bytes) -> StoreResult<()> { + // This is put in place not to overload Redis and also way some kind of + // flow control. + handle_write_requests(self.clone(), key, payload).await?; + Ok(()) + } + + // Check if the Redis connection is healthy + async fn ready(&mut self) -> bool { + let mut conn = self.conn_manager.clone(); + match redis::cmd("PING").query_async::(&mut conn).await { + Ok(response) => response == "PONG", + Err(_) => false, + } + } +} + +// #[cfg(feature = "redis-tests")] +// #[cfg(test)] +// mod tests { +// use axum::body::Bytes; +// use redis::AsyncCommands; +// +// use super::*; +// use crate::app::callback::datumstore::LocalDatumStore; +// use crate::callback::Response; +// +// #[tokio::test] +// async fn test_redis_store() { +// let redis_config = RedisConfig { +// addr: "no_such_redis://127.0.0.1:6379".to_owned(), +// max_tasks: 10, +// ..Default::default() +// }; +// let redis_connection = RedisConnection::new(redis_config).await; +// assert!(redis_connection.is_err()); +// +// // Test Redis connection +// let redis_connection = RedisConnection::new(RedisConfig::default()).await; +// assert!(redis_connection.is_ok()); +// +// let key = uuid::Uuid::new_v4().to_string(); +// +// let ps_cb = PayloadToSave::Callback { +// key: key.clone(), +// value: Arc::new(Callback { +// id: String::from("1234"), +// vertex: String::from("prev_vertex"), +// cb_time: 1234, +// from_vertex: String::from("next_vertex"), +// responses: vec![Response { tags: None }], +// }), +// }; +// +// let mut redis_conn = redis_connection.unwrap(); +// redis_conn.register(key.as_str()).await.unwrap(); +// +// // Test Redis save +// assert!(redis_conn.save(vec![ps_cb]).await.is_ok()); +// +// let ps_datum = PayloadToSave::DatumFromPipeline { +// key: key.clone(), +// value: Bytes::from("hello world"), +// }; +// +// assert!(redis_conn.save(vec![ps_datum]).await.is_ok()); +// +// // Test Redis retrieve callbacks +// let callbacks = redis_conn.retrieve_callbacks(&key).await; +// assert!(callbacks.is_ok()); +// +// let callbacks = callbacks.unwrap(); +// assert_eq!(callbacks.len(), 1); +// +// // Additional validations +// let callback = callbacks.first().unwrap(); +// assert_eq!(callback.id, "1234"); +// assert_eq!(callback.vertex, "prev_vertex"); +// assert_eq!(callback.cb_time, 1234); +// assert_eq!(callback.from_vertex, "next_vertex"); +// +// // Test Redis retrieve datum +// let datums = redis_conn.retrieve_datum(&key).await; +// assert!(datums.is_ok()); +// +// assert_eq!(datums.unwrap(), ProcessingStatus::InProgress); +// +// redis_conn.deregister(key.as_str()).await.unwrap(); +// let datums = redis_conn.retrieve_datum(&key).await.unwrap(); +// let ProcessingStatus::Completed(datums) = datums else { +// panic!("Expected completed results"); +// }; +// assert_eq!(datums.len(), 1); +// +// let datum = datums.first().unwrap(); +// assert_eq!(datum, "hello world".as_bytes()); +// +// // Test Redis retrieve callbacks error +// let result = redis_conn.retrieve_callbacks("non_existent_key").await; +// assert!(matches!(result, Err(StoreError::StoreRead(_)))); +// +// // Test Redis retrieve datum error +// let result = redis_conn.retrieve_datum("non_existent_key").await; +// assert!(matches!(result, Err(StoreError::InvalidRequestId(_)))); +// } +// +// #[tokio::test] +// async fn test_redis_ttl() { +// let redis_config = RedisConfig { +// max_tasks: 10, +// ..Default::default() +// }; +// let redis_connection = RedisConnection::new(redis_config) +// .await +// .expect("Failed to connect to Redis"); +// +// let key = uuid::Uuid::new_v4().to_string(); +// let value = Arc::new(Callback { +// id: String::from("test-redis-ttl"), +// vertex: String::from("vertex"), +// cb_time: 1234, +// from_vertex: String::from("next_vertex"), +// responses: vec![Response { tags: None }], +// }); +// +// // Save with TTL of 1 second +// redis_connection +// .write_to_redis(&key, &serde_json::to_vec(&*value).unwrap()) +// .await +// .expect("Failed to write to Redis"); +// +// let mut conn_manager = redis_connection.conn_manager.clone(); +// +// let exists: bool = conn_manager +// .exists(&key) +// .await +// .expect("Failed to check existence immediately"); +// +// // if the key exists, the TTL should be set to 1 second +// if exists { +// let ttl: isize = conn_manager.ttl(&key).await.expect("Failed to check TTL"); +// assert_eq!(ttl, 86400, "TTL should be set to 1 second"); +// } +// } +// } diff --git a/rust/serving/src/app/callback/state.rs b/rust/serving/src/app/callback/state.rs index 28bb004405..e71b207003 100644 --- a/rust/serving/src/app/callback/state.rs +++ b/rust/serving/src/app/callback/state.rs @@ -1,79 +1,118 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::sync::Arc; use tokio::sync::oneshot; +use tokio_stream::StreamExt; +use tracing::{error, info}; -use super::store::{ProcessingStatus, Store}; -use crate::app::callback::store::Error as StoreError; -use crate::app::callback::store::Result as StoreResult; -use crate::app::callback::{store::PayloadToSave, Callback}; +use super::datumstore::DatumStore; +use crate::app::callback::cbstore::{CallbackStore, ProcessingStatus}; +use crate::app::callback::datumstore::Error as StoreError; +use crate::app::callback::datumstore::Result as StoreResult; use crate::app::tracker::MessageGraph; use crate::Error; -struct RequestState { - // Channel to notify when all callbacks for a message is received - tx: oneshot::Sender>, - // CallbackRequest is immutable, while vtx_visited can grow. - vtx_visited: Vec>, -} - #[derive(Clone)] -pub(crate) struct State { - // hashmap of vertex infos keyed by ID - // it also contains tx to trigger to response to the syncHTTP call - callbacks: Arc>>, +pub(crate) struct State { // generator to generate subgraph msg_graph_generator: Arc, // conn is to be used while reading and writing to redis. - store: T, + datum_store: T, + callback_store: C, } -impl State +impl State where - T: Store, + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, { /// Create a new State to track connections and callback data - pub(crate) async fn new(msg_graph: MessageGraph, store: T) -> crate::Result { + pub(crate) async fn new( + msg_graph: MessageGraph, + store: T, + callback_store: C, + ) -> crate::Result { Ok(Self { - callbacks: Arc::new(Mutex::new(HashMap::new())), msg_graph_generator: Arc::new(msg_graph), - store, + datum_store: store, + callback_store, }) } /// register a new connection /// The oneshot receiver will be notified when all callbacks for this connection is received from /// the numaflow pipeline. - pub(crate) async fn register( + pub(crate) async fn process_request( &mut self, - id: Option, - ) -> StoreResult<(String, oneshot::Receiver>)> { - // TODO: add an entry in Redis to note that the entry has been registered + id: &str, + ) -> StoreResult>> { + let (tx, rx) = oneshot::channel(); + let sub_graph_generator = Arc::clone(&self.msg_graph_generator); + let msg_id = id.to_string(); + let mut subgraph = None; + + // register the request in the store + self.callback_store.register(id).await?; + + // start watching for callbacks + let (mut callbacks_stream, watch_handle) = self.callback_store.watch_callbacks(id).await?; + + let mut cb_store = self.callback_store.clone(); + tokio::spawn(async move { + let _handle = watch_handle; + let mut callbacks = Vec::new(); + + while let Some(cb) = callbacks_stream.next().await { + info!(?cb, ?msg_id, "Received callback"); + callbacks.push(cb); + subgraph = match sub_graph_generator + .generate_subgraph_from_callbacks(msg_id.clone(), callbacks.clone()) + { + Ok(subgraph) => subgraph, + Err(e) => { + error!(?e, "Failed to generate subgraph"); + break; + } + }; + if subgraph.is_some() { + break; + } + } - let id = self.store.register(id).await?; + if let Some(subgraph) = subgraph { + tx.send(Ok(subgraph.clone())) + .expect("Failed to send subgraph"); + + cb_store + .deregister(&msg_id, &subgraph) + .await + .expect("Failed to deregister"); + } else { + error!("Subgraph could not be generated for the given ID"); + + tx.send(Err(Error::SubGraphNotFound( + "Subgraph could not be generated for the given ID", + ))) + .expect("Failed to send subgraph"); + + cb_store + .mark_as_failed(&msg_id, "Subgraph could not be generated") + .await + .expect("Failed to mark as failed"); + } + }); - let (tx, rx) = oneshot::channel(); - { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - guard.insert( - id.clone(), - RequestState { - tx, - vtx_visited: Vec::new(), - }, - ); - } - Ok((id, rx)) + Ok(rx) } /// Retrieves the output of the numaflow pipeline pub(crate) async fn retrieve_saved( &mut self, id: &str, - ) -> Result { - self.store.retrieve_datum(id).await.map_err(Into::into) + ) -> Result>>, StoreError> { + self.datum_store + .retrieve_datum(id) + .await + .map_err(Into::into) } pub(crate) async fn save_response( @@ -83,309 +122,189 @@ where ) -> crate::Result<()> { // we have to differentiate between the saved responses and the callback requests // saved responses are stored in "id_SAVED", callback requests are stored in "id" - self.store - .save(vec![PayloadToSave::DatumFromPipeline { - key: id, - value: body, - }]) + self.datum_store + .save(id.as_str(), body) .await .map_err(Into::into) } - /// insert_callback_requests is used to insert the callback requests. - pub(crate) async fn insert_callback_requests( - &mut self, - cb_requests: Vec, - ) -> Result<(), Error> { - /* - TODO: should we consider batching the requests and then processing them? - that way algorithm can be invoked only once for a batch of requests - instead of invoking it for each request. - */ - let cb_requests: Vec> = cb_requests.into_iter().map(Arc::new).collect(); - let redis_payloads: Vec = cb_requests - .iter() - .cloned() - .map(|cbr| PayloadToSave::Callback { - key: cbr.id.clone(), - value: Arc::clone(&cbr), - }) - .collect(); - - self.store.save(redis_payloads).await?; - - for cbr in cb_requests { - let id = cbr.id.clone(); - { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - let req_state = guard.get_mut(&id).ok_or_else(|| { - Error::SubGraphInvalidInput(format!("request id {id} doesn't exist in-memory")) - })?; - req_state.vtx_visited.push(cbr); - } - - // check if the sub graph can be generated - match self.get_subgraph_from_memory(&id) { - Ok(_) => { - // if the sub graph is generated, then we can send the response - self.deregister(&id).await? - } - Err(e) => { - match e { - Error::SubGraphNotFound(_) => { - // if the sub graph is not generated, then we can continue - continue; - } - err => { - tracing::error!(?err, "Failed to generate subgraph"); - // if there is an error, deregister with the error - self.deregister(&id).await? - } - } - } - } - } - Ok(()) - } - - /// Get the subgraph for the given ID from in-memory. - fn get_subgraph_from_memory(&self, id: &str) -> Result { - let callbacks = self.get_callbacks_from_memory(id).ok_or(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - ))?; - - self.get_subgraph(id.to_string(), callbacks) - } - /// Get the subgraph for the given ID from persistent store. This is used querying for the status from the service endpoint even after the /// request has been completed. pub(crate) async fn retrieve_subgraph_from_storage( &mut self, id: &str, ) -> Result { - // If the id is not found in the in-memory store, fetch from Redis - let callbacks: Vec> = match self.retrieve_callbacks_from_storage(id).await { - Ok(callbacks) => callbacks, - Err(e) => { - return Err(e); + let status = self.callback_store.status(id).await?; + match status { + ProcessingStatus::InProgress => Ok("Request In Progress".to_string()), + ProcessingStatus::Completed(sub_graph) => Ok(sub_graph), + ProcessingStatus::Failed(error) => { + error!(?error, "Request failed"); + Err(Error::SubGraphGenerator(error)) } - }; - // check if the sub graph can be generated - self.get_subgraph(id.to_string(), callbacks) - } - - // Generate subgraph from the given callbacks - fn get_subgraph(&self, id: String, callbacks: Vec>) -> Result { - match self - .msg_graph_generator - .generate_subgraph_from_callbacks(id, callbacks) - { - Ok(Some(sub_graph)) => Ok(sub_graph), - Ok(None) => Err(Error::SubGraphNotFound( - "Subgraph could not be generated for the given ID", - )), - Err(e) => Err(e), } } - /// deregister is called to trigger response and delete all the data persisted for that ID - pub(crate) async fn deregister(&mut self, id: &str) -> Result<(), Error> { - let state = { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - // we do not require the data stored in HashMap anymore - guard.remove(id) - }; - - let Some(state) = state else { - return Err(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - )); - }; - - self.store.done(id.to_string()).await?; - - state - .tx - .send(Ok(id.to_string())) - .map_err(|_| Error::Other("Application bug - Receiver is already dropped".to_string())) - } - - // Get the Callback value for the given ID - // TODO: Generate json serialized data here itself to avoid cloning. - fn get_callbacks_from_memory(&self, id: &str) -> Option>> { - let guard = self.callbacks.lock().expect("Getting lock on State"); - guard.get(id).map(|state| state.vtx_visited.clone()) - } - - // Get the Callback value for the given ID from persistent store - async fn retrieve_callbacks_from_storage( - &mut self, - id: &str, - ) -> Result>, Error> { - // If the id is not found in the in-memory store, fetch from Redis - Ok(self - .store - .retrieve_callbacks(id) - .await? - .into_iter() - .collect()) + pub(crate) async fn mark_as_failed(&mut self, id: &str, error: &str) -> Result<(), Error> { + self.callback_store.mark_as_failed(id, error).await?; + Ok(()) } // Check if the store is ready pub(crate) async fn ready(&mut self) -> bool { - self.store.ready().await + self.datum_store.ready().await } } -#[cfg(test)] -mod tests { - use axum::body::Bytes; - - use super::*; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::callback::Response; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_state() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - // Test register - let id = "test_id".to_string(); - let (id, rx) = state.register(Some(id.clone())).await.unwrap(); - - let xid = id.clone(); - - // spawn a task to listen on the receiver, once we have received all the callbacks for the message - // we will get a response from the receiver with the message id - let handle = tokio::spawn(async move { - let result = rx.await.unwrap(); - // Tests deregister, and fetching the subgraph from the memory - assert_eq!(result.unwrap(), xid); - }); - - // Test save_response - let body = Bytes::from("Test Message"); - state.save_response(id.clone(), body).await.unwrap(); - - // Test retrieve_saved - let saved = state.retrieve_saved(&id).await.unwrap(); - assert_eq!( - saved, - ProcessingStatus::Completed(vec!["Test Message".as_bytes().to_vec()]) - ); - - // Test insert_callback_requests - let cbs = vec![ - Callback { - id: id.clone(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: id.clone(), - vertex: "planner".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { - tags: Some(vec!["tiger".to_owned(), "asciiart".to_owned()]), - }], - }, - Callback { - id: id.clone(), - vertex: "tiger".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: id.clone(), - vertex: "asciiart".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: id.clone(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "tiger".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: id.clone(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "asciiart".to_string(), - responses: vec![Response { tags: None }], - }, - ]; - state.insert_callback_requests(cbs).await.unwrap(); - - let sub_graph = state.retrieve_subgraph_from_storage(&id).await; - assert!(sub_graph.is_ok()); - - handle.await.unwrap(); - } - - #[tokio::test] - async fn test_retrieve_saved_no_entry() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let id = "nonexistent_id".to_string(); - - // Try to retrieve saved data for an ID that doesn't exist - let result = state.retrieve_saved(&id).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_insert_callback_requests_invalid_id() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let cbs = vec![Callback { - id: "nonexistent_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }]; - - // Try to insert callback requests for an ID that hasn't been registered - let result = state.insert_callback_requests(cbs).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_retrieve_subgraph_from_storage_no_entry() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let id = "nonexistent_id".to_string(); - - // Try to retrieve a subgraph for an ID that doesn't exist - let result = state.retrieve_subgraph_from_storage(&id).await; - - // Check that an error is returned - assert!(result.is_err()); - } -} +// #[cfg(test)] +// mod tests { +// use axum::body::Bytes; +// +// use super::*; +// use crate::app::callback::datumstore::memstore::InMemoryStore; +// use crate::callback::Response; +// use crate::pipeline::PipelineDCG; +// +// const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; +// +// #[tokio::test] +// async fn test_state() { +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// let datumstore = InMemoryStore::new(); +// let mut state = State::new(msg_graph, datumstore).await.unwrap(); +// +// // Test register +// let id = "test_id".to_string(); +// let rx = state.register(id.as_str()).await.unwrap(); +// +// let xid = id.clone(); +// +// // spawn a task to listen on the receiver, once we have received all the callbacks for the message +// // we will get a response from the receiver with the message id +// let handle = tokio::spawn(async move { +// let result = rx.await.unwrap(); +// // Tests deregister, and fetching the subgraph from the memory +// assert_eq!(result.unwrap(), xid); +// }); +// +// // Test save_response +// let body = Bytes::from("Test Message"); +// state.save_response(id.clone(), body).await.unwrap(); +// +// // Test retrieve_saved +// let saved = state.retrieve_saved(&id).await.unwrap(); +// assert_eq!( +// saved, +// ProcessingStatus::Completed(vec!["Test Message".as_bytes().to_vec()]) +// ); +// +// // Test insert_callback_requests +// let cbs = vec![ +// Callback { +// id: id.clone(), +// vertex: "in".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { tags: None }], +// }, +// Callback { +// id: id.clone(), +// vertex: "planner".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { +// tags: Some(vec!["tiger".to_owned(), "asciiart".to_owned()]), +// }], +// }, +// Callback { +// id: id.clone(), +// vertex: "tiger".to_string(), +// cb_time: 12345, +// from_vertex: "planner".to_string(), +// responses: vec![Response { tags: None }], +// }, +// Callback { +// id: id.clone(), +// vertex: "asciiart".to_string(), +// cb_time: 12345, +// from_vertex: "planner".to_string(), +// responses: vec![Response { tags: None }], +// }, +// Callback { +// id: id.clone(), +// vertex: "serve-sink".to_string(), +// cb_time: 12345, +// from_vertex: "tiger".to_string(), +// responses: vec![Response { tags: None }], +// }, +// Callback { +// id: id.clone(), +// vertex: "serve-sink".to_string(), +// cb_time: 12345, +// from_vertex: "asciiart".to_string(), +// responses: vec![Response { tags: None }], +// }, +// ]; +// state.insert_callback_requests(cbs).await.unwrap(); +// +// let sub_graph = state.retrieve_subgraph_from_storage(&id).await; +// assert!(sub_graph.is_ok()); +// +// handle.await.unwrap(); +// } +// +// #[tokio::test] +// async fn test_retrieve_saved_no_entry() { +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// let datumstore = InMemoryStore::new(); +// let mut state = State::new(msg_graph, datumstore).await.unwrap(); +// +// let id = "nonexistent_id".to_string(); +// +// // Try to retrieve saved data for an ID that doesn't exist +// let result = state.retrieve_saved(&id).await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// +// #[tokio::test] +// async fn test_insert_callback_requests_invalid_id() { +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// let datumstore = InMemoryStore::new(); +// let mut state = State::new(msg_graph, datumstore).await.unwrap(); +// +// let cbs = vec![Callback { +// id: "nonexistent_id".to_string(), +// vertex: "in".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { tags: None }], +// }]; +// +// // Try to insert callback requests for an ID that hasn't been registered +// let result = state.insert_callback_requests(cbs).await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// +// #[tokio::test] +// async fn test_retrieve_subgraph_from_storage_no_entry() { +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// let datumstore = InMemoryStore::new(); +// let mut state = State::new(msg_graph, datumstore).await.unwrap(); +// +// let id = "nonexistent_id".to_string(); +// +// // Try to retrieve a subgraph for an ID that doesn't exist +// let result = state.retrieve_subgraph_from_storage(&id).await; +// +// // Check that an error is returned +// assert!(result.is_err()); +// } +// } diff --git a/rust/serving/src/app/callback/store.rs b/rust/serving/src/app/callback/store.rs deleted file mode 100644 index de2b7382e8..0000000000 --- a/rust/serving/src/app/callback/store.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; - -use thiserror::Error; - -use crate::app::callback::Callback; - -// in-memory store -pub(crate) mod memstore; -// redis as the store -pub(crate) mod redisstore; - -pub(crate) enum PayloadToSave { - /// Callback as sent by Numaflow to track the progression - Callback { key: String, value: Arc }, - /// Data sent by the Numaflow pipeline which is to be delivered as the response - DatumFromPipeline { - key: String, - value: axum::body::Bytes, - }, -} - -#[derive(Debug, PartialEq)] -/// Represents the current processing status of a request id in the `Store`. -pub(crate) enum ProcessingStatus { - InProgress, - Completed(Vec>), -} - -#[derive(Error, Debug, Clone)] -pub(crate) enum Error { - #[error("Connecting to the store: {0}")] - Connection(String), - - #[error("Request id {0} doesn't exist in store")] - InvalidRequestId(String), - - #[error("Request id {0} already exists in the store")] - DuplicateRequest(String), - - #[error("Reading from the store: {0}")] - StoreRead(String), - - #[error("Writing payload to the store: {0}")] - StoreWrite(String), -} - -impl From for crate::Error { - fn from(value: Error) -> Self { - crate::Error::Store(value.to_string()) - } -} - -pub(crate) type Result = std::result::Result; - -/// Store trait to store the callback information. -#[trait_variant::make(Store: Send)] -#[allow(dead_code)] -pub(crate) trait LocalStore { - /// Register a request id in the store. If user provides a request id, the same should be returned - /// if it doesn't already exist in the store. An error should be returned if the user-specified request id - /// already exists in the store. If the `id` is `None`, the store should generate a new unique request id. - async fn register(&mut self, id: Option) -> Result; - /// This method will be called when processing is completed for a request id. - async fn done(&mut self, id: String) -> Result<()>; - async fn save(&mut self, messages: Vec) -> Result<()>; - /// retrieve the callback payloads - async fn retrieve_callbacks(&mut self, id: &str) -> Result>>; - async fn retrieve_datum(&mut self, id: &str) -> Result; - async fn ready(&mut self) -> bool; -} diff --git a/rust/serving/src/app/callback/store/memstore.rs b/rust/serving/src/app/callback/store/memstore.rs deleted file mode 100644 index 25a125ae67..0000000000 --- a/rust/serving/src/app/callback/store/memstore.rs +++ /dev/null @@ -1,236 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use uuid::Uuid; - -use super::{Error as StoreError, Result as StoreResult}; -use super::{PayloadToSave, ProcessingStatus}; -use crate::app::callback::Callback; - -const STORE_KEY_SUFFIX: &str = "saved"; - -/// `InMemoryStore` is an in-memory implementation of the `Store` trait. -/// It uses a `HashMap` to store data in memory. -#[derive(Clone)] -pub(crate) struct InMemoryStore { - /// The data field is a `HashMap` where the key is a `String` and the value is a `Vec>`. - /// It is wrapped in an `Arc>` to allow shared ownership and thread safety. - pub(crate) data: Arc>>>>, -} - -impl InMemoryStore { - /// Creates a new `InMemoryStore` with an empty `HashMap`. - #[allow(dead_code)] - pub(crate) fn new() -> Self { - Self { - data: Arc::new(std::sync::Mutex::new(HashMap::new())), - } - } -} - -impl super::Store for InMemoryStore { - async fn register(&mut self, id: Option) -> StoreResult { - Ok(id.unwrap_or_else(|| Uuid::now_v7().to_string())) - } - async fn done(&mut self, _id: String) -> StoreResult<()> { - Ok(()) - } - /// Saves a vector of `PayloadToSave` into the `HashMap`. - /// Each `PayloadToSave` is serialized into bytes and stored in the `HashMap` under its key. - async fn save(&mut self, messages: Vec) -> StoreResult<()> { - let mut data = self.data.lock().unwrap(); - for msg in messages { - match msg { - PayloadToSave::Callback { key, value } => { - if key.is_empty() { - return Err(StoreError::StoreWrite("Key cannot be empty".to_string())); - } - let bytes = serde_json::to_vec(&*value).map_err(|e| { - StoreError::StoreWrite(format!("Serializing to bytes - {}", e)) - })?; - data.entry(key).or_default().push(bytes); - } - PayloadToSave::DatumFromPipeline { key, value } => { - if key.is_empty() { - return Err(StoreError::StoreWrite("Key cannot be empty".to_string())); - } - data.entry(format!("{key}_{STORE_KEY_SUFFIX}")) - .or_default() - .push(value.into()); - } - } - } - Ok(()) - } - - /// Retrieves callbacks for a given id from the `HashMap`. - /// Each callback is deserialized from bytes into a `CallbackRequest`. - async fn retrieve_callbacks(&mut self, id: &str) -> StoreResult>> { - let data = self.data.lock().unwrap(); - match data.get(id) { - Some(result) => { - let messages: Result, _> = result - .iter() - .map(|msg| { - let cbr: Callback = serde_json::from_slice(msg).map_err(|_| { - StoreError::StoreRead( - "Failed to parse CallbackRequest from bytes".to_string(), - ) - })?; - Ok(Arc::new(cbr)) - }) - .collect(); - messages - } - None => Err(StoreError::StoreRead(format!( - "No entry found for id: {}", - id - ))), - } - } - - /// Retrieves data for a given id from the `HashMap`. - /// Each piece of data is deserialized from bytes into a `String`. - async fn retrieve_datum(&mut self, id: &str) -> StoreResult { - let id = format!("{id}_{STORE_KEY_SUFFIX}"); - let data = self.data.lock().unwrap(); - match data.get(&id) { - Some(result) => Ok(ProcessingStatus::Completed(result.to_vec())), - None => Err(StoreError::InvalidRequestId(format!( - "No entry found for id: {}", - id - ))), - } - } - - async fn ready(&mut self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - use crate::app::callback::store::{PayloadToSave, Store}; - use crate::callback::{Callback, Response}; - - #[tokio::test] - async fn test_save_and_retrieve_callbacks() { - let mut store = InMemoryStore::new(); - let key = "test_key".to_string(); - let value = Arc::new(Callback { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }); - - // Save a callback - store - .save(vec![PayloadToSave::Callback { - key: key.clone(), - value: Arc::clone(&value), - }]) - .await - .unwrap(); - - // Retrieve the callback - let retrieved = store.retrieve_callbacks(&key).await.unwrap(); - - // Check that the retrieved callback is the same as the one we saved - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0].id, "test_id".to_string()) - } - - #[tokio::test] - async fn test_save_and_retrieve_datum() { - let mut store = InMemoryStore::new(); - let key = "test_key".to_string(); - let value = "test_value".to_string(); - - // Save a datum - store - .save(vec![PayloadToSave::DatumFromPipeline { - key: key.clone(), - value: value.clone().into(), - }]) - .await - .unwrap(); - - // Retrieve the datum - let retrieved = store.retrieve_datum(&key).await.unwrap(); - let ProcessingStatus::Completed(retrieved) = retrieved else { - panic!("Expected pipeline processing to be completed"); - }; - - // Check that the retrieved datum is the same as the one we saved - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0], value.as_bytes()); - } - - #[tokio::test] - async fn test_retrieve_callbacks_no_entry() { - let mut store = InMemoryStore::new(); - let key = "nonexistent_key".to_string(); - - // Try to retrieve a callback for a key that doesn't exist - let result = store.retrieve_callbacks(&key).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_retrieve_datum_no_entry() { - let mut store = InMemoryStore::new(); - let key = "nonexistent_key".to_string(); - - // Try to retrieve a datum for a key that doesn't exist - let result = store.retrieve_datum(&key).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_save_invalid_callback() { - let mut store = InMemoryStore::new(); - let value = Arc::new(Callback { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }); - - // Try to save a callback with an invalid key - let result = store - .save(vec![PayloadToSave::Callback { - key: "".to_string(), - value: Arc::clone(&value), - }]) - .await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_save_invalid_datum() { - let mut store = InMemoryStore::new(); - - // Try to save a datum with an invalid key - let result = store - .save(vec![PayloadToSave::DatumFromPipeline { - key: "".to_string(), - value: "test_value".into(), - }]) - .await; - - // Check that an error is returned - assert!(result.is_err()); - } -} diff --git a/rust/serving/src/app/callback/store/redisstore.rs b/rust/serving/src/app/callback/store/redisstore.rs deleted file mode 100644 index d3f0c9087e..0000000000 --- a/rust/serving/src/app/callback/store/redisstore.rs +++ /dev/null @@ -1,432 +0,0 @@ -use std::sync::Arc; - -use backoff::retry::Retry; -use backoff::strategy::fixed; -use redis::aio::ConnectionManager; -use redis::{AsyncCommands, RedisError}; -use tokio::sync::Semaphore; -use uuid::Uuid; - -use super::{Error as StoreError, Result as StoreResult}; -use super::{PayloadToSave, ProcessingStatus}; -use crate::app::callback::Callback; -use crate::config::RedisConfig; - -const LPUSH: &str = "LPUSH"; -const LRANGE: &str = "LRANGE"; -const EXPIRE: &str = "EXPIRE"; - -const STATUS_PROCESSING: &str = "processing"; -const STATUS_COMPLETED: &str = "completed"; - -// Handle to the Redis actor. -#[derive(Clone)] -pub(crate) struct RedisConnection { - conn_manager: ConnectionManager, - config: RedisConfig, -} - -impl RedisConnection { - /// Creates a new RedisConnection with concurrent operations on Redis set by max_tasks. - pub(crate) async fn new(config: RedisConfig) -> crate::Result { - let client = redis::Client::open(config.addr.as_str()) - .map_err(|e| StoreError::Connection(format!("Creating Redis client: {e:?}")))?; - let conn = client - .get_connection_manager() - .await - .map_err(|e| StoreError::Connection(format!("Connecting to Redis server: {e:?}")))?; - Ok(Self { - conn_manager: conn, - config, - }) - } - - async fn execute_redis_cmd( - conn_manager: &mut ConnectionManager, - ttl_secs: Option, - key: &str, - val: &Vec, - ) -> Result<(), RedisError> { - let mut pipe = redis::pipe(); - pipe.cmd(LPUSH).arg(key).arg(val); - - // if the ttl is configured, add the EXPIRE command to the pipeline - if let Some(ttl) = ttl_secs { - pipe.cmd(EXPIRE).arg(key).arg(ttl); - } - - // Execute the pipeline - pipe.exec_async(conn_manager).await - } - - // write to Redis with retries - async fn write_to_redis(&self, key: &str, value: &Vec) -> StoreResult<()> { - let interval = fixed::Interval::from_millis(self.config.retries_duration_millis.into()) - .take(self.config.retries); - - Retry::retry( - interval, - || async { - // https://hackmd.io/@compiler-errors/async-closures - Self::execute_redis_cmd( - &mut self.conn_manager.clone(), - self.config.ttl_secs, - key, - value, - ) - .await - }, - |e: &RedisError| !e.is_unrecoverable_error(), - ) - .await - .map_err(|err| StoreError::StoreWrite(format!("Saving to redis: {}", err).to_string())) - } -} - -async fn handle_write_requests(redis_conn: RedisConnection, msg: PayloadToSave) -> StoreResult<()> { - match msg { - PayloadToSave::Callback { key, value } => { - // Convert the CallbackRequest to a byte array - let value = serde_json::to_vec(&*value) - .map_err(|e| StoreError::StoreWrite(format!("Serializing payload - {}", e)))?; - - let key = format!("request:{key}:callbacks"); - redis_conn.write_to_redis(&key, &value).await - } - - // Write the byte array to Redis - PayloadToSave::DatumFromPipeline { key, value } => { - // we have to differentiate between the saved responses and the callback requests - // saved responses are stored in "id_SAVED", callback requests are stored in "id" - let key = format!("request:{key}:results"); - let value: Vec = value.into(); - - redis_conn.write_to_redis(&key, &value).await - } - } -} - -// It is possible to move the methods defined here to be methods on the Redis actor and communicate through channels. -// With that, all public APIs defined on RedisConnection can be on &self (immutable). -impl super::Store for RedisConnection { - async fn register(&mut self, id: Option) -> StoreResult { - match id { - Some(id) => { - let mut pipe = redis::pipe(); - let key = format!("request:{id}:status"); - pipe.cmd("SET").arg(&key).arg(STATUS_PROCESSING).arg("NX"); - // if the ttl is configured, add the EXPIRE command to the pipeline - if let Some(ttl) = self.config.ttl_secs { - pipe.arg("EX").arg(ttl); - } - - let (status,): (bool,) = - pipe.query_async(&mut self.conn_manager) - .await - .map_err(|e| { - StoreError::StoreWrite(format!( - "Registering request_id={id} in Redis: {e:?}" - )) - })?; - - if !status { - // The user specified request id already exists - return Err(StoreError::DuplicateRequest(id)); - } - Ok(id) - } - None => { - // We use UUID v7 as the request id. Attempt for a maxium of 5 times to generate an - // id that doesn't currently exist in the Store. - for _ in 0..5 { - let id = Uuid::now_v7().to_string(); - let mut pipe = redis::pipe(); - let key = format!("request:{id}:status"); - pipe.cmd("SET").arg(&key).arg(STATUS_PROCESSING).arg("NX"); - - // if the ttl is configured, add the EXPIRE command to the pipeline - if let Some(ttl) = self.config.ttl_secs { - pipe.arg("EX").arg(ttl); - } - - let (status,): (bool,) = pipe - .query_async(&mut self.conn_manager) - .await - .map_err(|e| { - StoreError::StoreWrite(format!( - "Registering request_id={id} in Redis: {e:?}" - )) - })?; - - if !status { - continue; - } - return Ok(id); - } - Err(StoreError::StoreWrite( - "Could not generate a unique request id".to_string(), - )) - } - } - } - - // Updates the processing status for the specified request id as completed. - async fn done(&mut self, id: String) -> StoreResult<()> { - let key = format!("request:{id}:status"); - let status: bool = redis::cmd("SET") - .arg(&key) - .arg(STATUS_COMPLETED) - .arg("XX") - .arg("KEEPTTL") - .query_async(&mut self.conn_manager) - .await - .map_err(|e| { - StoreError::StoreWrite(format!( - "Setting processing status as completed in Redis for request_id={id}: {e:?}" - )) - })?; - if !status { - return Err(StoreError::StoreWrite(format!( - "Key {key} is not present in Redis for updating processing status as completed" - ))); - } - Ok(()) - } - // Attempt to save all payloads. Returns error if we fail to save at least one message. - async fn save(&mut self, messages: Vec) -> StoreResult<()> { - let mut tasks = vec![]; - // This is put in place not to overload Redis and also way some kind of - // flow control. - let sem = Arc::new(Semaphore::new(self.config.max_tasks)); - for msg in messages { - let permit = Arc::clone(&sem).acquire_owned().await; - let redis_conn = self.clone(); - let task = tokio::spawn(async move { - let _permit = permit; - handle_write_requests(redis_conn, msg).await - }); - tasks.push(task); - } - for task in tasks { - task.await.unwrap()?; - } - Ok(()) - } - - async fn retrieve_callbacks(&mut self, id: &str) -> StoreResult>> { - let redis_key = format!("request:{id}:callbacks"); - let result: Result>, RedisError> = redis::cmd(LRANGE) - .arg(redis_key) - .arg(0) - .arg(-1) - .query_async(&mut self.conn_manager) - .await; - - match result { - Ok(result) => { - if result.is_empty() { - return Err(StoreError::StoreRead(format!( - "No entry found for id: {}", - id - ))); - } - - let messages: Result, _> = result - .into_iter() - .map(|msg| { - let cbr: Callback = serde_json::from_slice(&msg).map_err(|e| { - StoreError::StoreRead(format!("Parsing payload from bytes - {}", e)) - })?; - Ok(Arc::new(cbr)) - }) - .collect(); - - messages - } - Err(e) => Err(StoreError::StoreRead(format!( - "Failed to read from redis: {:?}", - e - ))), - } - } - - async fn retrieve_datum(&mut self, id: &str) -> StoreResult { - let redis_status_key = format!("request:{id}:status"); - let status: Option = self - .conn_manager - .get(redis_status_key) - .await - .map_err(|e| StoreError::StoreRead(format!("Reading request status: {e:?}")))?; - - let Some(status) = status else { - return Err(StoreError::InvalidRequestId(id.to_string())); - }; - - if status == STATUS_PROCESSING { - return Ok(ProcessingStatus::InProgress); - } - - let key = format!("request:{id}:results"); - let result: Result>, RedisError> = redis::cmd(LRANGE) - .arg(key) - .arg(0) - .arg(-1) - .query_async(&mut self.conn_manager) - .await; - - match result { - Ok(result) => { - if result.is_empty() { - return Err(StoreError::StoreRead(format!( - "No entry found for id: {}", - id - ))); - } - - Ok(ProcessingStatus::Completed(result)) - } - Err(e) => Err(StoreError::StoreRead(format!( - "Failed to read from redis: {:?}", - e - ))), - } - } - - // Check if the Redis connection is healthy - async fn ready(&mut self) -> bool { - let mut conn = self.conn_manager.clone(); - match redis::cmd("PING").query_async::(&mut conn).await { - Ok(response) => response == "PONG", - Err(_) => false, - } - } -} - -#[cfg(feature = "redis-tests")] -#[cfg(test)] -mod tests { - use axum::body::Bytes; - use redis::AsyncCommands; - - use super::*; - use crate::app::callback::store::{LocalStore, ProcessingStatus}; - use crate::callback::Response; - - #[tokio::test] - async fn test_redis_store() { - let redis_config = RedisConfig { - addr: "no_such_redis://127.0.0.1:6379".to_owned(), - max_tasks: 10, - ..Default::default() - }; - let redis_connection = RedisConnection::new(redis_config).await; - assert!(redis_connection.is_err()); - - // Test Redis connection - let redis_connection = RedisConnection::new(RedisConfig::default()).await; - assert!(redis_connection.is_ok()); - - let key = uuid::Uuid::new_v4().to_string(); - - let ps_cb = PayloadToSave::Callback { - key: key.clone(), - value: Arc::new(Callback { - id: String::from("1234"), - vertex: String::from("prev_vertex"), - cb_time: 1234, - from_vertex: String::from("next_vertex"), - responses: vec![Response { tags: None }], - }), - }; - - let mut redis_conn = redis_connection.unwrap(); - redis_conn.register(Some(key.clone())).await.unwrap(); - - // Test Redis save - assert!(redis_conn.save(vec![ps_cb]).await.is_ok()); - - let ps_datum = PayloadToSave::DatumFromPipeline { - key: key.clone(), - value: Bytes::from("hello world"), - }; - - assert!(redis_conn.save(vec![ps_datum]).await.is_ok()); - - // Test Redis retrieve callbacks - let callbacks = redis_conn.retrieve_callbacks(&key).await; - assert!(callbacks.is_ok()); - - let callbacks = callbacks.unwrap(); - assert_eq!(callbacks.len(), 1); - - // Additional validations - let callback = callbacks.first().unwrap(); - assert_eq!(callback.id, "1234"); - assert_eq!(callback.vertex, "prev_vertex"); - assert_eq!(callback.cb_time, 1234); - assert_eq!(callback.from_vertex, "next_vertex"); - - // Test Redis retrieve datum - let datums = redis_conn.retrieve_datum(&key).await; - assert!(datums.is_ok()); - - assert_eq!(datums.unwrap(), ProcessingStatus::InProgress); - - redis_conn.done(key.clone()).await.unwrap(); - let datums = redis_conn.retrieve_datum(&key).await.unwrap(); - let ProcessingStatus::Completed(datums) = datums else { - panic!("Expected completed results"); - }; - assert_eq!(datums.len(), 1); - - let datum = datums.first().unwrap(); - assert_eq!(datum, "hello world".as_bytes()); - - // Test Redis retrieve callbacks error - let result = redis_conn.retrieve_callbacks("non_existent_key").await; - assert!(matches!(result, Err(StoreError::StoreRead(_)))); - - // Test Redis retrieve datum error - let result = redis_conn.retrieve_datum("non_existent_key").await; - assert!(matches!(result, Err(StoreError::InvalidRequestId(_)))); - } - - #[tokio::test] - async fn test_redis_ttl() { - let redis_config = RedisConfig { - max_tasks: 10, - ..Default::default() - }; - let redis_connection = RedisConnection::new(redis_config) - .await - .expect("Failed to connect to Redis"); - - let key = uuid::Uuid::new_v4().to_string(); - let value = Arc::new(Callback { - id: String::from("test-redis-ttl"), - vertex: String::from("vertex"), - cb_time: 1234, - from_vertex: String::from("next_vertex"), - responses: vec![Response { tags: None }], - }); - - // Save with TTL of 1 second - redis_connection - .write_to_redis(&key, &serde_json::to_vec(&*value).unwrap()) - .await - .expect("Failed to write to Redis"); - - let mut conn_manager = redis_connection.conn_manager.clone(); - - let exists: bool = conn_manager - .exists(&key) - .await - .expect("Failed to check existence immediately"); - - // if the key exists, the TTL should be set to 1 second - if exists { - let ttl: isize = conn_manager.ttl(&key).await.expect("Failed to check TTL"); - assert_eq!(ttl, 86400, "TTL should be set to 1 second"); - } - } -} diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index 12315d7b9d..f0429763e0 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -12,44 +12,30 @@ use serde::Deserialize; use serde_json::json; use tokio::sync::{mpsc, oneshot}; use tracing::error; +use uuid::Uuid; -use super::{ - callback::store::{ProcessingStatus, Store}, - AppState, -}; -use crate::app::callback::store::Error as StoreError; +use super::{callback::datumstore::DatumStore, AppState}; +use crate::app::callback::cbstore::CallbackStore; +use crate::app::callback::datumstore::Error as StoreError; use crate::app::response::{ApiError, ServeResponse}; use crate::{app::callback::state, Message, MessageWrapper}; -// TODO: -// - [ ] better health check -// - [ ] jetstream connection pooling -// - [ ] make use of proper url capture! perhaps we have to rewrite the nesting at app level -// *async* -// curl -H 'Content-Type: text/plain' -X POST -d "test-$(date +'%s')" -v http://localhost:3000/v1/process/async | jq -// *sync* -// curl -H 'ID: foobar' -H 'Content-Type: text/plain' -X POST -d "test-$(date +'%s')" http://localhost:3000/v1/process/sync -// curl -H 'Content-Type: application/json' -X POST -d '{"id": "foobar"}' http://localhost:3000/v1/process/callback -// { -// "id": "foobar", -// "vertex": "b", -// "cb_time": 12345, -// "from_vertex": "a" -// } - const NUMAFLOW_RESP_ARRAY_LEN: &str = "Numaflow-Array-Len"; const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; -struct ProxyState { +struct ProxyState { message: mpsc::Sender, tid_header: String, /// Lets the HTTP handlers know whether they are in a Monovertex or a Pipeline monovertex: bool, - callback: state::State, + callback: state::State, } -pub(crate) async fn jetstream_proxy( - state: AppState, +pub(crate) async fn jetstream_proxy< + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, +>( + state: AppState, ) -> crate::Result { let proxy_state = Arc::new(ProxyState { message: state.message.clone(), @@ -71,8 +57,11 @@ struct ServeQueryParams { id: String, } -async fn fetch( - State(proxy_state): State>>, +async fn fetch< + T: Send + Sync + Clone + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( + State(proxy_state): State>>, Query(ServeQueryParams { id }): Query, ) -> Response { let pipeline_result = match proxy_state.callback.clone().retrieve_saved(&id).await { @@ -92,7 +81,7 @@ async fn fetch( } }; - let ProcessingStatus::Completed(result) = pipeline_result else { + let Some(result) = pipeline_result else { return Json(json!({"status": "in-progress"})).into_response(); }; @@ -109,7 +98,7 @@ async fn fetch( let arr_idx_header_val = match HeaderValue::from_str(response_arr_len.as_str()) { Ok(val) => val, Err(e) => { - tracing::error!(?e, "Encoding response array length"); + error!(?e, "Encoding response array length"); return ApiError::InternalServerError(format!( "Encoding response array len failed: {}", e @@ -126,14 +115,18 @@ async fn fetch( (header_map, body).into_response() } -async fn sync_publish( - State(proxy_state): State>>, +async fn sync_publish< + T: Send + Sync + Clone + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( + State(proxy_state): State>>, headers: HeaderMap, body: Bytes, ) -> impl IntoResponse { let id = headers .get(&proxy_state.tid_header) - .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()); + .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()) + .unwrap_or_else(|| Uuid::now_v7().to_string()); let mut msg_headers: HashMap = HashMap::new(); for (key, value) in headers.iter() { @@ -144,7 +137,12 @@ async fn sync_publish( } // Register the ID in the callback proxy state - let (id, notify) = match proxy_state.callback.clone().register(id).await { + let notify = match proxy_state + .callback + .clone() + .process_request(id.as_str()) + .await + { Ok(result) => result, Err(e) => { error!(error = ?e, "Registering request in data store"); @@ -168,11 +166,20 @@ async fn sync_publish( headers: msg_headers, }, }; - proxy_state.message.send(message).await.unwrap(); // FIXME: + proxy_state + .message + .send(message) + .await + .expect("failed to send message"); if let Err(e) = confirm_save_rx.await { // Deregister the ID in the callback proxy state if waiting for ack fails - let _ = proxy_state.callback.clone().deregister(&id).await; + let _ = proxy_state + .callback + .clone() + .mark_as_failed(&id, e.to_string().as_str()) + .await; + error!(error = ?e, "Publishing message to Jetstream for sync request"); return Err(ApiError::BadGateway( "Failed to write message to Jetstream".to_string(), @@ -202,7 +209,8 @@ async fn sync_publish( HeaderName::from_str(&proxy_state.tid_header).unwrap(), HeaderValue::from_str(&id).unwrap(), ); - let ProcessingStatus::Completed(result) = result else { + + let Some(result) = result else { return Ok(Json(json!({"status": "processing"})).into_response()); }; @@ -228,14 +236,19 @@ async fn sync_publish( Ok((header_map, body).into_response()) } -async fn async_publish( - State(proxy_state): State>>, +async fn async_publish< + T: Send + Sync + Clone + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( + State(proxy_state): State>>, headers: HeaderMap, body: Bytes, ) -> Result, ApiError> { let id = headers .get(&proxy_state.tid_header) - .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()); + .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()) + .unwrap_or_else(|| Uuid::now_v7().to_string()); + let mut msg_headers: HashMap = HashMap::new(); for (key, value) in headers.iter() { // Exclude request ID @@ -249,7 +262,12 @@ async fn async_publish( } // Register request in Redis - let (id, notify) = match proxy_state.callback.clone().register(id).await { + let notify = match proxy_state + .callback + .clone() + .process_request(id.as_str()) + .await + { Ok(result) => result, Err(e) => { error!(error = ?e, "Registering request in data store"); @@ -304,340 +322,326 @@ async fn async_publish( } } -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use axum::body::{to_bytes, Body}; - use axum::extract::Request; - use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; - use serde_json::{json, Value}; - use tower::ServiceExt; - use uuid::Uuid; - - use super::*; - use crate::app::callback::state::State as CallbackState; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::callback::store::PayloadToSave; - use crate::app::callback::store::Result as StoreResult; - use crate::app::tracker::MessageGraph; - use crate::callback::{Callback, Response}; - use crate::config::DEFAULT_ID_HEADER; - use crate::pipeline::PipelineDCG; - use crate::Settings; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[derive(Clone)] - struct MockStore; - - impl Store for MockStore { - async fn register(&mut self, id: Option) -> StoreResult { - Ok(id.unwrap_or_else(|| Uuid::now_v7().to_string())) - } - async fn done(&mut self, _id: String) -> StoreResult<()> { - Ok(()) - } - async fn save(&mut self, _messages: Vec) -> StoreResult<()> { - Ok(()) - } - async fn retrieve_callbacks(&mut self, _id: &str) -> StoreResult>> { - Ok(vec![]) - } - async fn retrieve_datum(&mut self, _id: &str) -> StoreResult { - Ok(ProcessingStatus::Completed(vec![])) - } - async fn ready(&mut self) -> bool { - true - } - } - - #[tokio::test] - async fn test_async_publish() -> Result<(), Box> { - const ID_HEADER: &str = "X-Numaflow-Id"; - const ID_VALUE: &str = "foobar"; - let settings = Settings { - tid_header: ID_HEADER.into(), - ..Default::default() - }; - - let mock_store = MockStore {}; - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - let callback_state = CallbackState::new(msg_graph, mock_store).await?; - - let (messages_tx, mut messages_rx) = mpsc::channel::(10); - let response_collector = tokio::spawn(async move { - let message = messages_rx.recv().await.unwrap(); - let MessageWrapper { - confirm_save, - message, - } = message; - confirm_save.send(()).unwrap(); - message - }); - - let app_state = AppState { - message: messages_tx, - settings: Arc::new(settings), - callback_state, - }; - - let app = jetstream_proxy(app_state).await?; - let res = Request::builder() - .method("POST") - .uri("/async") - .header(CONTENT_TYPE, "text/plain") - .header(ID_HEADER, ID_VALUE) - .body(Body::from("Test Message")) - .unwrap(); - - let response = app.oneshot(res).await.unwrap(); - let message = response_collector.await.unwrap(); - assert_eq!(message.id, ID_VALUE); - assert_eq!(response.status(), StatusCode::OK); - - let result = extract_response_from_body(response.into_body()).await; - assert_eq!( - result, - json!({ - "message": "Successfully published message", - "id": ID_VALUE, - "code": 200 - }) - ); - Ok(()) - } - - async fn extract_response_from_body(body: Body) -> Value { - let bytes = to_bytes(body, usize::MAX).await.unwrap(); - let mut resp: Value = serde_json::from_slice(&bytes).unwrap(); - let _ = resp.as_object_mut().unwrap().remove("timestamp").unwrap(); - resp - } - - fn create_default_callbacks(id: &str) -> Vec { - vec![ - Callback { - id: id.to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: id.to_string(), - vertex: "planner".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - responses: vec![Response { - tags: Some(vec!["tiger".into()]), - }], - }, - Callback { - id: id.to_string(), - vertex: "tiger".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - responses: vec![Response { tags: None }], - }, - Callback { - id: id.to_string(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "tiger".to_string(), - responses: vec![Response { tags: None }], - }, - ] - } - - #[tokio::test] - async fn test_sync_publish() { - const ID_HEADER: &str = "X-Numaflow-ID"; - const ID_VALUE: &str = "foobar"; - let settings = Settings { - tid_header: ID_HEADER.into(), - ..Default::default() - }; - - let mem_store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - - let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); - - let (messages_tx, mut messages_rx) = mpsc::channel(10); - - let response_collector = tokio::spawn(async move { - let message = messages_rx.recv().await.unwrap(); - let MessageWrapper { - confirm_save, - message, - } = message; - confirm_save.send(()).unwrap(); - message - }); - - let app_state = AppState { - message: messages_tx, - settings: Arc::new(settings), - callback_state: callback_state.clone(), - }; - - let app = jetstream_proxy(app_state).await.unwrap(); - - tokio::spawn(async move { - let mut retries = 0; - callback_state - .save_response(ID_VALUE.into(), Bytes::from_static(b"test-output")) - .await - .unwrap(); - loop { - let cbs = create_default_callbacks(ID_VALUE); - match callback_state.insert_callback_requests(cbs).await { - Ok(_) => break, - Err(e) => { - retries += 1; - if retries > 10 { - panic!("Failed to insert callback requests: {:?}", e); - } - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - } - } - } - }); - - let res = Request::builder() - .method("POST") - .uri("/sync") - .header("Content-Type", "text/plain") - .header(ID_HEADER, ID_VALUE) - .body(Body::from("Test Message")) - .unwrap(); - - let response = app.clone().oneshot(res).await.unwrap(); - let message = response_collector.await.unwrap(); - assert_eq!(message.id, ID_VALUE); - assert_eq!(response.status(), StatusCode::OK); - - let result = to_bytes(response.into_body(), 10 * 1024).await.unwrap(); - assert_eq!(result, Bytes::from_static(b"test-output")); - } - - #[tokio::test] - async fn test_sync_publish_serve() { - const ID_VALUE: &str = "foobar"; - let settings = Arc::new(Settings::default()); - - let mem_store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - - let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); - - let (messages_tx, mut messages_rx) = mpsc::channel(10); - - let response_collector = tokio::spawn(async move { - let message = messages_rx.recv().await.unwrap(); - let MessageWrapper { - confirm_save, - message, - } = message; - confirm_save.send(()).unwrap(); - message - }); - - let app_state = AppState { - message: messages_tx, - settings, - callback_state: callback_state.clone(), - }; - - let app = jetstream_proxy(app_state).await.unwrap(); - - // pipeline is in -> cat -> out, so we will have 3 callback requests - - // spawn a tokio task which will insert the callback requests to the callback state - // if it fails, sleep for 10ms and retry - tokio::spawn(async move { - let mut retries = 0; - loop { - let cbs = create_default_callbacks(ID_VALUE); - match callback_state.insert_callback_requests(cbs).await { - Ok(_) => { - // save a test message, we should get this message when serve is invoked - // with foobar id - callback_state - .save_response("foobar".to_string(), Bytes::from("Test Message 1")) - .await - .unwrap(); - callback_state - .save_response( - "foobar".to_string(), - Bytes::from("Another Test Message 2"), - ) - .await - .unwrap(); - break; - } - Err(e) => { - retries += 1; - if retries > 10 { - panic!("Failed to insert callback requests: {:?}", e); - } - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - } - } - } - }); - - let req = Request::builder() - .method("POST") - .uri("/sync") - .header("Content-Type", "text/plain") - .header(DEFAULT_ID_HEADER, ID_VALUE) - .body(Body::from("Test Message")) - .unwrap(); - - let response = app.clone().oneshot(req).await.unwrap(); - let message = response_collector.await.unwrap(); - assert_eq!(message.id, ID_VALUE); - - assert_eq!(response.status(), StatusCode::OK); - - let content_len = response.headers().get(CONTENT_LENGTH).unwrap(); - assert_eq!(content_len.as_bytes(), b"36"); - - let hval_response_len = response.headers().get(NUMAFLOW_RESP_ARRAY_LEN).unwrap(); - assert_eq!(hval_response_len.as_bytes(), b"2"); - - let hval_response_array_len = response.headers().get(NUMAFLOW_RESP_ARRAY_IDX_LEN).unwrap(); - assert_eq!(hval_response_array_len.as_bytes(), b"14,22"); - - let result = to_bytes(response.into_body(), usize::MAX).await.unwrap(); - assert_eq!(result, "Test Message 1Another Test Message 2".as_bytes()); - - // Get result for the request id using /fetch endpoint - let req = Request::builder() - .method("GET") - .uri(format!("/fetch?id={ID_VALUE}")) - .body(axum::body::Body::empty()) - .unwrap(); - let response = app.clone().oneshot(req).await.unwrap(); - assert_eq!(response.status(), StatusCode::OK); - let serve_resp = to_bytes(response.into_body(), usize::MAX).await.unwrap(); - assert_eq!( - serve_resp, - Bytes::from_static(b"Test Message 1Another Test Message 2") - ); - - // Request for an id that doesn't exist in the store - let req = Request::builder() - .method("GET") - .uri("/fetch?id=unknown") - .body(axum::body::Body::empty()) - .unwrap(); - let response = app.oneshot(req).await.unwrap(); - assert_eq!(response.status(), StatusCode::NOT_FOUND); - } -} +// #[cfg(test)] +// mod tests { +// use std::sync::Arc; +// +// use axum::body::{to_bytes, Body}; +// use axum::extract::Request; +// use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; +// use serde_json::{json, Value}; +// use tower::ServiceExt; +// +// use super::*; +// use crate::app::callback::datumstore::memstore::InMemoryStore; +// use crate::app::callback::datumstore::Result as StoreResult; +// use crate::app::callback::state::State as CallbackState; +// use crate::app::tracker::MessageGraph; +// use crate::callback::{Callback, Response}; +// use crate::config::DEFAULT_ID_HEADER; +// use crate::pipeline::PipelineDCG; +// use crate::Settings; +// +// const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; +// +// #[derive(Clone)] +// struct MockStore; +// +// impl DatumStore for MockStore { +// async fn retrieve_datum(&mut self, _id: &str) -> StoreResult>>> { +// Ok(Some(vec![])) +// } +// async fn ready(&mut self) -> bool { +// true +// } +// } +// +// #[tokio::test] +// async fn test_async_publish() -> Result<(), Box> { +// const ID_HEADER: &str = "X-Numaflow-Id"; +// const ID_VALUE: &str = "foobar"; +// let settings = Settings { +// tid_header: ID_HEADER.into(), +// ..Default::default() +// }; +// +// let mock_store = MockStore {}; +// let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; +// let callback_state = CallbackState::new(msg_graph, mock_store).await?; +// +// let (messages_tx, mut messages_rx) = mpsc::channel::(10); +// let response_collector = tokio::spawn(async move { +// let message = messages_rx.recv().await.unwrap(); +// let MessageWrapper { +// confirm_save, +// message, +// } = message; +// confirm_save.send(()).unwrap(); +// message +// }); +// +// let app_state = AppState { +// message: messages_tx, +// settings: Arc::new(settings), +// callback_state, +// }; +// +// let app = jetstream_proxy(app_state).await?; +// let res = Request::builder() +// .method("POST") +// .uri("/async") +// .header(CONTENT_TYPE, "text/plain") +// .header(ID_HEADER, ID_VALUE) +// .body(Body::from("Test Message")) +// .unwrap(); +// +// let response = app.oneshot(res).await.unwrap(); +// let message = response_collector.await.unwrap(); +// assert_eq!(message.id, ID_VALUE); +// assert_eq!(response.status(), StatusCode::OK); +// +// let result = extract_response_from_body(response.into_body()).await; +// assert_eq!( +// result, +// json!({ +// "message": "Successfully published message", +// "id": ID_VALUE, +// "code": 200 +// }) +// ); +// Ok(()) +// } +// +// async fn extract_response_from_body(body: Body) -> Value { +// let bytes = to_bytes(body, usize::MAX).await.unwrap(); +// let mut resp: Value = serde_json::from_slice(&bytes).unwrap(); +// let _ = resp.as_object_mut().unwrap().remove("timestamp").unwrap(); +// resp +// } +// +// fn create_default_callbacks(id: &str) -> Vec { +// vec![ +// Callback { +// id: id.to_string(), +// vertex: "in".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { tags: None }], +// }, +// Callback { +// id: id.to_string(), +// vertex: "planner".to_string(), +// cb_time: 12345, +// from_vertex: "in".to_string(), +// responses: vec![Response { +// tags: Some(vec!["tiger".into()]), +// }], +// }, +// Callback { +// id: id.to_string(), +// vertex: "tiger".to_string(), +// cb_time: 12345, +// from_vertex: "planner".to_string(), +// responses: vec![Response { tags: None }], +// }, +// Callback { +// id: id.to_string(), +// vertex: "serve-sink".to_string(), +// cb_time: 12345, +// from_vertex: "tiger".to_string(), +// responses: vec![Response { tags: None }], +// }, +// ] +// } +// +// #[tokio::test] +// async fn test_sync_publish() { +// const ID_HEADER: &str = "X-Numaflow-ID"; +// const ID_VALUE: &str = "foobar"; +// let settings = Settings { +// tid_header: ID_HEADER.into(), +// ..Default::default() +// }; +// +// let mem_store = InMemoryStore::new(); +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// +// let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); +// +// let (messages_tx, mut messages_rx) = mpsc::channel(10); +// +// let response_collector = tokio::spawn(async move { +// let message = messages_rx.recv().await.unwrap(); +// let MessageWrapper { +// confirm_save, +// message, +// } = message; +// confirm_save.send(()).unwrap(); +// message +// }); +// +// let app_state = AppState { +// message: messages_tx, +// settings: Arc::new(settings), +// callback_state: callback_state.clone(), +// }; +// +// let app = jetstream_proxy(app_state).await.unwrap(); +// +// tokio::spawn(async move { +// let mut retries = 0; +// callback_state +// .save_response(ID_VALUE.into(), Bytes::from_static(b"test-output")) +// .await +// .unwrap(); +// loop { +// let cbs = create_default_callbacks(ID_VALUE); +// match callback_state.insert_callback_requests(cbs).await { +// Ok(_) => break, +// Err(e) => { +// retries += 1; +// if retries > 10 { +// panic!("Failed to insert callback requests: {:?}", e); +// } +// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; +// } +// } +// } +// }); +// +// let res = Request::builder() +// .method("POST") +// .uri("/sync") +// .header("Content-Type", "text/plain") +// .header(ID_HEADER, ID_VALUE) +// .body(Body::from("Test Message")) +// .unwrap(); +// +// let response = app.clone().oneshot(res).await.unwrap(); +// let message = response_collector.await.unwrap(); +// assert_eq!(message.id, ID_VALUE); +// assert_eq!(response.status(), StatusCode::OK); +// +// let result = to_bytes(response.into_body(), 10 * 1024).await.unwrap(); +// assert_eq!(result, Bytes::from_static(b"test-output")); +// } +// +// #[tokio::test] +// async fn test_sync_publish_serve() { +// const ID_VALUE: &str = "foobar"; +// let settings = Arc::new(Settings::default()); +// +// let mem_store = InMemoryStore::new(); +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// +// let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); +// +// let (messages_tx, mut messages_rx) = mpsc::channel(10); +// +// let response_collector = tokio::spawn(async move { +// let message = messages_rx.recv().await.unwrap(); +// let MessageWrapper { +// confirm_save, +// message, +// } = message; +// confirm_save.send(()).unwrap(); +// message +// }); +// +// let app_state = AppState { +// message: messages_tx, +// settings, +// callback_state: callback_state.clone(), +// }; +// +// let app = jetstream_proxy(app_state).await.unwrap(); +// +// // pipeline is in -> cat -> out, so we will have 3 callback requests +// +// // spawn a tokio task which will insert the callback requests to the callback state +// // if it fails, sleep for 10ms and retry +// tokio::spawn(async move { +// let mut retries = 0; +// loop { +// let cbs = create_default_callbacks(ID_VALUE); +// match callback_state.insert_callback_requests(cbs).await { +// Ok(_) => { +// // save a test message, we should get this message when serve is invoked +// // with foobar id +// callback_state +// .save_response("foobar".to_string(), Bytes::from("Test Message 1")) +// .await +// .unwrap(); +// callback_state +// .save_response( +// "foobar".to_string(), +// Bytes::from("Another Test Message 2"), +// ) +// .await +// .unwrap(); +// break; +// } +// Err(e) => { +// retries += 1; +// if retries > 10 { +// panic!("Failed to insert callback requests: {:?}", e); +// } +// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; +// } +// } +// } +// }); +// +// let req = Request::builder() +// .method("POST") +// .uri("/sync") +// .header("Content-Type", "text/plain") +// .header(DEFAULT_ID_HEADER, ID_VALUE) +// .body(Body::from("Test Message")) +// .unwrap(); +// +// let response = app.clone().oneshot(req).await.unwrap(); +// let message = response_collector.await.unwrap(); +// assert_eq!(message.id, ID_VALUE); +// +// assert_eq!(response.status(), StatusCode::OK); +// +// let content_len = response.headers().get(CONTENT_LENGTH).unwrap(); +// assert_eq!(content_len.as_bytes(), b"36"); +// +// let hval_response_len = response.headers().get(NUMAFLOW_RESP_ARRAY_LEN).unwrap(); +// assert_eq!(hval_response_len.as_bytes(), b"2"); +// +// let hval_response_array_len = response.headers().get(NUMAFLOW_RESP_ARRAY_IDX_LEN).unwrap(); +// assert_eq!(hval_response_array_len.as_bytes(), b"14,22"); +// +// let result = to_bytes(response.into_body(), usize::MAX).await.unwrap(); +// assert_eq!(result, "Test Message 1Another Test Message 2".as_bytes()); +// +// // Get result for the request id using /fetch endpoint +// let req = Request::builder() +// .method("GET") +// .uri(format!("/fetch?id={ID_VALUE}")) +// .body(axum::body::Body::empty()) +// .unwrap(); +// let response = app.clone().oneshot(req).await.unwrap(); +// assert_eq!(response.status(), StatusCode::OK); +// let serve_resp = to_bytes(response.into_body(), usize::MAX).await.unwrap(); +// assert_eq!( +// serve_resp, +// Bytes::from_static(b"Test Message 1Another Test Message 2") +// ); +// +// // Request for an id that doesn't exist in the store +// let req = Request::builder() +// .method("GET") +// .uri("/fetch?id=unknown") +// .body(axum::body::Body::empty()) +// .unwrap(); +// let response = app.oneshot(req).await.unwrap(); +// assert_eq!(response.status(), StatusCode::NOT_FOUND); +// } +// } diff --git a/rust/serving/src/app/message_path.rs b/rust/serving/src/app/message_path.rs index c93abaa582..6b9c687dd9 100644 --- a/rust/serving/src/app/message_path.rs +++ b/rust/serving/src/app/message_path.rs @@ -4,11 +4,15 @@ use axum::{ }; use super::callback::state::State as CallbackState; -use super::callback::store::Store; +use crate::app::callback::cbstore::CallbackStore; +use crate::app::callback::datumstore::DatumStore; use crate::app::response::ApiError; -pub fn get_message_path( - callback_store: CallbackState, +pub fn get_message_path< + T: Send + Sync + Clone + DatumStore + 'static, + C: Send + Sync + Clone + CallbackStore + 'static, +>( + callback_store: CallbackState, ) -> Router { Router::new() .route("/status", routing::get(message_path)) @@ -20,11 +24,14 @@ struct MessagePath { id: String, } -async fn message_path( - State(mut store): State>, +async fn message_path< + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, +>( + State(mut state): State>, Query(MessagePath { id }): Query, ) -> Result { - match store.retrieve_subgraph_from_storage(&id).await { + match state.retrieve_subgraph_from_storage(&id).await { Ok(subgraph) => Ok(subgraph), Err(e) => { tracing::error!(error=?e); @@ -35,37 +42,37 @@ async fn message_path( } } -#[cfg(test)] -mod tests { - use axum::body::Body; - use axum::extract::Request; - use axum::http::header::CONTENT_TYPE; - use axum::http::StatusCode; - use tower::ServiceExt; - - use super::*; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::tracker::MessageGraph; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_message_path_not_present() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = get_message_path(state); - - let res = Request::builder() - .method("GET") - .uri("/status?id=test_id") - .header(CONTENT_TYPE, "application/json") - .body(Body::empty()) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } -} +// #[cfg(test)] +// mod tests { +// use axum::body::Body; +// use axum::extract::Request; +// use axum::http::header::CONTENT_TYPE; +// use axum::http::StatusCode; +// use tower::ServiceExt; +// +// use super::*; +// use crate::app::callback::datumstore::memstore::InMemoryStore; +// use crate::app::tracker::MessageGraph; +// use crate::pipeline::PipelineDCG; +// +// const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; +// +// #[tokio::test] +// async fn test_message_path_not_present() { +// let store = InMemoryStore::new(); +// let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); +// let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); +// let state = CallbackState::new(msg_graph, store).await.unwrap(); +// let app = get_message_path(state); +// +// let res = Request::builder() +// .method("GET") +// .uri("/status?id=test_id") +// .header(CONTENT_TYPE, "application/json") +// .body(Body::empty()) +// .unwrap(); +// +// let resp = app.oneshot(res).await.unwrap(); +// assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); +// } +// } diff --git a/rust/serving/src/callback.rs b/rust/serving/src/callback.rs index 7046352766..b59eea3f11 100644 --- a/rust/serving/src/callback.rs +++ b/rust/serving/src/callback.rs @@ -1,13 +1,18 @@ use std::{ sync::Arc, - time::{Duration, SystemTime, UNIX_EPOCH}, + time::{SystemTime, UNIX_EPOCH}, }; -use reqwest::Client; +use async_nats::jetstream::kv::Store; +use async_nats::jetstream::Context; +use backoff::retry::Retry; +use backoff::strategy::fixed; +use bytes::Bytes; use serde::{Deserialize, Serialize}; use tokio::{sync::Semaphore, task::JoinHandle}; +use tracing::{error, warn}; -use crate::config::DEFAULT_ID_HEADER; +use crate::Error; /// As message passes through each component (map, transformer, sink, etc.). it emits a beacon via callback /// to inform that message has been processed by this component. @@ -31,26 +36,28 @@ pub(crate) struct Response { #[derive(Clone)] pub struct CallbackHandler { + semaphore: Arc, + store: Store, /// the client to callback to the request originating pod/container - client: Client, vertex_name: String, - semaphore: Arc, } impl CallbackHandler { - pub fn new(vertex_name: String, concurrency_limit: usize) -> Self { - let client = Client::builder() - .danger_accept_invalid_certs(true) - .timeout(Duration::from_secs(1)) - .build() - .expect("Creating callback client for Serving source"); - + pub async fn new( + vertex_name: String, + js_context: Context, + store_name: &'static str, + concurrency_limit: usize, + ) -> Self { + let store = js_context + .get_key_value(store_name) + .await + .expect("Failed to get kv store"); let semaphore = Arc::new(Semaphore::new(concurrency_limit)); - Self { - client, - vertex_name, semaphore, + vertex_name, + store, } } @@ -58,7 +65,6 @@ impl CallbackHandler { pub async fn callback( &self, id: String, - callback_url: String, previous_vertex: String, responses: Vec>>, ) -> crate::Result> { @@ -81,79 +87,30 @@ impl CallbackHandler { }; let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - let client = self.client.clone(); + let store = self.store.clone(); let handle = tokio::spawn(async move { + let interval = fixed::Interval::from_millis(1000).take(2); let _permit = permit; - // Retry in case of failure in making request. - // When there is a failure, we retry after wait_secs. This value is doubled after each retry attempt. - // Then longest wait time will be 64 seconds. - let mut wait_secs = 1; - // TODO: let's do only 2 retries and write directly to the DB. - const TOTAL_ATTEMPTS: usize = 7; - for i in 1..=TOTAL_ATTEMPTS { - let resp = client - .post(&callback_url) - .header(DEFAULT_ID_HEADER, id.clone()) - .json(&[&callback_payload]) - .send() - .await; - let resp = match resp { - Ok(resp) => resp, - Err(e) => { - if i < TOTAL_ATTEMPTS { - tracing::warn!( - ?e, - "Sending callback request failed. Will retry after a delay" - ); - // TODO: this sleep is a LOT, should in < 10ms. we do not want to retry - // for the pod/container to comeback but rather retrying to transient errors. - // Also see whether `reqwest` will do an internal retry. - tokio::time::sleep(Duration::from_secs(wait_secs)).await; - wait_secs *= 2; - } else { - tracing::error!(?e, "Sending callback request failed"); + let value = serde_json::to_string(&callback_payload).expect("Failed to serialize"); + let result = Retry::retry( + interval, + || async { + match store.put(id.clone(), Bytes::from(value.clone())).await { + Ok(resp) => Ok(resp), + Err(e) => { + warn!(?e, "Failed to write callback to store, retrying.."); + Err(Error::Other(format!( + "Failed to write callback to store: {}", + e + ))) } - continue; } - }; - - if resp.status().is_success() { - break; - } - - if resp.status().is_client_error() { - // TODO: When the source serving pod restarts, the callbacks will fail with 4xx status - // since the request ID won't be available in it's in-memory tracker. - // No point in retrying such cases - // 4xx can also happen if payload is wrong (due to bugs in the code). We should differentiate - // between what can be retried and not. - let status_code = resp.status(); - let response_body = resp.text().await; - tracing::error!( - ?status_code, - ?response_body, - "Received client error while making callback. Callback will not be retried" - ); - break; - } - - let status_code = resp.status(); - let response_body = resp.text().await; - if i < TOTAL_ATTEMPTS { - tracing::warn!( - ?status_code, - ?response_body, - "Received non-OK status for callback request. Will retry after a delay" - ); - tokio::time::sleep(Duration::from_secs(wait_secs)).await; - wait_secs *= 2; - } else { - tracing::error!( - ?status_code, - ?response_body, - "Received non-OK status for callback request" - ); - } + }, + |_: &Error| true, + ) + .await; + if let Err(e) = result { + error!(?e, "Failed to write callback to store"); } }); @@ -161,195 +118,187 @@ impl CallbackHandler { } } -#[cfg(test)] -mod tests { - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; - use std::time::Duration; - - use axum::http::StatusCode; - use axum::routing::{get, post}; - use axum::{Json, Router}; - use axum_server::tls_rustls::RustlsConfig; - use tokio::sync::mpsc; - - use crate::app::callback::state::State as CallbackState; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::start_main_server; - use crate::app::tracker::MessageGraph; - use crate::callback::CallbackHandler; - use crate::config::generate_certs; - use crate::pipeline::PipelineDCG; - use crate::test_utils::get_port; - use crate::{AppState, Settings}; - - type Result = std::result::Result>; - - #[tokio::test] - async fn test_successful_callback() -> Result<()> { - // Set up the CryptoProvider (controls core cryptography used by rustls) for the process - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - - let (cert, key) = generate_certs()?; - - let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) - .await - .map_err(|e| format!("Failed to create tls config {:?}", e))?; - - let port = get_port(); - let settings = Settings { - app_listen_port: port, - ..Default::default() - }; - // We start the 'Serving' https server with an in-memory store - // When the server receives callback request, the in-memory store will be populated. - // This is verified at the end of the test. - let store = InMemoryStore::new(); - let message_graph = MessageGraph::from_pipeline(&PipelineDCG::default())?; - let (tx, _) = mpsc::channel(10); - - let mut app_state = AppState { - message: tx, - settings: Arc::new(settings), - callback_state: CallbackState::new(message_graph, store.clone()).await?, - }; - - // We use this value as the request id of the callback request - const ID_VALUE: &str = "1234"; - - // Register the request id in the store. This normally happens when the Serving source - // receives a request from the client. The callbacks for this request must only happen after this. - let _callback_notify_rx = app_state - .callback_state - .register(Some(ID_VALUE.into())) - .await; - - let server_handle = tokio::spawn(start_main_server(app_state, tls_config)); - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(2)) - .danger_accept_invalid_certs(true) - .build()?; - - // Wait for the server to be ready - let mut server_ready = false; - for _ in 0..10 { - let resp = client - .get(format!("https://localhost:{port}/livez")) - .send() - .await?; - if resp.status().is_success() { - server_ready = true; - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - assert!(server_ready, "Server is not ready"); - - let callback_handler = CallbackHandler::new("test".into(), 10); - - // On the server, this fails with SubGraphInvalidInput("Invalid callback: 1234, vertex: in") - // We get 200 OK response from the server, since we already registered this request ID in the store. - callback_handler - .callback( - ID_VALUE.into(), - format!("https://localhost:{port}/v1/process/callback"), - "in".into(), - vec![], - ) - .await?; - let mut data = None; - for _ in 0..10 { - tokio::time::sleep(Duration::from_millis(2)).await; - data = { - let guard = store.data.lock().unwrap(); - guard.get(ID_VALUE).cloned() - }; - if data.is_some() { - break; - } - } - assert!(data.is_some(), "Callback data not found in store"); - server_handle.abort(); - Ok(()) - } - - #[tokio::test] - // Starts a custom server that handles requests to `/v1/process/callback`. - // The request handler will return INTERNAL_ERROR for the first 2 requests. This should result in - // retry on the client side. Then the handler responds with BAD_REQUEST, which should cause the client - // to abort. - async fn test_callback_retry() -> Result<()> { - // Set up the CryptoProvider (controls core cryptography used by rustls) for the process - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - - let (cert, key) = generate_certs()?; - - let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) - .await - .map_err(|e| format!("Failed to create tls config {:?}", e))?; - - let port = get_port(); - let server_addr = format!("127.0.0.1:{port}"); - let callback_url = format!("https://{server_addr}/v1/process/callback"); - - let request_count = Arc::new(AtomicUsize::new(0)); - let router = Router::new() - .route("/livez", get(|| async { StatusCode::OK })) - .route( - "/v1/process/callback", - post({ - let req_count = Arc::clone(&request_count); - |payload: Json| async move { - tracing::info!(?payload, "Get request"); - if req_count.fetch_add(1, Ordering::Relaxed) < 2 { - StatusCode::INTERNAL_SERVER_ERROR - } else { - StatusCode::BAD_REQUEST - } - } - }), - ); - - let sock_addr = server_addr.as_str().parse().unwrap(); - let server = tokio::spawn(async move { - axum_server::bind_rustls(sock_addr, tls_config) - .serve(router.into_make_service()) - .await - .unwrap(); - }); - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(2)) - .danger_accept_invalid_certs(true) - .build()?; - - // Wait for the server to be ready - let mut server_ready = false; - let health_url = format!("https://{server_addr}/livez"); - for _ in 0..10 { - let Ok(resp) = client.get(&health_url).send().await else { - tokio::time::sleep(Duration::from_millis(5)).await; - continue; - }; - if resp.status().is_success() { - server_ready = true; - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - assert!(server_ready, "Server is not ready"); - - let callback_handler = CallbackHandler::new("test".into(), 10); - - // On the server, this fails with SubGraphInvalidInput("Invalid callback: 1234, vertex: in") - // We get 200 OK response from the server, since we already registered this request ID in the store. - let callback_task = callback_handler - .callback("1234".into(), callback_url, "in".into(), vec![]) - .await?; - assert!(callback_task.await.is_ok()); - server.abort(); - assert_eq!(request_count.load(Ordering::Relaxed), 3); - Ok(()) - } -} +// #[cfg(test)] +// mod tests { +// use std::sync::atomic::{AtomicUsize, Ordering}; +// use std::sync::Arc; +// use std::time::Duration; +// +// use axum::http::StatusCode; +// use axum::routing::{get, post}; +// use axum::{Json, Router}; +// use axum_server::tls_rustls::RustlsConfig; +// use tokio::sync::mpsc; +// +// use crate::app::callback::datumstore::memstore::InMemoryStore; +// use crate::app::callback::state::State as CallbackState; +// use crate::app::start_main_server; +// use crate::app::tracker::MessageGraph; +// use crate::callback::CallbackHandler; +// use crate::config::generate_certs; +// use crate::pipeline::PipelineDCG; +// use crate::test_utils::get_port; +// use crate::{AppState, Settings}; +// +// type Result = std::result::Result>; +// +// #[tokio::test] +// async fn test_successful_callback() -> Result<()> { +// // Set up the CryptoProvider (controls core cryptography used by rustls) for the process +// let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); +// +// let (cert, key) = generate_certs()?; +// +// let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) +// .await +// .map_err(|e| format!("Failed to create tls config {:?}", e))?; +// +// let port = get_port(); +// let settings = Settings { +// app_listen_port: port, +// ..Default::default() +// }; +// // We start the 'Serving' https server with an in-memory store +// // When the server receives callback request, the in-memory store will be populated. +// // This is verified at the end of the test. +// let store = InMemoryStore::new(); +// let message_graph = MessageGraph::from_pipeline(&PipelineDCG::default())?; +// let (tx, _) = mpsc::channel(10); +// +// let mut app_state = AppState { +// message: tx, +// settings: Arc::new(settings), +// callback_state: CallbackState::new(message_graph, store.clone()).await?, +// }; +// +// // We use this value as the request id of the callback request +// const ID_VALUE: &str = "1234"; +// +// // Register the request id in the datumstore. This normally happens when the Serving source +// // receives a request from the client. The callbacks for this request must only happen after this. +// let _callback_notify_rx = app_state.callback_state.register(ID_VALUE.into()).await; +// +// let server_handle = tokio::spawn(start_main_server(app_state, tls_config)); +// +// let client = reqwest::Client::builder() +// .timeout(Duration::from_secs(2)) +// .danger_accept_invalid_certs(true) +// .build()?; +// +// // Wait for the server to be ready +// let mut server_ready = false; +// for _ in 0..10 { +// let resp = client +// .get(format!("https://localhost:{port}/livez")) +// .send() +// .await?; +// if resp.status().is_success() { +// server_ready = true; +// break; +// } +// tokio::time::sleep(Duration::from_millis(5)).await; +// } +// assert!(server_ready, "Server is not ready"); +// +// let callback_handler = CallbackHandler::new("test".into(), 10); +// +// // On the server, this fails with SubGraphInvalidInput("Invalid callback: 1234, vertex: in") +// // We get 200 OK response from the server, since we already registered this request ID in the store. +// callback_handler +// .callback(ID_VALUE.into(), "in".into(), vec![]) +// .await?; +// let mut data = None; +// for _ in 0..10 { +// tokio::time::sleep(Duration::from_millis(2)).await; +// data = { +// let guard = store.data.lock().unwrap(); +// guard.get(ID_VALUE).cloned() +// }; +// if data.is_some() { +// break; +// } +// } +// assert!(data.is_some(), "Callback data not found in store"); +// server_handle.abort(); +// Ok(()) +// } +// +// #[tokio::test] +// // Starts a custom server that handles requests to `/v1/process/callback`. +// // The request handler will return INTERNAL_ERROR for the first 2 requests. This should result in +// // retry on the client side. Then the handler responds with BAD_REQUEST, which should cause the client +// // to abort. +// async fn test_callback_retry() -> Result<()> { +// // Set up the CryptoProvider (controls core cryptography used by rustls) for the process +// let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); +// +// let (cert, key) = generate_certs()?; +// +// let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) +// .await +// .map_err(|e| format!("Failed to create tls config {:?}", e))?; +// +// let port = get_port(); +// let server_addr = format!("127.0.0.1:{port}"); +// let callback_url = format!("https://{server_addr}/v1/process/callback"); +// +// let request_count = Arc::new(AtomicUsize::new(0)); +// let router = Router::new() +// .route("/livez", get(|| async { StatusCode::OK })) +// .route( +// "/v1/process/callback", +// post({ +// let req_count = Arc::clone(&request_count); +// |payload: Json| async move { +// tracing::info!(?payload, "Get request"); +// if req_count.fetch_add(1, Ordering::Relaxed) < 2 { +// StatusCode::INTERNAL_SERVER_ERROR +// } else { +// StatusCode::BAD_REQUEST +// } +// } +// }), +// ); +// +// let sock_addr = server_addr.as_str().parse().unwrap(); +// let server = tokio::spawn(async move { +// axum_server::bind_rustls(sock_addr, tls_config) +// .serve(router.into_make_service()) +// .await +// .unwrap(); +// }); +// +// let client = reqwest::Client::builder() +// .timeout(Duration::from_secs(2)) +// .danger_accept_invalid_certs(true) +// .build()?; +// +// // Wait for the server to be ready +// let mut server_ready = false; +// let health_url = format!("https://{server_addr}/livez"); +// for _ in 0..10 { +// let Ok(resp) = client.get(&health_url).send().await else { +// tokio::time::sleep(Duration::from_millis(5)).await; +// continue; +// }; +// if resp.status().is_success() { +// server_ready = true; +// break; +// } +// tokio::time::sleep(Duration::from_millis(5)).await; +// } +// assert!(server_ready, "Server is not ready"); +// +// let callback_handler = CallbackHandler::new("test".into(), 10); +// +// // On the server, this fails with SubGraphInvalidInput("Invalid callback: 1234, vertex: in") +// // We get 200 OK response from the server, since we already registered this request ID in the store. +// let callback_task = callback_handler +// .callback("1234".into(), "in".into(), vec![]) +// .await?; +// assert!(callback_task.await.is_ok()); +// server.abort(); +// assert_eq!(request_count.load(Ordering::Relaxed), 3); +// Ok(()) +// } +// } diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 6c3076441b..fae62d70d3 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -1,12 +1,12 @@ use std::collections::HashMap; use std::fmt::Debug; +use std::time::Duration; use base64::prelude::BASE64_STANDARD; use base64::Engine; use numaflow_models::models::{MonoVertex, Vertex}; use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; use serde::{Deserialize, Serialize}; -use std::time::Duration; use crate::{ pipeline::PipelineDCG, @@ -61,6 +61,7 @@ pub struct Settings { pub upstream_addr: String, pub drain_timeout_secs: u64, pub redis: RedisConfig, + pub js_store: String, /// The IP address of the numaserve pod. This will be used to construct the value for X-Numaflow-Callback-Url header pub host_ip: String, pub api_auth_token: Option, @@ -76,6 +77,7 @@ impl Default for Settings { upstream_addr: "localhost:8888".to_owned(), drain_timeout_secs: 600, redis: RedisConfig::default(), + js_store: "kv".to_owned(), host_ip: "127.0.0.1".to_owned(), api_auth_token: None, pipeline_spec: Default::default(), diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index 0cb8c2d92a..d7bf7ea4c5 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -1,7 +1,6 @@ use std::net::SocketAddr; use std::sync::Arc; -use app::callback::store::Store; use axum_server::tls_rustls::RustlsConfig; use tokio::sync::mpsc; use tracing::info; @@ -25,20 +24,24 @@ pub mod source; use source::MessageWrapper; pub use source::{Message, ServingSource}; +use crate::app::callback::cbstore::CallbackStore; +use crate::app::callback::datumstore::DatumStore; + pub mod callback; #[derive(Clone)] -pub(crate) struct AppState { +pub(crate) struct AppState { pub(crate) message: mpsc::Sender, pub(crate) settings: Arc, - pub(crate) callback_state: CallbackState, + pub(crate) callback_state: CallbackState, } -pub(crate) async fn serve( - app: AppState, +pub(crate) async fn serve( + app: AppState, ) -> std::result::Result<(), Box> where - T: Clone + Send + Sync + Store + 'static, + T: Clone + Send + Sync + DatumStore + 'static, + C: Clone + Send + Sync + CallbackStore + 'static, { let (cert, key) = generate_certs()?; diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index 6f1c5629c0..1d4ddbaf5c 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -2,12 +2,14 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; +use async_nats::jetstream::Context; use bytes::Bytes; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; +use crate::app::callback::cbstore::jetstreamstore::JSCallbackStore; +use crate::app::callback::datumstore::redisstore::RedisConnection; use crate::app::callback::state::State as CallbackState; -use crate::app::callback::store::redisstore::RedisConnection; use crate::app::tracker::MessageGraph; use crate::config::{DEFAULT_CALLBACK_URL_HEADER_KEY, DEFAULT_ID_HEADER}; use crate::Settings; @@ -56,6 +58,7 @@ struct ServingSourceActor { impl ServingSourceActor { async fn start( + js_context: Context, settings: Arc, handler_rx: mpsc::Receiver, request_channel_buffer_size: usize, @@ -72,7 +75,8 @@ impl ServingSourceActor { e )) })?; - let callback_state = CallbackState::new(msg_graph, redis_store).await?; + let callback_store = JSCallbackStore::new(js_context, &settings.js_store).await?; + let callback_state = CallbackState::new(msg_graph, redis_store, callback_store).await?; let callback_url = format!( "https://{}:{}/v1/process/callback", @@ -191,13 +195,21 @@ pub struct ServingSource { impl ServingSource { pub async fn new( + context: Context, settings: Arc, batch_size: usize, timeout: Duration, vertex_replica_id: u16, ) -> Result { let (actor_tx, actor_rx) = mpsc::channel(2 * batch_size); - ServingSourceActor::start(settings, actor_rx, 2 * batch_size, vertex_replica_id).await?; + ServingSourceActor::start( + context, + settings, + actor_rx, + 2 * batch_size, + vertex_replica_id, + ) + .await?; Ok(Self { batch_size, timeout, @@ -241,18 +253,34 @@ impl ServingSource { mod tests { use std::{sync::Arc, time::Duration}; + use async_nats::jetstream; + use super::ServingSource; use crate::Settings; type Result = std::result::Result>; #[tokio::test] async fn test_serving_source() -> Result<()> { + let js_url = "localhost:4222"; + let client = async_nats::connect(js_url).await.unwrap(); + let context = jetstream::new(client); + // Setup the CryptoProvider (controls core cryptography used by rustls) for the process let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - let settings = Arc::new(Settings::default()); - let serving_source = - ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1), 0).await?; + let mut settings = Arc::new(Settings { + js_store: "test_serving_source".to_string(), + ..Default::default() + }); + + let serving_source = ServingSource::new( + context, + Arc::clone(&settings), + 10, + Duration::from_millis(1), + 0, + ) + .await?; let client = reqwest::Client::builder() .timeout(Duration::from_secs(2))