diff --git a/README.md b/README.md index 972b729f..edbe4534 100644 --- a/README.md +++ b/README.md @@ -592,6 +592,16 @@ c = c.Append(hlog.AccessHandler(func(r *http.Request, status, size int, duration Dur("duration", duration). Msg("") })) +c = c.Append(hlog.AccessHandlerWithData(func(data hlog.AccessHandlerData) { + hlog.FromRequest(data.Request).Info(). + Str("method", data.Request.Method). + Stringer("url", data.Request.URL). + Int("status", data.Status). + Int("sizeWritten", data.BytesWritten). + Int64("sizeRead", data.BytesRead). + Dur("duration", data.Duration). + Msg("") +})) c = c.Append(hlog.RemoteAddrHandler("ip")) c = c.Append(hlog.UserAgentHandler("user_agent")) c = c.Append(hlog.RefererHandler("referer")) diff --git a/hlog/hlog.go b/hlog/hlog.go index 06ca4adf..fc3857e7 100644 --- a/hlog/hlog.go +++ b/hlog/hlog.go @@ -302,6 +302,36 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat } } +type AccessHandlerData struct { + Request *http.Request + Duration time.Duration + Status int + BytesWritten int + BytesRead int64 +} + +// AccessHandlerWithData returns a handler that call f after each request. +func AccessHandlerWithData(f func(data AccessHandlerData)) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + ww := mutil.WrapWriter(w) + body := mutil.NewByteCountReadCloser(r.Body) + r.Body = body + defer func() { + f(AccessHandlerData{ + Request: r, + Duration: time.Since(start), + Status: ww.Status(), + BytesWritten: ww.BytesWritten(), + BytesRead: body.BytesRead(), + }) + }() + next.ServeHTTP(ww, r) + }) + } +} + // HostHandler adds the request's host as a field to the context's logger // using fieldKey as field key. If trimPort is set to true, then port is // removed from the host. diff --git a/hlog/hlog_test.go b/hlog/hlog_test.go index 0d6b31ea..63b73cb0 100644 --- a/hlog/hlog_test.go +++ b/hlog/hlog_test.go @@ -7,10 +7,12 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "github.com/rs/xid" @@ -432,3 +434,43 @@ func TestGetHost(t *testing.T) { }) } } + +func TestAccessHandlerWithData(t *testing.T) { + bodyValue := "hello, world!" + req := httptest.NewRequest(http.MethodGet, "/", strings.NewReader(bodyValue)) + + handler := AccessHandlerWithData(func(data AccessHandlerData) { + expectedBytes := int64(len(bodyValue)) + if data.BytesRead != expectedBytes { + t.Errorf("unexpected bytes read, got: %d, want: %d", data.BytesRead, expectedBytes) + } + if data.BytesWritten != int(expectedBytes) { + t.Errorf("unexpected bytes read, got: %d, want: %d", data.BytesWritten, expectedBytes) + } + if data.Status != http.StatusOK { + t.Errorf("unexpected status, got: %d, want: %d", data.Status, http.StatusOK) + } + if data.Request != req { + t.Error("unexpected request object") + } + }) + + rr := httptest.NewRecorder() + + handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, r.Body) + })).ServeHTTP(rr, req) + + if rr.Result().StatusCode != http.StatusOK { + t.Errorf("unexpected status, got: %d, want: %d", rr.Result().StatusCode, http.StatusOK) + } + + b, err := io.ReadAll(rr.Result().Body) + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + + if bodyValue != string(b) { + t.Errorf("unexpected response body, got: %s, want: %s", string(b), bodyValue) + } +} diff --git a/hlog/internal/mutil/body.go b/hlog/internal/mutil/body.go new file mode 100644 index 00000000..b40f6558 --- /dev/null +++ b/hlog/internal/mutil/body.go @@ -0,0 +1,42 @@ +package mutil + +import ( + "io" + "sync/atomic" +) + +type byteCountReadCloser struct { + rc io.ReadCloser + read *int64 +} + +var _ io.ReadCloser = (*byteCountReadCloser)(nil) +var _ io.WriterTo = (*byteCountReadCloser)(nil) + +func NewByteCountReadCloser(rc io.ReadCloser) *byteCountReadCloser { + read := int64(0) + return &byteCountReadCloser{ + rc: rc, + read: &read, + } +} + +func (b *byteCountReadCloser) Read(p []byte) (int, error) { + n, err := b.rc.Read(p) + atomic.AddInt64(b.read, int64(n)) + return n, err +} + +func (b *byteCountReadCloser) Close() error { + return b.rc.Close() +} + +func (b *byteCountReadCloser) WriteTo(w io.Writer) (int64, error) { + n, err := io.Copy(w, b.rc) + atomic.AddInt64(b.read, n) + return n, err +} + +func (b *byteCountReadCloser) BytesRead() int64 { + return atomic.LoadInt64(b.read) +}