From 221d260ea19ec3dea711ec37d60d3c5bef440774 Mon Sep 17 00:00:00 2001 From: Adam Hamrick Date: Fri, 24 Jan 2025 13:52:24 -0500 Subject: [PATCH] [TT-1842] [TT-1608] Parrot Server (#1595) Add Parrot Server --- .github/workflows/docker-test.yaml | 2 +- .github/workflows/framework-golden-tests.yml | 8 +- .github/workflows/framework.yml | 4 +- .github/workflows/generate-go-docs.yaml | 2 +- .github/workflows/k8s-e2e.yaml | 4 +- .github/workflows/lint.yaml | 4 +- .github/workflows/parrot-release.yml | 44 ++ .github/workflows/rc-breaking-changes.yaml | 4 +- .github/workflows/release-go-module.yml | 2 +- .github/workflows/test.yaml | 4 +- .gitignore | 7 +- parrot/.dockerignore | 13 + parrot/.goreleaser.yaml | 59 ++ parrot/Dockerfile | 3 + parrot/Makefile | 37 + parrot/README.md | 35 + parrot/cmd/main.go | 83 ++ parrot/errors.go | 43 ++ parrot/examples_test.go | 315 ++++++++ parrot/go.mod | 24 + parrot/go.sum | 43 ++ parrot/parrot.go | 756 +++++++++++++++++++ parrot/parrot_benchmark_test.go | 145 ++++ parrot/parrot_test.go | 540 +++++++++++++ parrot/recorder.go | 175 +++++ parrot/recorder_test.go | 203 +++++ 26 files changed, 2543 insertions(+), 16 deletions(-) create mode 100644 .github/workflows/parrot-release.yml create mode 100644 parrot/.dockerignore create mode 100644 parrot/.goreleaser.yaml create mode 100644 parrot/Dockerfile create mode 100644 parrot/Makefile create mode 100644 parrot/README.md create mode 100644 parrot/cmd/main.go create mode 100644 parrot/errors.go create mode 100644 parrot/examples_test.go create mode 100644 parrot/go.mod create mode 100644 parrot/go.sum create mode 100644 parrot/parrot.go create mode 100644 parrot/parrot_benchmark_test.go create mode 100644 parrot/parrot_test.go create mode 100644 parrot/recorder.go create mode 100644 parrot/recorder_test.go diff --git a/.github/workflows/docker-test.yaml b/.github/workflows/docker-test.yaml index 3a8b782cc..38a7ee0be 100644 --- a/.github/workflows/docker-test.yaml +++ b/.github/workflows/docker-test.yaml @@ -46,7 +46,7 @@ jobs: go test -timeout 20m -json -parallel 2 -cover -covermode=atomic -coverprofile=unit-test-coverage.out $(go list ./... | grep /docker/test_env) -run '${{ matrix.test.tests }}' 2>&1 | tee /tmp/gotest.log | ../gotestloghelper -ci - name: Publish Artifacts if: failure() - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v4 with: name: test-logs path: ./lib/logs diff --git a/.github/workflows/framework-golden-tests.yml b/.github/workflows/framework-golden-tests.yml index fe851cffd..61a759c8b 100644 --- a/.github/workflows/framework-golden-tests.yml +++ b/.github/workflows/framework-golden-tests.yml @@ -22,7 +22,7 @@ jobs: config: smoke.toml count: 1 timeout: 10m - - name: TestSmoke + - name: TestSmokeLimitedResources config: smoke_limited_resources.toml count: 1 timeout: 10m @@ -81,9 +81,9 @@ jobs: - 'framework/**' - '.github/workflows/framework-golden-tests.yml' - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: 1.22.8 + go-version: 1.23 - name: Cache Go modules uses: actions/cache@v3 with: @@ -104,7 +104,7 @@ jobs: go test -timeout ${{ matrix.test.timeout }} -v -count ${{ matrix.test.count }} -run ${{ matrix.test.name }} - name: Upload Logs if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: container-logs-${{ matrix.test.name }} path: framework/examples/myproject/logs diff --git a/.github/workflows/framework.yml b/.github/workflows/framework.yml index cf5227020..97ea9ea85 100644 --- a/.github/workflows/framework.yml +++ b/.github/workflows/framework.yml @@ -36,9 +36,9 @@ jobs: src: - 'framework/**' - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: 1.22.8 + go-version: 1.23 - name: Cache Go modules uses: actions/cache@v4 with: diff --git a/.github/workflows/generate-go-docs.yaml b/.github/workflows/generate-go-docs.yaml index 59cca7239..a93a29e40 100644 --- a/.github/workflows/generate-go-docs.yaml +++ b/.github/workflows/generate-go-docs.yaml @@ -123,7 +123,7 @@ jobs: rm filtered_folders.json - name: Upload costs as artifact - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v4 with: name: generation-costs path: ./costs diff --git a/.github/workflows/k8s-e2e.yaml b/.github/workflows/k8s-e2e.yaml index 480726d81..269541a28 100644 --- a/.github/workflows/k8s-e2e.yaml +++ b/.github/workflows/k8s-e2e.yaml @@ -98,7 +98,7 @@ jobs: QA_AWS_ROLE_TO_ASSUME: ${{ secrets.QA_AWS_ROLE_TO_ASSUME }} run_setup: false - name: Upload test log - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v4 if: failure() with: name: test-log @@ -146,7 +146,7 @@ jobs: QA_KUBECONFIG: ${{ secrets.QA_KUBECONFIG }} run_setup: false - name: Upload test log - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v4 if: failure() with: name: remote-runner-test-log diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index c282001da..beab894e8 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -76,6 +76,8 @@ jobs: path: ./tools/asciitable/ - name: workflowresultparser path: ./tools/workflowresultparser/ + - name: parrot + path: ./parrot/ steps: - name: Check out Code uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 @@ -98,7 +100,7 @@ jobs: run: test -f ${{ matrix.project.path }}golangci-lint-report.xml || true - name: Store lint report artifact if: always() - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v4 with: name: golangci-lint-report-${{ matrix.project.name }} path: ${{ matrix.project.path }}golangci-lint-report.xml diff --git a/.github/workflows/parrot-release.yml b/.github/workflows/parrot-release.yml new file mode 100644 index 000000000..8a689e19d --- /dev/null +++ b/.github/workflows/parrot-release.yml @@ -0,0 +1,44 @@ +name: Parrotserver Release + +on: + push: + tags: + - parrot/v* + +jobs: + release: + name: Build and Release + runs-on: ubuntu-latest + environment: integration + steps: + - name: Checkout repo + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + with: + fetch-depth: 0 + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: ${{ secrets.QA_AWS_REGION }} + role-to-assume: ${{ secrets.QA_AWS_ROLE_TO_ASSUME }} + role-duration-seconds: 600 + - name: Login to Amazon ECR + uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 + with: + mask-password: 'true' + env: + AWS_REGION: ${{ secrets.QA_AWS_REGION }} + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: stable + - name: Goreleaser Release + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser-pro + version: "~> v2" + args: release --clean -f ./parrot/.goreleaser.yml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GORELEASER_KEY: ${{ secrets.GORELEASER_KEY }} + IMAGE_PREFIX: ${{ secrets.QA_AWS_ACCOUNT_NUMBER }}.dkr.ecr.${{ secrets.QA_AWS_REGION }}.amazonaws.com/parrot + IMAGE_TAG: ${{ github.ref_name}} diff --git a/.github/workflows/rc-breaking-changes.yaml b/.github/workflows/rc-breaking-changes.yaml index 428d53e20..7673da002 100644 --- a/.github/workflows/rc-breaking-changes.yaml +++ b/.github/workflows/rc-breaking-changes.yaml @@ -16,10 +16,10 @@ jobs: with: fetch-depth: 0 fetch-tags: true - - name: Set up Go 1.23.3 + - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23.3' + go-version: 1.23 - name: Install gorelease tool run: | go install golang.org/x/exp/cmd/gorelease@latest diff --git a/.github/workflows/release-go-module.yml b/.github/workflows/release-go-module.yml index 925e882ed..5906c0041 100644 --- a/.github/workflows/release-go-module.yml +++ b/.github/workflows/release-go-module.yml @@ -69,7 +69,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23.3' + go-version: 1.23 - name: Install gorelease tool run: | go install golang.org/x/exp/cmd/gorelease@latest diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d595356b5..bc08cad0b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -26,6 +26,8 @@ jobs: path: ./tools/flakeguard/ - name: workflowresultparser path: ./tools/workflowresultparser/ + - name: parrot + path: ./parrot/ runs-on: ubuntu-latest name: ${{ matrix.project.name }} unit tests steps: @@ -52,7 +54,7 @@ jobs: make test_unit" - name: Publish Artifacts if: failure() - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v4 with: name: test-logs path: /tmp/gotest.log diff --git a/.gitignore b/.gitignore index 4073b8808..2f70527c0 100644 --- a/.gitignore +++ b/.gitignore @@ -77,4 +77,9 @@ __debug* .tool-versions import_keys_test.go -tag.py \ No newline at end of file +tag.py + +parrot/*.json +parrot/*.log +# Executable +parrot/parrot \ No newline at end of file diff --git a/parrot/.dockerignore b/parrot/.dockerignore new file mode 100644 index 000000000..fdca67dcb --- /dev/null +++ b/parrot/.dockerignore @@ -0,0 +1,13 @@ +Dockerfile +*.md +*.log +.gitignore +.golangci-lint.yml +.goreleaser.yml +.pre-commit-config.yaml +*_test.go +LICENSE +.vscode/ +dist/ +.github/ +save.json \ No newline at end of file diff --git a/parrot/.goreleaser.yaml b/parrot/.goreleaser.yaml new file mode 100644 index 000000000..14867fff3 --- /dev/null +++ b/parrot/.goreleaser.yaml @@ -0,0 +1,59 @@ +# yaml-language-server: $schema=https://goreleaser.com/static/schema-pro.json +version: 2 +project_name: parrot + +monorepo: + tag_prefix: parrot/ + dir: parrot + +env: + - IMG_PRE={{ if index .Env "IMAGE_PREFIX" }}{{ .Env.IMAGE_PREFIX }}{{ else }}local{{ end }} + - TAG={{ if index .Env "IMAGE_TAG" }}{{ .Env.IMAGE_TAG }}{{ else }}latest{{ end }} + +# Build settings for binaries +builds: + - id: parrot + main: ./cmd/main.go + goos: + - linux + - darwin + goarch: + - amd64 + - arm64 + ldflags: + - '-s -w' + +archives: + - formats: ['binary'] + +dockers: + - id: linux-amd64-parrot + goos: linux + goarch: amd64 + image_templates: + - '{{ .Env.IMG_PRE }}/parrot:{{ .Tag }}' + - '{{ .Env.IMG_PRE }}/parrot:latest' + build_flag_templates: + - --platform=linux/amd64 + - --pull + - --label=org.opencontainers.image.created={{.Date}} + - --label=org.opencontainers.image.title={{.ProjectName}} + - --label=org.opencontainers.image.revision={{.FullCommit}} + - --label=org.opencontainers.image.version={{.Version}} + - id: linux-arm64-parrot + goos: linux + goarch: arm64 + image_templates: + - '{{ .Env.IMG_PRE }}/parrot:{{ .Tag }}-arm64' + - '{{ .Env.IMG_PRE }}/parrot:latest-arm64' + build_flag_templates: + - --platform=linux/arm64 + - --pull + - --label=org.opencontainers.image.created={{.Date}} + - --label=org.opencontainers.image.title={{.ProjectName}} + - --label=org.opencontainers.image.revision={{.FullCommit}} + - --label=org.opencontainers.image.version={{.Version}} + +before: + hooks: + - cd parrot && go mod tidy \ No newline at end of file diff --git a/parrot/Dockerfile b/parrot/Dockerfile new file mode 100644 index 000000000..0c4dfee03 --- /dev/null +++ b/parrot/Dockerfile @@ -0,0 +1,3 @@ +FROM scratch +COPY parrotserver /parrotserver +ENTRYPOINT [ "parrotserver" ] \ No newline at end of file diff --git a/parrot/Makefile b/parrot/Makefile new file mode 100644 index 000000000..478087dec --- /dev/null +++ b/parrot/Makefile @@ -0,0 +1,37 @@ +# Default test log level (can be overridden) +PARROT_TEST_LOG_LEVEL ?= "" + +# Pass TEST_LOG_LEVEL as a flag to go test +TEST_ARGS ?= -testLogLevel=$(PARROT_TEST_LOG_LEVEL) + +.PHONY: lint +lint: + golangci-lint --color=always run ./... --fix -v + +.PHONY: test +test: + go install github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + set -euo pipefail + go test $(TEST_ARGS) -json -cover -coverprofile cover.out -v ./... 2>&1 | tee /tmp/gotest.log | gotestfmt + +.PHONY: test_race +test_race: + go install github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + set -euo pipefail + go test $(TEST_ARGS) -json -cover -count=1 -race -coverprofile cover.out -v ./... 2>&1 | tee /tmp/gotest.log | gotestfmt + +.PHONY: test_unit +test_unit: + go test $(TEST_ARGS) -coverprofile cover.out ./... + +.PHONY: bench +bench: + go test $(TEST_ARGS) -bench=. -run=^$$ ./... + +.PHONY: build +build: + go build -o ./parrot ./cmd + +.PHONY: goreleaser +goreleaser: + cd .. && goreleaser build --snapshot --clean -f ./parrot/.goreleaser.yaml \ No newline at end of file diff --git a/parrot/README.md b/parrot/README.md new file mode 100644 index 000000000..cb9fc8011 --- /dev/null +++ b/parrot/README.md @@ -0,0 +1,35 @@ +# Parrot Server + +A simple, high-performing mockserver that can dynamically build new routes with customized responses, parroting back whatever you tell it to. + +## Features + +* Simplistic and fast design +* Run within your Go code, through a small binary, or in a minimal Docker container +* Easily record all incoming requests to the server to programmatically react to + +## Use + +See our runnable examples in [examples_test.go](./examples_test.go) to see how to use Parrot programmatically. + +## Run + +```sh +go run ./cmd +go run ./cmd -h # See all config options +``` + +## Test + +```sh +make test +make test PARROT_TEST_LOG_LEVEL=trace # Set log level for tests +make test_race # Test with -race flag enabled +make bench # Benchmark +``` + +## Build + +```sh +make goreleaser # Uses goreleaser to build binaries and docker containers +``` diff --git a/parrot/cmd/main.go b/parrot/cmd/main.go new file mode 100644 index 000000000..245fcb69f --- /dev/null +++ b/parrot/cmd/main.go @@ -0,0 +1,83 @@ +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/smartcontractkit/chainlink-testing-framework/parrot" + "github.com/spf13/cobra" +) + +func main() { + var ( + port int + debug bool + trace bool + silent bool + json bool + recorders []string + ) + + rootCmd := &cobra.Command{ + Use: "parrot", + Short: "A server that can register and parrot back dynamic requests", + RunE: func(cmd *cobra.Command, args []string) error { + options := []parrot.ServerOption{parrot.WithPort(port)} + logLevel := zerolog.InfoLevel + if debug { + logLevel = zerolog.DebugLevel + } + if trace { + logLevel = zerolog.TraceLevel + } + if silent { + logLevel = zerolog.Disabled + } + options = append(options, parrot.WithLogLevel(logLevel)) + if json { + options = append(options, parrot.WithJSONLogs()) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + p, err := parrot.Wake(options...) + if err != nil { + return err + } + + for _, r := range recorders { + err = p.Record(r) + if err != nil { + return err + } + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + err = p.Shutdown(ctx) + if err != nil { + log.Error().Err(err).Msg("Error putting parrot to sleep") + } + return nil + }, + } + + rootCmd.Flags().IntVarP(&port, "port", "p", 0, "Port to run the parrot on") + rootCmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug output") + rootCmd.Flags().BoolVarP(&trace, "trace", "t", false, "Enable trace and debug output") + rootCmd.Flags().BoolVarP(&silent, "silent", "s", false, "Disable all output") + rootCmd.Flags().BoolVarP(&json, "json", "j", false, "Output logs in JSON format") + rootCmd.Flags().StringSliceVarP(&recorders, "recorders", "r", nil, "Existing recorders to use") + + if err := rootCmd.Execute(); err != nil { + log.Error().Err(err).Msg("error executing command") + os.Exit(1) + } +} diff --git a/parrot/errors.go b/parrot/errors.go new file mode 100644 index 000000000..50032b700 --- /dev/null +++ b/parrot/errors.go @@ -0,0 +1,43 @@ +package parrot + +import ( + "errors" + "fmt" +) + +var ( + ErrNilRoute = errors.New("route is nil") + ErrNoMethod = errors.New("no method specified") + ErrInvalidPath = errors.New("invalid path") + ErrNoResponse = errors.New("route must have a handler or some response") + ErrOnlyOneResponse = errors.New("route can only have one response type") + ErrResponseMarshal = errors.New("unable to marshal response body to JSON") + ErrRouteNotFound = errors.New("route not found") + + ErrNoRecorderURL = errors.New("no recorder URL specified") + ErrInvalidRecorderURL = errors.New("invalid recorder URL") + ErrRecorderNotFound = errors.New("recorder not found") + + ErrServerShutdown = errors.New("parrot is already asleep") +) + +// Custom error type to help add more detail to base errors +type dynamicError struct { + Base error // Base error for comparison + Extra string // Dynamic context (e.g., method name) +} + +func (e *dynamicError) Error() string { + return fmt.Sprintf("%s: %s", e.Base.Error(), e.Extra) +} + +func (e *dynamicError) Unwrap() error { + return e.Base +} + +func newDynamicError(base error, detail string) error { + return &dynamicError{ + Base: base, + Extra: detail, + } +} diff --git a/parrot/examples_test.go b/parrot/examples_test.go new file mode 100644 index 000000000..ab89b5059 --- /dev/null +++ b/parrot/examples_test.go @@ -0,0 +1,315 @@ +package parrot_test + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog" + "github.com/smartcontractkit/chainlink-testing-framework/parrot" +) + +func ExampleServer_Register_internal() { + // Create a new parrot instance with no logging and a custom save file + saveFile := "register_example.json" + p, err := parrot.Wake(parrot.WithLogLevel(zerolog.NoLevel), parrot.WithSaveFile(saveFile)) + if err != nil { + panic(err) + } + defer func() { // Cleanup the parrot instance + err = p.Shutdown(context.Background()) // Gracefully shutdown the parrot instance + if err != nil { + panic(err) + } + p.WaitShutdown() // Wait for the parrot instance to shutdown. Usually unnecessary, but we want to clean up the save file + os.Remove(saveFile) // Cleanup the save file for the example + }() + + // Create a new route /test that will return a 200 status code with a text/plain response body of "Squawk" + route := &parrot.Route{ + Method: http.MethodGet, + Path: "/test", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + + // Register the route with the parrot instance + err = p.Register(route) + if err != nil { + panic(err) + } + + // Call the route + resp, err := p.Call(http.MethodGet, "/test") + if err != nil { + panic(err) + } + fmt.Println(resp.StatusCode()) + fmt.Println(string(resp.Body())) + + // Get all routes from the parrot instance + routes := p.Routes() + fmt.Println(len(routes)) + + // Delete the route + err = p.Delete(route.ID()) + if err != nil { + panic(err) + } + + // Get all routes from the parrot instance + routes = p.Routes() + fmt.Println(len(routes)) + // Output: + // 200 + // Squawk + // 1 + // 0 +} + +func ExampleServer_Register_external() { + var ( + saveFile = "route_example.json" + port = 9090 + ) + defer os.Remove(saveFile) // Cleanup the save file for the example + + go func() { // Run the parrot server as a separate instance, like in a Docker container + _, err := parrot.Wake(parrot.WithPort(port), parrot.WithLogLevel(zerolog.NoLevel), parrot.WithSaveFile(saveFile)) + if err != nil { + panic(err) + } + }() + + // Code that calls the parrot server from another service + // Use resty to make HTTP calls to the parrot server + client := resty.New() + client.SetBaseURL(fmt.Sprintf("http://localhost:%d", port)) // The URL of the parrot server + + waitForParrotServer(client, time.Second) // Wait for the parrot server to start + + // Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk" + route := &parrot.Route{ + Method: http.MethodGet, + Path: "/test", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + resp, err := client.R().SetBody(route).Post("/routes") + if err != nil { + panic(err) + } + defer resp.RawResponse.Body.Close() + fmt.Println(resp.StatusCode()) + + // Get all routes from the parrot server + routes := make([]*parrot.Route, 0) + resp, err = client.R().SetResult(&routes).Get("/routes") + if err != nil { + panic(err) + } + defer resp.RawResponse.Body.Close() + fmt.Println(resp.StatusCode()) + fmt.Println(len(routes)) + + // Delete the route + resp, err = client.R().SetBody(route).Delete("/routes") + if err != nil { + panic(err) + } + defer resp.RawResponse.Body.Close() + fmt.Println(resp.StatusCode()) + + // Get all routes from the parrot server + routes = make([]*parrot.Route, 0) + resp, err = client.R().SetResult(&routes).Get("/routes") + if err != nil { + panic(err) + } + defer resp.RawResponse.Body.Close() + fmt.Println(len(routes)) + + // Output: + // 201 + // 200 + // 1 + // 204 + // 0 +} + +func ExampleRecorder_internal() { + saveFile := "recorder_example.json" + p, err := parrot.Wake(parrot.WithLogLevel(zerolog.NoLevel), parrot.WithSaveFile(saveFile)) + if err != nil { + panic(err) + } + defer func() { // Cleanup the parrot instance + err = p.Shutdown(context.Background()) // Gracefully shutdown the parrot instance + if err != nil { + panic(err) + } + p.WaitShutdown() // Wait for the parrot instance to shutdown. Usually unnecessary, but we want to clean up the save file + os.Remove(saveFile) // Cleanup the save file for the example + }() + + // Create a new recorder + recorder, err := parrot.NewRecorder() + if err != nil { + panic(err) + } + + // Register the recorder with the parrot instance + err = p.Record(recorder.URL()) + if err != nil { + panic(err) + } + defer recorder.Close() + + // Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk" + route := &parrot.Route{ + Method: http.MethodGet, + Path: "/test", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + err = p.Register(route) + if err != nil { + panic(err) + } + + // Call the route + go func() { + _, err := p.Call(http.MethodGet, "/test") + if err != nil { + panic(err) + } + }() + + // Record the route call + for { + select { + case recordedRouteCall := <-recorder.Record(): + if recordedRouteCall.RouteID == route.ID() { + fmt.Println(recordedRouteCall.RouteID) + fmt.Println(recordedRouteCall.Request.Method) + fmt.Println(recordedRouteCall.Response.StatusCode) + fmt.Println(string(recordedRouteCall.Response.Body)) + return + } + case err := <-recorder.Err(): + panic(err) + } + } + // Output: + // GET:/test + // GET + // 200 + // Squawk +} + +// Example of how to use parrot recording when calling it from an external service +func ExampleRecorder_external() { + var ( + saveFile = "recorder_example.json" + port = 9091 + ) + defer os.Remove(saveFile) // Cleanup the save file for the example + + go func() { // Run the parrot server as a separate instance, like in a Docker container + _, err := parrot.Wake(parrot.WithPort(port), parrot.WithLogLevel(zerolog.NoLevel), parrot.WithSaveFile(saveFile)) + if err != nil { + panic(err) + } + }() + + client := resty.New() + client.SetBaseURL(fmt.Sprintf("http://localhost:%d", port)) // The URL of the parrot server + + waitForParrotServer(client, time.Second) // Wait for the parrot server to start + + // Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk" + route := &parrot.Route{ + Method: http.MethodGet, + Path: "/test", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + + // Register the route with the parrot instance + resp, err := client.R().SetBody(route).Post("/routes") + if err != nil { + panic(err) + } + + // Use the host of the machine your recorder is running on + // This should not be localhost if you are running the parrot server on a different machine + // It should be the public IP address of the machine running your code, so that the parrot can call back to it + host := "localhost" + + // Create a new recorder with our host + recorder, err := parrot.NewRecorder(parrot.WithHost(host)) + if err != nil { + panic(err) + } + + // Register the recorder with the parrot instance + resp, err = client.R().SetBody(recorder).Post("/record") + if err != nil { + panic(err) + } + if resp.StatusCode() != http.StatusCreated { + panic(fmt.Sprintf("failed to register recorder, got %d status code", resp.StatusCode())) + } + + go func() { // Some other service calls the /test route + _, err := client.R().Get("/test") + if err != nil { + panic(err) + } + }() + + // You can now listen to the recorder for all route calls + for { + select { + case recordedRouteCall := <-recorder.Record(): + if recordedRouteCall.RouteID == route.ID() { + fmt.Println(recordedRouteCall.RouteID) + fmt.Println(recordedRouteCall.Request.Method) + fmt.Println(recordedRouteCall.Response.StatusCode) + fmt.Println(string(recordedRouteCall.Response.Body)) + return + } + case err := <-recorder.Err(): + panic(err) + } + } + // Output: + // GET:/test + // GET + // 200 + // Squawk +} + +// waitForParrotServer checks the parrot server health endpoint until it returns a 200 status code or the timeout is reached +func waitForParrotServer(client *resty.Client, timeoutDur time.Duration) { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + timeout := time.NewTimer(timeoutDur) + for { // Wait for the parrot server to start + select { + case <-ticker.C: + resp, err := client.R().Get("/health") + if err != nil { + continue + } + if resp.StatusCode() == http.StatusOK { + return + } + case <-timeout.C: + panic("timeout waiting for parrot server to start") + } + } +} diff --git a/parrot/go.mod b/parrot/go.mod new file mode 100644 index 000000000..18fc13a2e --- /dev/null +++ b/parrot/go.mod @@ -0,0 +1,24 @@ +module github.com/smartcontractkit/chainlink-testing-framework/parrot + +go 1.23.4 + +require ( + github.com/go-resty/resty/v2 v2.16.3 + github.com/google/uuid v1.6.0 + github.com/rs/zerolog v1.33.0 + github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rs/xid v1.5.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sys v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/parrot/go.sum b/parrot/go.sum new file mode 100644 index 000000000..da5bd8c5f --- /dev/null +++ b/parrot/go.sum @@ -0,0 +1,43 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-resty/resty/v2 v2.16.3 h1:zacNT7lt4b8M/io2Ahj6yPypL7bqx9n1iprfQuodV+E= +github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/parrot/parrot.go b/parrot/parrot.go new file mode 100644 index 000000000..7ba4e30ac --- /dev/null +++ b/parrot/parrot.go @@ -0,0 +1,756 @@ +package parrot + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-resty/resty/v2" + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" +) + +const ( + healthRoute = "/health" + routesRoute = "/routes" + recordRoute = "/record" +) + +// Route holds information about the mock route configuration +type Route struct { + // Method is the HTTP method to match + Method string `json:"Method"` + // Path is the URL path to match + Path string `json:"Path"` + // Handler is the dynamic handler function to use when called + // Can only be set upon creation of the server + Handler http.HandlerFunc `json:"-"` + // RawResponseBody is the static, raw string response to return when called + RawResponseBody string `json:"raw_response_body"` + // ResponseBody will be marshalled to JSON and returned when called + ResponseBody any `json:"response_body"` + // ResponseStatusCode is the HTTP status code to return when called + ResponseStatusCode int `json:"response_status_code"` +} + +// ID returns the unique identifier for the route +func (r *Route) ID() string { + return r.Method + ":" + r.Path +} + +// Server is a mock HTTP server that can register and respond to dynamic routes +type Server struct { + port int + host string + address string + + client *resty.Client + shutDown bool + shutDownChan chan struct{} + shutDownOnce sync.Once + saveFileName string + useCustomLogger bool + logFileName string + logFile *os.File + logLevel zerolog.Level + jsonLogs bool + log zerolog.Logger + + server *http.Server + routes map[string]*Route // Store routes based on "Method:Path" keys + routesMu sync.RWMutex + + recorderHooks map[string]struct{} // Store recorders based on URL keys to avoid duplicates + recordersMu sync.RWMutex +} + +// SaveFile is the structure of the file to save and load parrot data from +type SaveFile struct { + Routes []*Route `json:"routes"` + Recorders []string `json:"recorders"` +} + +// ServerOption defines functional options for configuring the ParrotServer +type ServerOption func(*Server) error + +// WithPort sets the port for the ParrotServer to run on +func WithPort(port int) ServerOption { + return func(s *Server) error { + if port < 0 || port > 65535 { + return fmt.Errorf("invalid port: %d", port) + } + s.port = port + return nil + } +} + +// WithLogLevel sets the visible log level of the default logger +func WithLogLevel(level zerolog.Level) ServerOption { + return func(s *Server) error { + s.logLevel = level + return nil + } +} + +// WithLogger sets the logger for the ParrotServer +func WithLogger(l zerolog.Logger) ServerOption { + return func(s *Server) error { + s.log = l + s.useCustomLogger = true + return nil + } +} + +// WithJSONLogs sets the logger to output JSON logs +func WithJSONLogs() ServerOption { + return func(s *Server) error { + s.jsonLogs = true + return nil + } +} + +// WithSaveFile sets the file to save the routes to +func WithSaveFile(saveFile string) ServerOption { + return func(s *Server) error { + if saveFile == "" { + return fmt.Errorf("invalid save file name: %s", saveFile) + } + s.saveFileName = saveFile + return nil + } +} + +// WithLogFile sets the file to save the logs to +func WithLogFile(logFile string) ServerOption { + return func(s *Server) error { + if logFile == "" { + return fmt.Errorf("invalid log file name: %s", logFile) + } + s.logFileName = logFile + return nil + } +} + +// WithRoutes sets the initial routes for the Parrot +func WithRoutes(routes []*Route) ServerOption { + return func(s *Server) error { + for _, route := range routes { + if err := s.Register(route); err != nil { + return fmt.Errorf("failed to register route: %w", err) + } + } + return nil + } +} + +// Wake creates a new Parrot server with dynamic route handling +func Wake(options ...ServerOption) (*Server, error) { + p := &Server{ + port: 0, + saveFileName: "parrot_save.json", + logLevel: zerolog.InfoLevel, + logFileName: "parrot.log", + + client: resty.New(), + shutDownChan: make(chan struct{}), + + routes: make(map[string]*Route), + routesMu: sync.RWMutex{}, + + recorderHooks: make(map[string]struct{}), + recordersMu: sync.RWMutex{}, + } + + for _, option := range options { + if err := option(p); err != nil { + return nil, err + } + } + + var err error + p.logFile, err = os.Create(p.logFileName) + if err != nil { + return nil, fmt.Errorf("failed to create log file: %w", err) + } + + if !p.useCustomLogger { // Build default logger + var writers []io.Writer + + if p.jsonLogs { + writers = append(writers, os.Stderr) + } else { + consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02T15:04:05.000"} + writers = append(writers, consoleOut) + } + + if p.logFile != nil { + writers = append(writers, p.logFile) + } + + multiWriter := zerolog.MultiLevelWriter(writers...) + p.log = zerolog.New(multiWriter).Level(p.logLevel).With().Timestamp().Logger() + } + + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port)) + if err != nil { + return nil, fmt.Errorf("failed to start listener: %w", err) + } + host, port, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + return nil, fmt.Errorf("failed to split host and port: %w", err) + } + p.host = host + p.address = listener.Addr().String() + p.port, err = strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("failed to parse port: %w", err) + } + + mux := http.NewServeMux() + // TODO: Add a route to enable registering recorders + mux.HandleFunc(routesRoute, p.routeHandler) + mux.HandleFunc(recordRoute, p.recordHandler) + mux.HandleFunc(healthRoute, p.healthHandler) + mux.HandleFunc("/", p.dynamicHandler) + + p.server = &http.Server{ + ReadHeaderTimeout: 5 * time.Second, + Addr: listener.Addr().String(), + Handler: p.loggingMiddleware(mux), + } + + if err = p.load(); err != nil { + return nil, fmt.Errorf("failed to load data from '%s': %w", p.saveFileName, err) + } + + go p.run(listener) + + return p, nil +} + +// run starts the parrot server +func (p *Server) run(listener net.Listener) { + defer func() { + p.shutDown = true + if err := p.save(); err != nil { + p.log.Error().Err(err).Msg("Failed to save routes") + } + if err := p.logFile.Close(); err != nil { + p.log.Error().Err(err).Msg("Failed to close log file") + } + p.shutDownOnce.Do(func() { + close(p.shutDownChan) + }) + }() + + p.log.Info().Str("Address", p.address).Msg("Parrot awake and ready to squawk") + p.log.Debug().Str("Save File", p.saveFileName).Str("Log File", p.logFileName).Msg("Configuration") + if err := p.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + p.log.Fatal().Err(err).Msg("Error while running server") + } +} + +// Shutdown gracefully shuts down the parrot server +func (p *Server) Shutdown(ctx context.Context) error { + if p.shutDown { + return ErrServerShutdown + } + + p.log.Info().Msg("Putting cloth over the parrot's cage...") + return p.server.Shutdown(ctx) +} + +// WaitShutdown blocks until the parrot server has shut down +func (p *Server) WaitShutdown() { + <-p.shutDownChan +} + +// Address returns the address the parrot is running on +func (p *Server) Address() string { + return p.address +} + +// Register adds a new route to the parrot +func (p *Server) Register(route *Route) error { + if p.shutDown { + return ErrServerShutdown + } + if route == nil { + return ErrNilRoute + } + if !isValidPath(route.Path) { + return newDynamicError(ErrInvalidPath, fmt.Sprintf("'%s'", route.Path)) + } + if route.Method == "" { + return ErrNoMethod + } + if route.Handler == nil && route.ResponseBody == nil && route.RawResponseBody == "" { + return ErrNoResponse + } + if route.Handler != nil && (route.ResponseBody != nil || route.RawResponseBody != "") { + return newDynamicError(ErrOnlyOneResponse, "handler and another response type provided") + } + if route.ResponseBody != nil && route.RawResponseBody != "" { + return ErrOnlyOneResponse + } + if route.ResponseBody != nil { + if _, err := json.Marshal(route.ResponseBody); err != nil { + return newDynamicError(ErrResponseMarshal, err.Error()) + } + } + + p.routesMu.Lock() + defer p.routesMu.Unlock() + p.routes[route.ID()] = route + p.log.Info(). + Str("Route ID", route.ID()). + Str("Path", route.Path). + Str("Method", route.Method). + Msg("Route registered") + + return nil +} + +// Record registers a new recorder with the parrot. All incoming requests to the parrot will be sent to the recorder. +func (p *Server) Record(recorderURL string) error { + if p.shutDown { + return ErrServerShutdown + } + + p.recordersMu.Lock() + defer p.recordersMu.Unlock() + if recorderURL == "" { + return ErrNoRecorderURL + } + _, err := url.ParseRequestURI(recorderURL) + if err != nil { + return ErrInvalidRecorderURL + } + p.recorderHooks[recorderURL] = struct{}{} + return nil +} + +// Recorders returns the URLs of all registered recorders +func (p *Server) Recorders() []string { + if p.shutDown { + return nil + } + + p.recordersMu.RLock() + defer p.recordersMu.RUnlock() + recorders := make([]string, 0, len(p.recorderHooks)) + for recorder := range p.recorderHooks { + recorders = append(recorders, recorder) + } + return recorders +} + +// Delete removes a route from the parrot +func (p *Server) Delete(routeID string) error { + if p.shutDown { + return ErrServerShutdown + } + + p.routesMu.RLock() + _, exists := p.routes[routeID] + p.routesMu.RUnlock() + + if !exists { + return newDynamicError(ErrRouteNotFound, routeID) + } + p.routesMu.Lock() + defer p.routesMu.Unlock() + delete(p.routes, routeID) + return nil +} + +// Call makes a request to the parrot server +func (p *Server) Call(method, path string) (*resty.Response, error) { + if p.shutDown { + return nil, ErrServerShutdown + } + return p.client.R().Execute(method, "http://"+filepath.Join(p.Address(), path)) +} + +func (p *Server) Routes() []*Route { + p.routesMu.RLock() + defer p.routesMu.RUnlock() + + routes := make([]*Route, 0, len(p.routes)) + for _, route := range p.routes { + routes = append(routes, route) + } + return routes +} + +// routeHandler handles registering, unregistering, and querying routes +func (p *Server) routeHandler(w http.ResponseWriter, r *http.Request) { + routesLogger := zerolog.Ctx(r.Context()) + if r.Method == http.MethodDelete { + var route *Route + if err := json.NewDecoder(r.Body).Decode(&route); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + routesLogger.Debug().Err(err).Msg("Failed to decode request body") + return + } + defer r.Body.Close() + + err := p.Delete(route.ID()) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + routesLogger.Debug().Err(err).Msg("Failed to unregister route") + return + } + + w.WriteHeader(http.StatusNoContent) + routesLogger.Info(). + Str("Route ID", route.ID()). + Msg("Route deleted") + return + } + + if r.Method == http.MethodPost { + var route *Route + if err := json.NewDecoder(r.Body).Decode(&route); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + routesLogger.Debug().Err(err).Msg("Failed to decode request body") + return + } + defer r.Body.Close() + + err := p.Register(route) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + routesLogger.Debug().Err(err).Msg("Failed to register route") + return + } + + w.WriteHeader(http.StatusCreated) + return + } + + if r.Method == http.MethodGet { + routes := p.Routes() + jsonRoutes, err := json.Marshal(routes) + if err != nil { + http.Error(w, "Failed to marshal routes", http.StatusInternalServerError) + routesLogger.Debug().Err(err).Msg("Failed to marshal routes") + return + } + + w.Header().Set("Content-Type", "application/json") + if _, err = w.Write(jsonRoutes); err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + routesLogger.Debug().Err(err).Msg("Failed to write response") + return + } + + routesLogger.Debug().Int("Count", len(routes)).Msg("Returned routes") + return + } + + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + routesLogger.Debug().Msg("Invalid method") +} + +// dynamicHandler handles all incoming requests and responds based on the registered routes. +func (p *Server) dynamicHandler(w http.ResponseWriter, r *http.Request) { + p.routesMu.RLock() + route, exists := p.routes[r.Method+":"+r.URL.Path] + p.routesMu.RUnlock() + + dynamicLogger := zerolog.Ctx(r.Context()) + if !exists { + http.NotFound(w, r) + dynamicLogger.Debug().Msg("Route not found") + return + } + + routeCallID := uuid.New().String()[0:8] + dynamicLogger.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("Route Call ID", routeCallID).Str("Route ID", route.ID()) + }) + + requestBody, err := io.ReadAll(r.Body) + if err != nil { + dynamicLogger.Debug(). + Err(err). + Msg("Failed to read request body") + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + + routeCall := &RouteCall{ + RouteCallID: routeCallID, + RouteID: r.Method + ":" + r.URL.Path, + Request: &RouteCallRequest{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Body: requestBody, + RemoteAddr: r.RemoteAddr, + }, + } + recordingWriter := newResponseWriterRecorder(w) + + defer func() { + res := recordingWriter.Result() + resBody, err := io.ReadAll(res.Body) + if err != nil { + dynamicLogger.Debug().Err(err).Msg("Failed to read response body") + http.Error(w, "Failed to read response body", http.StatusInternalServerError) + return + } + + routeCall.Response = &RouteCallResponse{ + StatusCode: res.StatusCode, + Header: res.Header, + Body: resBody, + } + p.sendToRecorders(routeCall) + }() + + // Let the custom handler take over if it exists + if route.Handler != nil { + dynamicLogger.Debug().Msg("Calling route handler") + route.Handler(recordingWriter, r) + return + } + + recordingWriter.WriteHeader(route.ResponseStatusCode) + + if route.RawResponseBody != "" { + if _, err := recordingWriter.Write([]byte(route.RawResponseBody)); err != nil { + dynamicLogger.Debug().Err(err).Msg("Failed to write response") + http.Error(recordingWriter, "Failed to write response", http.StatusInternalServerError) + return + } + dynamicLogger.Debug(). + Str("Response", route.RawResponseBody). + Msg("Returned raw response") + recordingWriter.WriteHeader(route.ResponseStatusCode) + return + } + + if route.ResponseBody != nil { + rawJSON, err := json.Marshal(route.ResponseBody) + if err != nil { + dynamicLogger.Debug().Err(err).Msg("Failed to marshal JSON response") + http.Error(recordingWriter, "Failed to marshal response into json", http.StatusInternalServerError) + return + } + if _, err = recordingWriter.Write(rawJSON); err != nil { + dynamicLogger.Debug().Err(err). + RawJSON("Response", rawJSON). + Msg("Failed to write response") + http.Error(recordingWriter, "Failed to write JSON response", http.StatusInternalServerError) + return + } + dynamicLogger.Debug(). + RawJSON("Response", rawJSON). + Msg("Returned JSON response") + recordingWriter.WriteHeader(route.ResponseStatusCode) + return + } + + dynamicLogger.Error().Msg("Route has no response") + http.Error(recordingWriter, "Route has no response", http.StatusInternalServerError) +} + +// recordHandler handles registering recorders with the parrot +func (p *Server) recordHandler(w http.ResponseWriter, r *http.Request) { + recordingLogger := zerolog.Ctx(r.Context()) + if r.Method != http.MethodPost { + http.Error(w, "Invalid method", http.StatusMethodNotAllowed) + recordingLogger.Debug().Msg("Invalid method") + return + } + + var recorder *Recorder + if err := json.NewDecoder(r.Body).Decode(&recorder); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + recordingLogger.Debug().Err(err).Msg("Failed to decode request body") + return + } + defer r.Body.Close() + + if recorder == nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + recordingLogger.Debug().Msg("No recorder provided") + return + } + + if recorder.URL() == "" { + http.Error(w, "Recorder URL required", http.StatusBadRequest) + recordingLogger.Debug().Msg("No recorder URL provided") + return + } + + if err := p.Record(recorder.URL()); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + recordingLogger.Debug().Err(err).Msg("Failed to register recorder") + return + } + + w.WriteHeader(http.StatusCreated) + recordingLogger.Debug().Str("URL", recorder.URL()).Msg("Recorder added") +} + +func (p *Server) healthHandler(w http.ResponseWriter, _ *http.Request) { + if p.shutDown { + http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) + return + } + + w.WriteHeader(http.StatusOK) +} + +// load loads all registered routes from a file. +func (p *Server) load() error { + if _, err := os.Stat(p.saveFileName); os.IsNotExist(err) { + p.log.Trace().Str("file", p.saveFileName).Msg("No data to load") + return nil + } + + p.log.Debug().Str("File", p.saveFileName).Msg("Loading data") + + fileData, err := os.ReadFile(p.saveFileName) + if err != nil { + return fmt.Errorf("failed to read routes from file: %w", err) + } + if len(fileData) == 0 { + p.log.Trace().Str("File", p.saveFileName).Msg("No data to load") + return nil + } + + var saveData SaveFile + err = json.Unmarshal(fileData, &saveData) + if err != nil { + return fmt.Errorf("failed to unmarshal save file: %w", err) + } + + for _, route := range saveData.Routes { + if err = p.Register(route); err != nil { + return fmt.Errorf("failed to register route: %w", err) + } + } + + for _, recorder := range saveData.Recorders { + if err = p.Record(recorder); err != nil { + return fmt.Errorf("failed to register recorder: %w", err) + } + } + + p.log.Info().Str("file", p.saveFileName).Int("number", len(p.routes)).Msg("Loaded routes") + return nil +} + +// save saves all registered routes to a file. +func (p *Server) save() error { + saveFile := &SaveFile{ + Routes: p.Routes(), + Recorders: p.Recorders(), + } + if len(saveFile.Routes) == 0 && len(saveFile.Recorders) == 0 { + p.log.Trace().Str("File", p.saveFileName).Msg("No data to save") + return nil + } + + jsonData, err := json.Marshal(saveFile) + if err != nil { + return fmt.Errorf("failed to marshal save file: %w", err) + } + + if err = os.WriteFile(p.saveFileName, jsonData, 0644); err != nil { //nolint:gosec + return fmt.Errorf("failed to write to save file: %w", err) + } + + p.log.Debug().Str("File", p.saveFileName).Msg("Saved data") + return nil +} + +// sendToRecorders sends the route call to all registered recorders +func (p *Server) sendToRecorders(routeCall *RouteCall) { + p.recordersMu.RLock() + defer p.recordersMu.RUnlock() + if len(p.recorderHooks) == 0 { + return + } + + client := resty.New() + p.log.Trace().Int("Recorder Count", len(p.recorderHooks)).Str("Route ID", routeCall.RouteID).Msg("Sending route call to recorders") + + for hook := range p.recorderHooks { + go func(hook string) { + resp, err := client.R().SetBody(routeCall).Post(hook) + if err != nil { + p.log.Error().Err(err).Str("Recorder Hook", hook).Msg("Failed to send route call to recorder") + return + } + defer resp.RawResponse.Body.Close() + if resp.IsError() { + p.log.Error(). + Str("Recorder Hook", hook). + Int("Code", resp.StatusCode()). + Str("Response", resp.String()). + Msg("Failed to send route call to recorder") + return + } + p.log.Trace().Str("Route ID", routeCall.RouteID).Str("Recorder Hook", hook).Msg("Route call sent to recorder") + }(hook) + } +} + +func (p *Server) loggingMiddleware(next http.Handler) http.Handler { + h := hlog.NewHandler(p.log) + + accessHandler := hlog.AccessHandler( + func(r *http.Request, status, size int, duration time.Duration) { + hlog.FromRequest(r).Trace(). + Str("Method", r.Method). + Stringer("URL", r.URL). + Int("Status Code", status). + Int("Response Size Bytes", size). + Str("Duration", duration.String()). + Str("Remote Addr", r.RemoteAddr). + Msg("Handled request") + }, + ) + + return h(accessHandler(next)) +} + +var pathRegex = regexp.MustCompile(`^\/[a-zA-Z0-9\-._~%!$&'()*+,;=:@\/]*$`) + +func isValidPath(path string) bool { + switch path { + case "", "/", "//", healthRoute, recordRoute, routesRoute, "/.", "/..": + return false + } + if !strings.HasPrefix(path, "/") { + return false + } + if strings.HasPrefix(path, recordRoute) { + return false + } + if strings.HasPrefix(path, healthRoute) { + return false + } + if strings.HasPrefix(path, routesRoute) { + return false + } + return pathRegex.MatchString(path) +} diff --git a/parrot/parrot_benchmark_test.go b/parrot/parrot_benchmark_test.go new file mode 100644 index 000000000..68d941b13 --- /dev/null +++ b/parrot/parrot_benchmark_test.go @@ -0,0 +1,145 @@ +package parrot + +import ( + "context" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" +) + +func BenchmarkRegisterRoute(b *testing.B) { + saveFile := b.Name() + ".json" + p, err := Wake(WithLogLevel(testLogLevel), WithSaveFile(saveFile)) + require.NoError(b, err) + + defer func() { // Cleanup + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + err := p.Shutdown(ctx) + cancel() + require.NoError(b, err, "error shutting down parrot") + p.WaitShutdown() + os.Remove(saveFile) + }() + + route := &Route{ + Method: "GET", + Path: "/bench", + RawResponseBody: "Benchmark Response", + ResponseStatusCode: http.StatusOK, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := p.Register(route) + require.NoError(b, err) + } + b.StopTimer() +} + +func BenchmarkRouteResponse(b *testing.B) { + saveFile := b.Name() + ".json" + p, err := Wake(WithLogLevel(testLogLevel), WithSaveFile(saveFile)) + require.NoError(b, err) + + defer func() { // Cleanup + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + err := p.Shutdown(ctx) + cancel() + require.NoError(b, err, "error shutting down parrot") + p.WaitShutdown() + os.Remove(saveFile) + }() + + route := &Route{ + Method: "GET", + Path: "/bench", + RawResponseBody: "Benchmark Response", + ResponseStatusCode: http.StatusOK, + } + err = p.Register(route) + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := p.Call(route.Method, route.Path) + require.NoError(b, err) + } + b.StopTimer() +} + +func BenchmarkSave(b *testing.B) { + var ( + routes = []*Route{} + saveFile = "bench_save_routes.json" + ) + + for i := 0; i < 1000; i++ { + routes = append(routes, &Route{ + Method: "GET", + Path: fmt.Sprintf("/bench%d", i), + RawResponseBody: fmt.Sprintf("Squawk %d", i), + ResponseStatusCode: http.StatusOK, + }) + } + p, err := Wake(WithRoutes(routes), WithLogLevel(testLogLevel), WithSaveFile(saveFile)) + require.NoError(b, err) + defer func() { // Cleanup + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + err = p.Shutdown(ctx) + cancel() + require.NoError(b, err, "error shutting down parrot") + p.WaitShutdown() + os.Remove(saveFile) + }() + + b.ResetTimer() // Start measuring time + for i := 0; i < b.N; i++ { + err := p.save() + require.NoError(b, err) + } + b.StopTimer() +} + +func BenchmarkLoad(b *testing.B) { + var ( + routes = []*Route{} + saveFile = "bench_load_routes.json" + ) + b.Cleanup(func() { + os.Remove(saveFile) + }) + + for i := 0; i < 1000; i++ { + routes = append(routes, &Route{ + Method: "GET", + Path: fmt.Sprintf("/bench%d", i), + RawResponseBody: fmt.Sprintf("Squawk %d", i), + ResponseStatusCode: http.StatusOK, + }) + } + p, err := Wake(WithRoutes(routes), WithLogLevel(zerolog.Disabled), WithSaveFile(saveFile)) + require.NoError(b, err, "error waking parrot") + defer func() { // Cleanup + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + err = p.Shutdown(ctx) + cancel() + require.NoError(b, err, "error shutting down parrot") + p.WaitShutdown() + os.Remove(saveFile) + }() + + err = p.save() + require.NoError(b, err, "error saving routes") + + b.ResetTimer() // Start measuring time + for i := 0; i < b.N; i++ { + err := p.load() + require.NoError(b, err) + } + b.StopTimer() +} diff --git a/parrot/parrot_test.go b/parrot/parrot_test.go new file mode 100644 index 000000000..8dd19bde9 --- /dev/null +++ b/parrot/parrot_test.go @@ -0,0 +1,540 @@ +package parrot + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testLogLevel = zerolog.NoLevel + +func TestMain(m *testing.M) { + testLogLevelFlag := "" + flag.StringVar(&testLogLevelFlag, "testLogLevel", "", "a zerolog log level to use for tests") + flag.Parse() + var err error + testLogLevel, err = zerolog.ParseLevel(testLogLevelFlag) + if err != nil { + fmt.Println("error parsing test log level:", err) + os.Exit(1) + } + + os.Exit(m.Run()) +} + +func TestRegisterRoutes(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + testCases := []struct { + name string + route *Route + }{ + { + name: "get route", + route: &Route{ + Method: http.MethodGet, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "json route", + route: &Route{ + Method: http.MethodGet, + Path: "/json", + ResponseBody: map[string]any{"message": "Squawk"}, + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "post route", + route: &Route{ + Method: http.MethodPost, + Path: "/post", + RawResponseBody: "Squawk", + ResponseStatusCode: 201, + }, + }, + { + name: "put route", + route: &Route{ + Method: http.MethodPut, + Path: "/put", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "delete route", + route: &Route{ + Method: http.MethodDelete, + Path: "/delete", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "patch route", + route: &Route{ + Method: http.MethodPatch, + Path: "/patch", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "error route", + route: &Route{ + Method: http.MethodGet, + Path: "/error", + RawResponseBody: "Squawk", + ResponseStatusCode: 500, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := p.Register(tc.route) + require.NoError(t, err, "error registering route") + + resp, err := p.Call(tc.route.Method, tc.route.Path) + require.NoError(t, err, "error calling parrot") + + assert.Equal(t, tc.route.ResponseStatusCode, resp.StatusCode()) + if tc.route.ResponseBody != nil { + jsonBody, err := json.Marshal(tc.route.ResponseBody) + require.NoError(t, err) + assert.JSONEq(t, string(jsonBody), string(resp.Body())) + } else { + assert.Equal(t, tc.route.RawResponseBody, string(resp.Body())) + } + }) + } +} + +func TestGetRoutes(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + routes := []*Route{ + { + Method: http.MethodGet, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + { + Method: http.MethodPost, + Path: "/goodbye", + RawResponseBody: "Squeak", + ResponseStatusCode: 201, + }, + } + + for _, route := range routes { + err := p.Register(route) + require.NoError(t, err, "error registering route") + } + + registeredRoutes := p.Routes() + require.Len(t, registeredRoutes, len(routes)) +} + +func TestIsValidPath(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + paths []string + valid bool + }{ + { + name: "valid paths", + paths: []string{"/hello"}, + valid: true, + }, + { + name: "no protected paths", + paths: []string{healthRoute, routesRoute, recordRoute, fmt.Sprintf("%s/%s", routesRoute, "route-id"), fmt.Sprintf("%s/%s", healthRoute, "recorder-id"), fmt.Sprintf("%s/%s", recordRoute, "recorder-id")}, + valid: false, + }, + { + name: "invalid paths", + paths: []string{"", "/", " ", " /", "/ ", " / ", "invalid path"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + for _, path := range tc.paths { + valid := isValidPath(path) + assert.Equal(t, tc.valid, valid) + } + }) + } +} + +func TestPreRegisterRoutes(t *testing.T) { + t.Parallel() + + routes := []*Route{ + { + Method: http.MethodGet, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + { + Method: http.MethodPost, + Path: "/goodbye", + RawResponseBody: "Squeak", + ResponseStatusCode: 201, + }, + } + + saveFile := t.Name() + ".json" + p, err := Wake(WithSaveFile(saveFile), WithRoutes(routes), WithLogLevel(testLogLevel)) + require.NoError(t, err, "error waking parrot") + + t.Cleanup(func() { + err := p.Shutdown(context.Background()) + assert.NoError(t, err, "error shutting down parrot") + p.WaitShutdown() + os.Remove(saveFile) + }) + + registeredRoutes := p.Routes() + require.Len(t, registeredRoutes, len(routes)) +} + +func TestCustomLogFile(t *testing.T) { + t.Parallel() + + logFile := t.Name() + ".log" + saveFile := t.Name() + ".json" + p, err := Wake(WithLogFile(logFile), WithSaveFile(saveFile), WithLogLevel(zerolog.InfoLevel)) + require.NoError(t, err, "error waking parrot") + + t.Cleanup(func() { + err := p.Shutdown(context.Background()) + assert.NoError(t, err, "error shutting down parrot") + p.WaitShutdown() + os.Remove(logFile) + os.Remove(saveFile) + }) + + // Call a route to generate some logs + route := &Route{ + Method: http.MethodGet, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + err = p.Register(route) + require.NoError(t, err, "error registering route") + + _, err = p.Call(route.Method, route.Path) + require.NoError(t, err, "error calling parrot") + + require.FileExists(t, logFile, "expected log file to exist") + logData, err := os.ReadFile(logFile) + require.NoError(t, err, "error reading log file") + require.Contains(t, string(logData), "GET:/hello", "expected log file to contain route call") +} + +func TestBadRegisterRoute(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + testCases := []struct { + name string + err error + route *Route + }{ + { + name: "nil route", + err: ErrNilRoute, + route: nil, + }, + { + name: "no method", + err: ErrNoMethod, + route: &Route{ + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "no path", + err: ErrInvalidPath, + route: &Route{ + Method: http.MethodGet, + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "base path", + err: ErrInvalidPath, + route: &Route{ + Method: http.MethodGet, + Path: "/", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "invalid path", + err: ErrInvalidPath, + route: &Route{ + Method: http.MethodGet, + Path: "invalid path", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "no response", + err: ErrNoResponse, + route: &Route{ + Method: http.MethodGet, + Path: "/hello", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "invalid url", + err: ErrInvalidPath, + route: &Route{ + Method: http.MethodGet, + Path: "http://example.com", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "multiple responses", + err: ErrOnlyOneResponse, + route: &Route{ + Method: http.MethodGet, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseBody: map[string]any{"message": "Squawk"}, + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "too many responses", + err: ErrOnlyOneResponse, + route: &Route{ + Method: http.MethodGet, + Path: "/hello", + ResponseBody: map[string]any{"message": "Squawk"}, + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("Squawk")) + }, + ResponseStatusCode: http.StatusOK, + }, + }, + { + name: "bad JSON", + err: ErrResponseMarshal, + route: &Route{ + Method: http.MethodGet, + Path: "/json", + ResponseBody: map[string]any{"message": make(chan int)}, + ResponseStatusCode: http.StatusOK, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := p.Register(tc.route) + require.Error(t, err, "expected error registering route") + assert.ErrorIs(t, err, tc.err) + }) + } +} + +func TestBadRecorder(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + err := p.Record("") + require.ErrorIs(t, err, ErrNoRecorderURL, "expected error recording parrot") + + err = p.Record("invalid url") + require.ErrorIs(t, err, ErrInvalidRecorderURL, "expected error recording parrot") +} + +func TestUnregisteredRoute(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + resp, err := p.Call(http.MethodGet, "/unregistered") + require.NoError(t, err, "error calling parrot") + require.NotNil(t, resp, "response should not be nil") + + assert.Equal(t, http.StatusNotFound, resp.StatusCode()) +} + +func TestDelete(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + route := &Route{ + Method: http.MethodPost, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + + err := p.Register(route) + require.NoError(t, err, "error registering route") + + resp, err := p.Call(route.Method, route.Path) + require.NoError(t, err, "error calling parrot") + + assert.Equal(t, resp.StatusCode(), route.ResponseStatusCode) + assert.Equal(t, route.RawResponseBody, string(resp.Body())) + + err = p.Delete(route.ID()) + require.NoError(t, err, "error unregistering route") + + resp, err = p.Call(route.Method, route.Path) + require.NoError(t, err, "error calling parrot") + assert.Equal(t, http.StatusNotFound, resp.StatusCode()) + + // Try to delete the route again + err = p.Delete(route.ID()) + require.ErrorIs(t, err, ErrRouteNotFound, "expected error deleting route") +} + +func TestSaveLoad(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + routes := []*Route{ + { + Method: "GET", + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }, + { + Method: "Post", + Path: "/goodbye", + RawResponseBody: "Squeak", + ResponseStatusCode: 201, + }, + } + + recorders := []string{ // Dummy recorder URLs + "http://localhost:8080", + "http://localhost:8081", + } + + for _, route := range routes { + err := p.Register(route) + require.NoError(t, err, "error registering route") + } + + for _, recorder := range recorders { + err := p.Record(recorder) + require.NoError(t, err, "error recording parrot") + } + + err := p.save() + require.NoError(t, err) + + require.FileExists(t, t.Name()+".json") + err = p.load() + require.NoError(t, err) + + for _, route := range routes { + resp, err := p.Call(route.Method, route.Path) + require.NoError(t, err, "error calling parrot") + + assert.Equal(t, route.ResponseStatusCode, resp.StatusCode(), "unexpected status code for route %s", route.ID()) + assert.Equal(t, route.RawResponseBody, string(resp.Body())) + } + + registeredRecorders := p.Recorders() + require.Len(t, registeredRecorders, len(recorders), "unexpected number of recorders") +} + +func TestShutDown(t *testing.T) { + fileName := t.Name() + ".json" + p, err := Wake(WithSaveFile(fileName), WithLogLevel(testLogLevel)) + require.NoError(t, err, "error waking parrot") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err = p.Shutdown(ctx) + require.NoError(t, err, "error shutting down parrot") + + p.WaitShutdown() // Wait for shutdown to complete + + _, err = p.Call(http.MethodGet, "/hello") + require.ErrorIs(t, err, ErrServerShutdown, "expected error calling parrot after shutdown") + + err = p.Record("http://localhost:8080") + require.ErrorIs(t, err, ErrServerShutdown, "expected error recording parrot after shutdown") + + err = p.Register(&Route{ + Method: http.MethodGet, + Path: "/hello", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + }) + require.ErrorIs(t, err, ErrServerShutdown, "expected error registering route after shutdown") + + err = p.Delete("route-id") + require.ErrorIs(t, err, ErrServerShutdown, "expected error deleting route after shutdown") + + err = p.Shutdown(context.Background()) + require.ErrorIs(t, err, ErrServerShutdown, "expected error shutting down parrot after shutdown") +} + +func newParrot(t *testing.T) *Server { + t.Helper() + + fileName := t.Name() + ".json" + p, err := Wake(WithSaveFile(fileName), WithLogLevel(testLogLevel)) + require.NoError(t, err, "error waking parrot") + t.Cleanup(func() { + err := p.Shutdown(context.Background()) + assert.NoError(t, err, "error shutting down parrot") + p.WaitShutdown() // Wait for shutdown to complete + os.Remove(fileName) + }) + return p +} diff --git a/parrot/recorder.go b/parrot/recorder.go new file mode 100644 index 000000000..8a2bb28b3 --- /dev/null +++ b/parrot/recorder.go @@ -0,0 +1,175 @@ +package parrot + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "time" +) + +// Recorder records route calls +type Recorder struct { + Host string `json:"host"` + Port string `json:"port"` + server *http.Server + + recordChan chan *RouteCall + errChan chan error +} + +// RouteCall records when a route is called, the request and response +type RouteCall struct { + // RouteCallID is a unique identifier for the route call for help with debugging + RouteCallID string `json:"route_call_id"` + // RouteID is the identifier of the route that was called + RouteID string `json:"route_id"` + // Request is the request made to the route + Request *RouteCallRequest `json:"request"` + // Response is the response from the route + Response *RouteCallResponse `json:"response"` +} + +// RouteCallRequest records the request made to a route +type RouteCallRequest struct { + Method string `json:"method"` + URL *url.URL `json:"url"` + RemoteAddr string `json:"caller"` + Header http.Header `json:"header"` + Body []byte `json:"body"` +} + +// RouteCallResponse records the response from a route +type RouteCallResponse struct { + StatusCode int `json:"status_code"` + Header http.Header `json:"header"` + Body []byte `json:"body"` +} + +// RecorderOption is a function that modifies a recorder +type RecorderOption func(*Recorder) + +// WithHost sets the host of the recorder +func WithHost(host string) RecorderOption { + return func(r *Recorder) { + r.Host = host + } +} + +// NewRecorder creates a new recorder that listens for incoming requests to the parrot server +func NewRecorder(opts ...RecorderOption) (*Recorder, error) { + r := &Recorder{ + recordChan: make(chan *RouteCall), + errChan: make(chan error), + } + + listener, err := net.Listen("tcp", ":0") // nolint:gosec + if err != nil { + return nil, fmt.Errorf("failed to start listener: %w", err) + } + r.Host, r.Port, err = net.SplitHostPort(listener.Addr().String()) + if err != nil { + return nil, fmt.Errorf("failed to split host and port: %w", err) + } + + mux := http.NewServeMux() + mux.Handle("/", r.defaultRecordHandler()) + r.server = &http.Server{ + ReadHeaderTimeout: 5 * time.Second, + Addr: listener.Addr().String(), + Handler: mux, + } + + for _, opt := range opts { + opt(r) + } + + go func() { + if err := r.server.Serve(listener); err != nil { + if err != http.ErrServerClosed { + fmt.Println("Error serving recorder:", err) + } + } + }() + return r, nil +} + +// URL returns the URL of the recorder to send requests to +// WARNING: This URL automatically binds to the first available port on the host machine +// and the host will be 0.0.0.0 or localhost. If you're calling this from a different machine +// you will need to replace the host with the IP address of the machine running the recorder. +func (r *Recorder) URL() string { + return fmt.Sprintf("http://%s:%s", r.Host, r.Port) +} + +// Record receives recorded calls +func (r *Recorder) Record() chan *RouteCall { + return r.recordChan +} + +// Close shuts down the recorder +func (r *Recorder) Close() error { + return r.server.Close() +} + +// Err receives errors from the recorder +func (r *Recorder) Err() chan error { + return r.errChan +} + +func (r *Recorder) defaultRecordHandler() http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + var recordedCall *RouteCall + if err := json.NewDecoder(req.Body).Decode(&recordedCall); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + defer req.Body.Close() + + r.recordChan <- recordedCall + } +} + +// httpResponseRecorder is a wrapper around http.ResponseWriter that records the response +// for later inspection while still writing to the original writer. +// WARNING: If you mutate after calling Header(), the changes will not be reflected in the recorded response. +type responseWriterRecorder struct { + originalWriter http.ResponseWriter + record *httptest.ResponseRecorder +} + +func newResponseWriterRecorder(w http.ResponseWriter) *responseWriterRecorder { + return &responseWriterRecorder{ + originalWriter: w, + record: httptest.NewRecorder(), + } +} + +// SetWriter sets a new writer to record and write to, flushing any previous record +func (rr *responseWriterRecorder) SetWriter(w http.ResponseWriter) { + rr.originalWriter = w + rr.record = httptest.NewRecorder() +} + +func (rr *responseWriterRecorder) WriteHeader(code int) { + rr.originalWriter.WriteHeader(code) + rr.record.WriteHeader(code) +} + +func (rr *responseWriterRecorder) Write(data []byte) (int, error) { + _, _ = rr.record.Write(data) // ignore error as we still want to write to the original writer + return rr.originalWriter.Write(data) +} + +func (rr *responseWriterRecorder) Header() http.Header { + for k, v := range rr.originalWriter.Header() { + rr.record.Header()[k] = v + } + return rr.originalWriter.Header() +} + +func (rr *responseWriterRecorder) Result() *http.Response { + return rr.record.Result() +} diff --git a/parrot/recorder_test.go b/parrot/recorder_test.go new file mode 100644 index 000000000..91684e5cd --- /dev/null +++ b/parrot/recorder_test.go @@ -0,0 +1,203 @@ +package parrot + +import ( + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResponseWriterRecorder(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + responseFunc http.HandlerFunc + expectedRespCode int + expectedRespBody string + expectedRespHeader http.Header + }{ + { + name: "good response", + responseFunc: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("Squawk")) + require.NoError(t, err, "error writing response") + }, + expectedRespCode: http.StatusOK, + expectedRespBody: "Squawk", + expectedRespHeader: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + }, + { + name: "error response", + responseFunc: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Squawk", http.StatusInternalServerError) + }, + expectedRespCode: http.StatusInternalServerError, + expectedRespBody: "Squawk\n", // http.Error adds a newline + expectedRespHeader: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + writerRecorder := newResponseWriterRecorder(recorder) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + handler := http.HandlerFunc(tc.responseFunc) + handler.ServeHTTP(writerRecorder, req) + + actualResp := recorder.Result() + recordedResp := writerRecorder.Result() + t.Cleanup(func() { + _ = actualResp.Body.Close() + _ = recordedResp.Body.Close() + }) + + actualBody, err := io.ReadAll(actualResp.Body) + require.NoError(t, err, "error reading actual response body") + recordedBody, err := io.ReadAll(recordedResp.Body) + require.NoError(t, err, "error reading recorded response body") + + assert.Equal(t, tc.expectedRespCode, actualResp.StatusCode, "actual response has unexpected status code") + assert.Equal(t, tc.expectedRespCode, recordedResp.StatusCode, "recorded response has unexpected status code") + assert.Equal(t, tc.expectedRespBody, string(actualBody), "actual response has unexpected body") + assert.Equal(t, tc.expectedRespBody, string(recordedBody), "recorded response has unexpected body") + }) + } +} + +func TestRecorder(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + recorder, err := NewRecorder() + require.NoError(t, err, "error creating recorder") + t.Cleanup(func() { + require.NoError(t, recorder.Close()) + }) + + err = p.Record(recorder.URL()) + require.NoError(t, err, "error recording parrot") + t.Cleanup(func() { + require.NoError(t, recorder.Close()) + }) + + route := &Route{ + Method: http.MethodGet, + Path: "/test", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + err = p.Register(route) + require.NoError(t, err, "error registering route") + + var ( + responseCount = 5 + recordedCalls = 0 + ) + + go func() { + for i := 0; i < responseCount; i++ { + _, err := p.Call(http.MethodGet, "/test") + require.NoError(t, err, "error calling parrot") + } + }() + + for { + select { + case recordedRouteCall := <-recorder.Record(): + assert.Equal(t, route.ID(), recordedRouteCall.RouteID, "recorded response has unexpected route ID") + + assert.Equal(t, http.StatusOK, recordedRouteCall.Response.StatusCode, "recorded response has unexpected status code") + assert.Equal(t, "Squawk", string(recordedRouteCall.Response.Body), "recorded response has unexpected body") + + assert.Equal(t, "/test", recordedRouteCall.Request.URL.Path, "recorded request has unexpected path") + assert.Equal(t, http.MethodGet, recordedRouteCall.Request.Method, "recorded request has unexpected method") + recordedCalls++ + if recordedCalls == responseCount { + return + } + case err := <-recorder.Err(): + require.NoError(t, err, "error recording route call") + } + } +} + +func TestMultipleRecorders(t *testing.T) { + t.Parallel() + + p := newParrot(t) + + var ( + numRecorders = 10 + numCalls = 5 + ) + recorders := make([]*Recorder, numRecorders) + for i := 0; i < numRecorders; i++ { + recorder, err := NewRecorder() + require.NoError(t, err, "error creating recorder") + recorders[i] = recorder + } + t.Cleanup(func() { + for _, recorder := range recorders { + require.NoError(t, recorder.Close()) + } + }) + + for _, recorder := range recorders { + err := p.Record(recorder.URL()) + require.NoError(t, err, "error recording parrot") + } + + route := &Route{ + Method: http.MethodGet, + Path: "/test", + RawResponseBody: "Squawk", + ResponseStatusCode: http.StatusOK, + } + err := p.Register(route) + require.NoError(t, err, "error registering route") + + var wg sync.WaitGroup + wg.Add(numCalls) + for i := 0; i < numCalls; i++ { + go func() { + defer wg.Done() + _, err := p.Call(http.MethodGet, "/test") + require.NoError(t, err, "error calling parrot") + }() + } + wg.Wait() + + for _, recorder := range recorders { + for i := 0; i < numCalls; i++ { + select { + case recordedRouteCall := <-recorder.Record(): + assert.Equal(t, route.ID(), recordedRouteCall.RouteID, "recorded response has unexpected route ID for recorder %d", i) + assert.Equal(t, http.StatusOK, recordedRouteCall.Response.StatusCode, "recorded response has unexpected status code for recorder %d", i) + assert.Equal(t, "Squawk", string(recordedRouteCall.Response.Body), "recorded response has unexpected body for recorder %d", i) + assert.Equal(t, "/test", recordedRouteCall.Request.URL.Path, "recorded request has unexpected path for recorder %d", i) + assert.Equal(t, http.MethodGet, recordedRouteCall.Request.Method, "recorded request has unexpected method for recorder %d", i) + case err := <-recorder.Err(): + require.NoError(t, err, "error recording route call") + case <-time.After(time.Second): + require.Fail(t, "timed out waiting for recorder %d", i) + } + } + } +}