Skip to content

Commit 9d0d5ae

Browse files
feat(server): add rate limiting and improve handler architecture (#17)
- Add RateLimitDeferrer to handle GitHub API rate limit backoff with exponential retry and jitter - Create fetcher_test.go with comprehensive test coverage for ConfigFetcher - Refactor ConfigFetcher constructor for cleaner initialization - Improve base handler with rate limit context propagation - Enhance eval context with rate limit aware evaluation - Add rate limit test suite with mock GitHub client Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
1 parent 33b5375 commit 9d0d5ae

10 files changed

Lines changed: 704 additions & 11 deletions

File tree

server/handler/base.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ type Base struct {
4444
AppName string
4545
AppID int64
4646

47-
Debouncer *StatusDebouncer
47+
Debouncer *StatusDebouncer
48+
RateLimitDeferrer *RateLimitDeferrer
4849
}
4950

5051
// PostCheckRun creates a GitHub check run with consistent logging.
@@ -152,8 +153,8 @@ func (b *Base) newEvalContext(ctx context.Context, installationID int64, loc pul
152153
}
153154

154155
func (b *Base) Evaluate(ctx context.Context, installationID int64, trigger common.Trigger, loc pull.Locator) error {
156+
key := DebounceKey(loc.Owner, loc.Repo, loc.Number)
155157
if b.Debouncer != nil {
156-
key := DebounceKey(loc.Owner, loc.Repo, loc.Number)
157158
trailingFn := func(eventCtx context.Context, accumulated common.Trigger, coalesced int) {
158159
tctx, span := StartDebounceTrailingSpan(eventCtx)
159160
defer span.End()
@@ -168,7 +169,7 @@ func (b *Base) Evaluate(ctx context.Context, installationID int64, trigger commo
168169

169170
logger := zerolog.Ctx(tctx)
170171
logger.Debug().Msgf("Running trailing evaluation for %s/%s#%d (accumulated trigger: %s, coalesced: %d)", loc.Owner, loc.Repo, loc.Number, accumulated, coalesced)
171-
if err := b.doEvaluate(tctx, installationID, accumulated, loc); err != nil {
172+
if err := b.evaluateOnce(tctx, installationID, accumulated, loc, key); err != nil {
172173
RecordError(span, &err)
173174
logger.Error().Err(err).Msgf("Trailing evaluation failed for %s/%s#%d", loc.Owner, loc.Repo, loc.Number)
174175
}
@@ -178,7 +179,7 @@ func (b *Base) Evaluate(ctx context.Context, installationID int64, trigger commo
178179
return nil
179180
}
180181
}
181-
return b.doEvaluate(ctx, installationID, trigger, loc)
182+
return b.evaluateOnce(ctx, installationID, trigger, loc, key)
182183
}
183184

184185
func (b *Base) doEvaluate(ctx context.Context, installationID int64, trigger common.Trigger, loc pull.Locator) error {
@@ -188,3 +189,28 @@ func (b *Base) doEvaluate(ctx context.Context, installationID int64, trigger com
188189
}
189190
return evalCtx.Evaluate(ctx, trigger)
190191
}
192+
193+
func (b *Base) evaluateOnce(ctx context.Context, installationID int64, trigger common.Trigger, loc pull.Locator, key string) error {
194+
if b.RateLimitDeferrer != nil && b.RateLimitDeferrer.DeferIfActive(ctx, installationID, key, trigger, func(deferredCtx context.Context, deferredTrigger common.Trigger) {
195+
if err := b.evaluateOnce(deferredCtx, installationID, deferredTrigger, loc, key); err != nil {
196+
zerolog.Ctx(deferredCtx).Error().Err(err).Msgf("Deferred evaluation failed for %s/%s#%d", loc.Owner, loc.Repo, loc.Number)
197+
}
198+
}) {
199+
zerolog.Ctx(ctx).Warn().Msgf("Deferring evaluation for %s/%s#%d because installation %d is rate limited", loc.Owner, loc.Repo, loc.Number, installationID)
200+
return nil
201+
}
202+
203+
if err := b.doEvaluate(ctx, installationID, trigger, loc); err != nil {
204+
if resetAt, ok := rateLimitResetTime(err); ok && b.RateLimitDeferrer != nil {
205+
b.RateLimitDeferrer.DeferUntil(ctx, installationID, key, trigger, resetAt, func(deferredCtx context.Context, deferredTrigger common.Trigger) {
206+
if evalErr := b.evaluateOnce(deferredCtx, installationID, deferredTrigger, loc, key); evalErr != nil {
207+
zerolog.Ctx(deferredCtx).Error().Err(evalErr).Msgf("Deferred evaluation failed for %s/%s#%d", loc.Owner, loc.Repo, loc.Number)
208+
}
209+
})
210+
zerolog.Ctx(ctx).Warn().Err(err).Time("github_rate_limit_reset", resetAt).Msgf("Rate limited evaluating %s/%s#%d; deferred retry scheduled", loc.Owner, loc.Repo, loc.Number)
211+
return nil
212+
}
213+
return err
214+
}
215+
return nil
216+
}

server/handler/base_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"io"
2121
"net/http"
2222
"net/url"
23+
"strconv"
2324
"strings"
2425
"sync"
2526
"testing"
@@ -188,6 +189,80 @@ type recordingTransport struct {
188189
configBody string
189190
}
190191

192+
func TestEvaluateDefersAndRetriesRateLimitedEvaluation(t *testing.T) {
193+
transport := &rateLimitedConfigTransport{
194+
resetAt: time.Now().Add(time.Second),
195+
}
196+
httpClient := &http.Client{Transport: transport}
197+
198+
client := github.NewClient(httpClient)
199+
baseURL, err := url.Parse("http://github.localhost/")
200+
require.NoError(t, err)
201+
client.BaseURL = baseURL
202+
203+
v4client := githubv4.NewClient(httpClient)
204+
205+
deferrer := NewRateLimitDeferrer(0)
206+
deferrer.jitter = func() time.Duration { return 0 }
207+
208+
h := Base{
209+
ClientCreator: staticClientCreator{
210+
client: client,
211+
v4client: v4client,
212+
},
213+
ConfigFetcher: NewConfigFetcher(appconfig.NewLoader([]string{".policy.yml"})),
214+
BaseConfig: &baseapp.HTTPConfig{
215+
PublicURL: "https://policy-bot.example.com",
216+
},
217+
PullOpts: &PullEvaluationOptions{
218+
StatusCheckContext: "policy-bot",
219+
},
220+
RateLimitDeferrer: deferrer,
221+
}
222+
223+
pr := testPullRequest()
224+
err = h.Evaluate(context.Background(), 123, 0, pull.Locator{
225+
Owner: pr.GetBase().GetRepo().GetOwner().GetLogin(),
226+
Repo: pr.GetBase().GetRepo().GetName(),
227+
Number: pr.GetNumber(),
228+
Value: pr,
229+
})
230+
require.NoError(t, err)
231+
assert.Equal(t, 1, transport.configRequestCount())
232+
233+
require.Eventually(t, func() bool {
234+
return transport.configRequestCount() >= 2
235+
}, 2*time.Second, 25*time.Millisecond)
236+
}
237+
238+
type rateLimitedConfigTransport struct {
239+
mu sync.Mutex
240+
configRequests int
241+
resetAt time.Time
242+
}
243+
244+
func (rt *rateLimitedConfigTransport) RoundTrip(req *http.Request) (*http.Response, error) {
245+
rt.mu.Lock()
246+
rt.configRequests++
247+
count := rt.configRequests
248+
rt.mu.Unlock()
249+
250+
if count == 1 {
251+
res := jsonResponse(req, http.StatusForbidden, `{"message":"API rate limit exceeded"}`)
252+
res.Header.Set("X-Ratelimit-Limit", "9600")
253+
res.Header.Set("X-Ratelimit-Remaining", "0")
254+
res.Header.Set("X-Ratelimit-Reset", strconv.FormatInt(rt.resetAt.Unix(), 10))
255+
return res, nil
256+
}
257+
return jsonResponse(req, http.StatusNotFound, `{"message":"Not Found"}`), nil
258+
}
259+
260+
func (rt *rateLimitedConfigTransport) configRequestCount() int {
261+
rt.mu.Lock()
262+
defer rt.mu.Unlock()
263+
return rt.configRequests
264+
}
265+
191266
func (rt *recordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
192267
if req.URL.Path == "/repos/testorg/testrepo/contents/.policy.yml" {
193268
time.Sleep(rt.configDelay)

server/handler/debounce.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
"go.opentelemetry.io/otel/trace"
2626
)
2727

28-
const DefaultDebounceWindow = 5 * time.Second
28+
const DefaultDebounceWindow = 30 * time.Second
2929

3030
// StatusDebouncer coalesces rapid evaluation requests for the same pull
3131
// request. When many check runs or status events complete in quick succession

server/handler/eval_context.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func parseFetchedConfig(ctx context.Context, fc FetchedConfig, evalOpts *PullEva
106106
msg := fmt.Sprintf("Error loading policy from %s", fc.Source)
107107
logger.Warn().Err(fc.LoadError).Msg(msg)
108108

109-
if postStatus != nil {
109+
if postStatus != nil && !isTransientClientError(fc.LoadError) {
110110
postStatus(ctx, "error", msg)
111111
}
112112
return nil, errors.Wrapf(fc.LoadError, "failed to load policy: %s: %s", fc.Source, fc.Path)
@@ -319,6 +319,9 @@ func isTransientClientError(err error) bool {
319319
if errors.As(err, &ghErr) {
320320
return ghErr.Response != nil && ghErr.Response.StatusCode >= 400
321321
}
322+
if isRateLimitError(err) {
323+
return true
324+
}
322325
var urlErr *url.Error
323326
if errors.As(err, &urlErr) {
324327
return true

server/handler/eval_context_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ package handler
1616

1717
import (
1818
"context"
19+
"errors"
1920
"testing"
21+
"time"
2022

2123
"github.com/palantir/policy-bot/policy"
2224
"github.com/palantir/policy-bot/policy/common"
@@ -243,3 +245,39 @@ func TestEvalContext_EvaluatePolicy_PendingAsFailure_NilConfig(t *testing.T) {
243245
assert.Equal(t, "completed", ec.Status.GetStatus(), "server option should be used when policy config is nil")
244246
assert.Equal(t, "failure", ec.Status.GetConclusion(), "server option should be used when policy config is nil")
245247
}
248+
249+
func TestParseFetchedConfigDoesNotPostStatusForTransientLoadError(t *testing.T) {
250+
var posted int
251+
fc := FetchedConfig{
252+
Source: "testowner/testrepo@main",
253+
Path: ".policy.yml",
254+
LoadError: newRateLimitError(time.Now().Add(time.Minute)),
255+
}
256+
257+
_, err := parseFetchedConfig(context.Background(), fc, nil, common.TriggerStatus, func(context.Context, string, string) {
258+
posted++
259+
})
260+
261+
require.Error(t, err)
262+
assert.Equal(t, 0, posted)
263+
}
264+
265+
func TestParseFetchedConfigPostsStatusForNonTransientLoadError(t *testing.T) {
266+
var posted int
267+
fc := FetchedConfig{
268+
Source: "testowner/testrepo@main",
269+
Path: ".policy.yml",
270+
LoadError: errors.New("invalid remote reference"),
271+
}
272+
273+
_, err := parseFetchedConfig(context.Background(), fc, nil, common.TriggerStatus, func(context.Context, string, string) {
274+
posted++
275+
})
276+
277+
require.Error(t, err)
278+
assert.Equal(t, 1, posted)
279+
}
280+
281+
func TestIsTransientClientErrorRecognizesRateLimits(t *testing.T) {
282+
assert.True(t, isTransientClientError(newRateLimitError(time.Now().Add(time.Minute))))
283+
}

server/handler/fetcher.go

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"errors"
2020
"net/http"
2121
"os"
22+
"sync"
2223
"time"
2324

2425
"github.com/google/go-github/v81/github"
@@ -37,10 +38,94 @@ type FetchedConfig struct {
3738
}
3839

3940
type ConfigFetcher struct {
40-
Loader *appconfig.Loader
41+
Loader *appconfig.Loader
42+
CacheTTL time.Duration
43+
Clock func() time.Time
44+
45+
mu sync.Mutex
46+
cache map[string]configCacheEntry
47+
inflight map[string]*configInflight
48+
}
49+
50+
const DefaultConfigCacheTTL = 30 * time.Second
51+
52+
type configCacheEntry struct {
53+
config FetchedConfig
54+
expiresAt time.Time
55+
}
56+
57+
type configInflight struct {
58+
done chan struct{}
59+
config FetchedConfig
60+
}
61+
62+
func NewConfigFetcher(loader *appconfig.Loader) *ConfigFetcher {
63+
return &ConfigFetcher{
64+
Loader: loader,
65+
CacheTTL: DefaultConfigCacheTTL,
66+
}
4167
}
4268

4369
func (cf *ConfigFetcher) ConfigForRepositoryBranch(ctx context.Context, client *github.Client, owner, repository, branch string) FetchedConfig {
70+
if cf.CacheTTL <= 0 {
71+
return cf.loadConfigForRepositoryBranch(ctx, client, owner, repository, branch)
72+
}
73+
74+
key := configCacheKey(owner, repository, branch)
75+
now := cf.now()
76+
77+
cf.mu.Lock()
78+
if cf.cache != nil {
79+
if entry, ok := cf.cache[key]; ok {
80+
if now.Before(entry.expiresAt) {
81+
cf.mu.Unlock()
82+
return cloneFetchedConfig(entry.config)
83+
}
84+
delete(cf.cache, key)
85+
}
86+
}
87+
if cf.inflight != nil {
88+
if in := cf.inflight[key]; in != nil {
89+
cf.mu.Unlock()
90+
select {
91+
case <-ctx.Done():
92+
return FetchedConfig{
93+
Source: owner + "/" + repository + "@" + branch,
94+
LoadError: ctx.Err(),
95+
}
96+
case <-in.done:
97+
return cloneFetchedConfig(in.config)
98+
}
99+
}
100+
}
101+
if cf.inflight == nil {
102+
cf.inflight = make(map[string]*configInflight)
103+
}
104+
in := &configInflight{done: make(chan struct{})}
105+
cf.inflight[key] = in
106+
cf.mu.Unlock()
107+
108+
fc := cf.loadConfigForRepositoryBranch(ctx, client, owner, repository, branch)
109+
110+
cf.mu.Lock()
111+
in.config = cloneFetchedConfig(fc)
112+
delete(cf.inflight, key)
113+
if shouldCacheFetchedConfig(fc) {
114+
if cf.cache == nil {
115+
cf.cache = make(map[string]configCacheEntry)
116+
}
117+
cf.cache[key] = configCacheEntry{
118+
config: cloneFetchedConfig(fc),
119+
expiresAt: cf.now().Add(cf.CacheTTL),
120+
}
121+
}
122+
close(in.done)
123+
cf.mu.Unlock()
124+
125+
return fc
126+
}
127+
128+
func (cf *ConfigFetcher) loadConfigForRepositoryBranch(ctx context.Context, client *github.Client, owner, repository, branch string) FetchedConfig {
44129
retries := 0
45130
delay := 1 * time.Second
46131
for {
@@ -86,6 +171,39 @@ func (cf *ConfigFetcher) ConfigForRepositoryBranch(ctx context.Context, client *
86171
}
87172
}
88173

174+
func (cf *ConfigFetcher) now() time.Time {
175+
if cf.Clock != nil {
176+
return cf.Clock()
177+
}
178+
return time.Now()
179+
}
180+
181+
func configCacheKey(owner, repository, branch string) string {
182+
return owner + "/" + repository + "@" + branch
183+
}
184+
185+
func shouldCacheFetchedConfig(fc FetchedConfig) bool {
186+
return fc.LoadError == nil && fc.ParseError == nil
187+
}
188+
189+
func cloneFetchedConfig(fc FetchedConfig) FetchedConfig {
190+
clone := fc
191+
if fc.Config == nil {
192+
return clone
193+
}
194+
195+
b, err := yaml.Marshal(fc.Config)
196+
if err != nil {
197+
return clone
198+
}
199+
var pc policy.Config
200+
if err := yaml.UnmarshalStrict(b, &pc); err != nil {
201+
return clone
202+
}
203+
clone.Config = &pc
204+
return clone
205+
}
206+
89207
func isServerError(err error) bool {
90208
var ghErr *github.ErrorResponse
91209
if errors.As(err, &ghErr) {

0 commit comments

Comments
 (0)