Skip to content

Commit

Permalink
Eliminate races
Browse files Browse the repository at this point in the history
  • Loading branch information
kalverra committed Jan 30, 2025
1 parent b4cbf64 commit faf9a06
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 72 deletions.
28 changes: 24 additions & 4 deletions parrot/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func ExampleServer_Register_internal() {
ResponseStatusCode: http.StatusOK,
}

waitForParrotServerInternal(p, time.Second) // Wait for the parrot server to start

// Register the route with the parrot instance
err = p.Register(route)
if err != nil {
Expand Down Expand Up @@ -86,7 +88,7 @@ func ExampleServer_Register_external() {
client := resty.New()
client.SetBaseURL(fmt.Sprintf("http://localhost:%d", port)) // The URL of the parrot server

waitForParrotServer(client, time.Second) // Wait for the parrot server to start
waitForParrotServerExternal(client, time.Second) // Wait for the parrot server to start

// Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk"
route := &parrot.Route{
Expand Down Expand Up @@ -158,6 +160,8 @@ func ExampleRecorder_internal() {
panic(err)
}

waitForParrotServerInternal(p, time.Second) // Wait for the parrot server to start

// Register the recorder with the parrot instance
err = p.Record(recorder.URL())
if err != nil {
Expand Down Expand Up @@ -225,7 +229,7 @@ func ExampleRecorder_external() {
client := resty.New()
client.SetBaseURL(fmt.Sprintf("http://localhost:%d", port)) // The URL of the parrot server

waitForParrotServer(client, time.Second) // Wait for the parrot server to start
waitForParrotServerExternal(client, time.Second) // Wait for the parrot server to start

// Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk"
route := &parrot.Route{
Expand Down Expand Up @@ -290,8 +294,8 @@ func ExampleRecorder_external() {
// Squawk
}

// waitForParrotServer checks the parrot server health endpoint until it returns a 200 status code or the timeout is reached
func waitForParrotServer(client *resty.Client, timeoutDur time.Duration) {
// waitForParrotServerExternal checks the parrot server health endpoint until it returns a 200 status code or the timeout is reached
func waitForParrotServerExternal(client *resty.Client, timeoutDur time.Duration) {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(timeoutDur)
Expand All @@ -310,3 +314,19 @@ func waitForParrotServer(client *resty.Client, timeoutDur time.Duration) {
}
}
}

func waitForParrotServerInternal(p *parrot.Server, timeoutDur time.Duration) {
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(timeoutDur)
for { // Wait for the parrot server to start
select {
case <-ticker.C:
if err := p.Healthy(); err == nil {
return
}
case <-timeout.C:
panic("timeout waiting for parrot server to start")
}
}
}
52 changes: 26 additions & 26 deletions parrot/parrot.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/go-chi/chi"
Expand Down Expand Up @@ -67,13 +68,12 @@ type Server struct {
recordersMu sync.RWMutex

// Save and shutdown
shutDown bool
shutDown atomic.Bool
shutDownChan chan struct{}
shutDownOnce sync.Once
saveFileName string

// Logging
useCustomLogger bool
logFileName string
logFile *os.File
logLevel zerolog.Level
Expand Down Expand Up @@ -113,33 +113,32 @@ func Wake(options ...ServerOption) (*Server, error) {
}
}

// Setup logger
var err error
p.logFile, err = os.Create(p.logFileName)
if err != nil {
return nil, fmt.Errorf("failed to create log file: %w", err)
}

if !p.useCustomLogger { // Build default logger
var writers []io.Writer
var writers []io.Writer

zerolog.TimeFieldFormat = "2006-01-02T15:04:05.000"
if !p.disableConsoleLogs {
if p.jsonLogs {
writers = append(writers, os.Stderr)
} else {
consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02T15:04:05.000"}
writers = append(writers, consoleOut)
}
}

if p.logFile != nil {
writers = append(writers, p.logFile)
if !p.disableConsoleLogs {
if p.jsonLogs {
writers = append(writers, os.Stderr)
} else {
consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02T15:04:05.000"}
writers = append(writers, consoleOut)
}
}

multiWriter := zerolog.MultiLevelWriter(writers...)
p.log = zerolog.New(multiWriter).Level(p.logLevel).With().Timestamp().Logger()
if p.logFile != nil {
writers = append(writers, p.logFile)
}

multiWriter := zerolog.MultiLevelWriter(writers...)
p.log = zerolog.New(multiWriter).Level(p.logLevel).With().Timestamp().Logger()

// Setup server
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
if err != nil {
return nil, fmt.Errorf("failed to start listener: %w", err)
Expand All @@ -155,6 +154,7 @@ func Wake(options ...ServerOption) (*Server, error) {
return nil, fmt.Errorf("failed to parse port: %w", err)
}

// Initialize router
p.router.Get(HealthRoute, p.healthHandlerGET)

p.router.Get(RoutesRoute, p.routesHandlerGET)
Expand Down Expand Up @@ -182,7 +182,7 @@ func Wake(options ...ServerOption) (*Server, error) {
// run starts the parrot server
func (p *Server) run(listener net.Listener) {
defer func() {
p.shutDown = true
p.shutDown.Store(true)
if err := p.save(); err != nil {
p.log.Error().Err(err).Msg("Failed to save routes")
}
Expand Down Expand Up @@ -241,7 +241,7 @@ func (p *Server) routeCallHandler(route *Route) http.HandlerFunc {

// Healthy checks if the parrot server is healthy
func (p *Server) Healthy() error {
if p.shutDown {
if p.shutDown.Load() {
return ErrServerShutdown
}

Expand Down Expand Up @@ -288,7 +288,7 @@ func (p *Server) healthHandlerGET(w http.ResponseWriter, r *http.Request) {

// Shutdown gracefully shuts down the parrot server
func (p *Server) Shutdown(ctx context.Context) error {
if p.shutDown {
if p.shutDown.Load() {
return ErrServerShutdown
}

Expand All @@ -308,7 +308,7 @@ func (p *Server) Address() string {

// Register adds a new route to the parrot
func (p *Server) Register(route *Route) error {
if p.shutDown {
if p.shutDown.Load() {
return ErrServerShutdown
}
if route == nil {
Expand Down Expand Up @@ -372,7 +372,7 @@ func (p *Server) routesHandlerPOST(w http.ResponseWriter, r *http.Request) {

// Record registers a new recorder with the parrot. All incoming requests to the parrot will be sent to the recorder.
func (p *Server) Record(recorderURL string) error {
if p.shutDown {
if p.shutDown.Load() {
return ErrServerShutdown
}

Expand Down Expand Up @@ -426,7 +426,7 @@ func (p *Server) recorderHandlerPOST(w http.ResponseWriter, r *http.Request) {

// Recorders returns the URLs of all registered recorders
func (p *Server) Recorders() []string {
if p.shutDown {
if p.shutDown.Load() {
return nil
}

Expand Down Expand Up @@ -493,14 +493,14 @@ func (p *Server) routesHandlerDELETE(w http.ResponseWriter, r *http.Request) {

// Call makes a request to the parrot server
func (p *Server) Call(method, path string) (*resty.Response, error) {
if p.shutDown {
if p.shutDown.Load() {
return nil, ErrServerShutdown
}
return p.client.R().Execute(method, "http://"+filepath.Join(p.Address(), path))
}

func (p *Server) Routes() []*Route {
if p.shutDown {
if p.shutDown.Load() {
return nil
}

Expand Down
9 changes: 0 additions & 9 deletions parrot/parrot_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@ func WithLogLevel(level zerolog.Level) ServerOption {
}
}

// WithLogger sets the logger for the ParrotServer
func WithLogger(l zerolog.Logger) ServerOption {
return func(s *Server) error {
s.log = l
s.useCustomLogger = true
return nil
}
}

// WithJSONLogs sets the logger to output JSON logs
func WithJSONLogs() ServerOption {
return func(s *Server) error {
Expand Down
63 changes: 30 additions & 33 deletions parrot/parrot_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package parrot

import (
"bytes"
"context"
"encoding/json"
"flag"
Expand Down Expand Up @@ -33,6 +32,36 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func TestHealthy(t *testing.T) {
t.Parallel()

p := newParrot(t)

healthCount := 0
targetCount := 3

ticker := time.NewTicker(time.Millisecond * 10)
timeout := time.NewTimer(time.Second)
t.Cleanup(func() {
ticker.Stop()
timeout.Stop()
})

for {
select {
case <-ticker.C:
if err := p.Healthy(); err == nil {
healthCount++
}
if healthCount >= targetCount {
return
}
case <-timeout.C:
require.GreaterOrEqual(t, targetCount, healthCount, "parrot never became healthy")
}
}
}

func TestRegisterRoutes(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -519,38 +548,6 @@ func TestShutDown(t *testing.T) {
require.ErrorIs(t, err, ErrServerShutdown, "expected error shutting down parrot after shutdown")
}

func TestCustomLogger(t *testing.T) {
t.Parallel()

logBuffer := new(bytes.Buffer)
testLogger := zerolog.New(logBuffer)

fileName := t.Name() + ".json"
p, err := Wake(WithSaveFile(fileName), WithLogLevel(zerolog.DebugLevel), WithLogger(testLogger))
require.NoError(t, err, "error waking parrot")
t.Cleanup(func() {
err := p.Shutdown(context.Background())
assert.NoError(t, err, "error shutting down parrot")
p.WaitShutdown() // Wait for shutdown to complete
os.Remove(fileName)
})

route := &Route{
Method: http.MethodGet,
Path: "/hello",
RawResponseBody: "Squawk",
ResponseStatusCode: http.StatusOK,
}

err = p.Register(route)
require.NoError(t, err, "error registering route")

_, err = p.Call(route.Method, route.Path)
require.NoError(t, err, "error calling parrot")

require.Contains(t, logBuffer.String(), route.ID(), "expected log buffer to contain route call")
}

func TestJSONLogger(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit faf9a06

Please sign in to comment.