Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
### Example user template template
### Example user template

# IntelliJ project files
.idea
*.iml
out
gen
### Go template
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

# Test binary, built with `go test -c`
*.test

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

# Dependency directories (remove the comment below to include it)
# vendor/

# Go workspace file
go.work
go.work.sum

# env file
.env

3 changes: 2 additions & 1 deletion github_ratelimit/github_primary_ratelimit/category.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ func GetAllCategories() []ResourceCategory {
}

func parseRequestCategory(request *http.Request) ResourceCategory {
return parseCategory(request.Method, request.URL.RawPath)
// RawPath is only populated when there is encoding in the path
return parseCategory(request.Method, request.URL.Path)
}

func parseCategory(method string, path string) ResourceCategory {
Expand Down
5 changes: 3 additions & 2 deletions github_ratelimit/github_primary_ratelimit/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import "context"
// It is used internally and generated from the options.
// It holds the state of the rate limiter in order to enable state sharing.
type Config struct {
state *RateLimitState
bypassLimit bool
state *RateLimitState
bypassLimit bool
resetOnSuccess bool

// callbacks
onLimitReached OnLimitReached
Expand Down
9 changes: 9 additions & 0 deletions github_ratelimit/github_primary_ratelimit/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ func WithUnknownCategoryCallback(callback OnUnknownCategory) Option {
}
}

// WithResetOnSuccessfulCall forces the reset of the current Category when a successful response is received.
// Useful for probing with github_ratelimit.WithOverrideConfig for individual calls to reset an entire category.
// Additionally useful if this rate limiter sits on top of a collection of tokens that can be rotated
func WithResetOnSuccessfulCall() Option {
return func(c *Config) {
c.resetOnSuccess = true
}
}

// WithSharedState is used to set the rate limiter state from an external source.
// Specifically, it is used to share the state between multiple rate limiters.
// e.g.,
Expand Down
31 changes: 21 additions & 10 deletions github_ratelimit/github_primary_ratelimit/ratelimit_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,18 @@ func (s *RateLimitState) Update(config *Config, update UpdateContainer, callback

newResetTime := update.GetResetTime()
if newResetTime == nil {
// nothing to update on a successful request
if config.resetOnSuccess {
// reset the category block with nil
s.storeNewResetTime(config, category, callbackContext, nil)
}
return nil
}
callbackContext.ResetTime = newResetTime.AsTime()

sharedResetTime, exists := s.resetTimeMap[category]
if !exists {
// XXX: there is no point in adding it as a new category to the map,
// because we will not detect it anyway. so just trigger and continue.
config.TriggerUnknownCategory(callbackContext)
sharedResetTime := s.storeNewResetTime(config, category, callbackContext, newResetTime)
if sharedResetTime == nil {
return nil
}

// XXX: should hold a ref to the timer to free resources early on-demand.
// please open an issue if you actually need it.
sharedResetTime.Store(newResetTime)
timer := newResetTime.StartTimer()
go func(timer *time.Timer, callbackContext CallbackContext) {
<-timer.C
Expand All @@ -123,3 +119,18 @@ func (s *RateLimitState) Update(config *Config, update UpdateContainer, callback

return newResetTime
}

func (s *RateLimitState) storeNewResetTime(config *Config, category ResourceCategory, callbackContext *CallbackContext, newResetTime *SecondsSinceEpoch) *AtomicTime {
sharedResetTime, exists := s.resetTimeMap[category]
if !exists {
// XXX: there is no point in adding it as a new category to the map,
// because we will not detect it anyway. so just trigger and continue.
config.TriggerUnknownCategory(callbackContext)
return nil
}

// XXX: should hold a ref to the timer to free resources early on-demand.
// please open an issue if you actually need it.
sharedResetTime.Store(newResetTime)
return sharedResetTime
}
99 changes: 99 additions & 0 deletions github_ratelimit/primary_ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,102 @@ func TestConfigOverride(t *testing.T) {
t.Fatal("default callback was called instead of override")
}
}

type ConfigurableMockRoundTripper struct {
resp *http.Response
err error
}

func (m *ConfigurableMockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return m.resp, m.err
}

func TestPrimaryRateLimit_ResetOnSuccess(t *testing.T) {
mockTransport := &ConfigurableMockRoundTripper{}
limiter := github_primary_ratelimit.New(mockTransport)

// 1. Simulate Rate Limit
limitReset := time.Now().Add(1 * time.Hour)
limitResetEpoch := github_primary_ratelimit.SecondsSinceEpoch(limitReset.Unix())
targetCategory := "search"
targetCategoryResource := github_primary_ratelimit.ResourceCategorySearch

header := http.Header{}
header.Set("x-ratelimit-remaining", "0")
header.Set("x-ratelimit-reset", fmt.Sprintf("%d", int64(limitResetEpoch)))
header.Set("x-ratelimit-resource", targetCategory)

mockTransport.resp = &http.Response{
StatusCode: http.StatusForbidden,
Header: header,
}

req, _ := http.NewRequest("GET", "https://api.github.com/search?q=ratelimit", nil)
_, err := limiter.RoundTrip(req)
if err == nil {
t.Fatal("expected error on first call (rate limit trigger)")
}

// 2. Verify Blocked
// Set 200 OK to ensure the error comes from the limiter state, not the backend
mockTransport.resp = &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
}
req2, _ := http.NewRequest("GET", "https://api.github.com/search?q=ratelimit2", nil)
_, err = limiter.RoundTrip(req2)
if err == nil {
t.Fatal("expected error on second call (blocked)")
}
var rateLimitErr *github_primary_ratelimit.RateLimitReachedError
if !errors.As(err, &rateLimitErr) {
t.Fatalf("expected RateLimitReachedError, got %T: %v", err, err)
}
if rateLimitErr.Category != targetCategoryResource {
t.Fatalf("expected blocked category %v, got %v", targetCategoryResource, rateLimitErr.Category)
}

// 3. Probe with ResetOnSuccessfulCall
headerSuccess := http.Header{}
headerSuccess.Set("x-ratelimit-resource", targetCategory)
mockTransport.resp = &http.Response{
StatusCode: http.StatusOK,
Header: headerSuccess,
}

var unknownCategoryDetected atomic.Bool
callback := func(ctx *github_primary_ratelimit.CallbackContext) {
unknownCategoryDetected.Store(true)
t.Logf("Unknown Category Callback: %v", ctx.Category)
}

req3, _ := http.NewRequest("GET", "https://api.github.com/search?q=ratelimit3", nil)
ctx := github_primary_ratelimit.WithOverrideConfig(req3.Context(),
github_primary_ratelimit.WithBypassLimit(),
github_primary_ratelimit.WithResetOnSuccessfulCall(),
github_primary_ratelimit.WithUnknownCategoryCallback(callback),
)
req3 = req3.WithContext(ctx)

resp3, err := limiter.RoundTrip(req3)
if err != nil {
t.Fatalf("unexpected error on probe call: %v", err)
}
if resp3.StatusCode != http.StatusOK {
t.Fatalf("expected 200 OK on probe call, got %v", resp3.StatusCode)
}

if unknownCategoryDetected.Load() {
t.Fatal("unknown category detected during probe, reset likely failed")
}

// 4. Verify Unblocked
req4, _ := http.NewRequest("GET", "https://api.github.com/search?q=ratelimit4", nil)
resp4, err := limiter.RoundTrip(req4)
if err != nil {
t.Fatalf("unexpected error on verification call (should be unblocked): %v", err)
}
if resp4.StatusCode != http.StatusOK {
t.Fatalf("expected 200 OK on verification call, got %v", resp4.StatusCode)
}
}