|
@@ -4,13 +4,14 @@ import (
|
|
"net/http"
|
|
"net/http"
|
|
"strconv"
|
|
"strconv"
|
|
"sync"
|
|
"sync"
|
|
|
|
+ "sync/atomic"
|
|
"time"
|
|
"time"
|
|
)
|
|
)
|
|
|
|
|
|
// RateLimiter holds all ratelimit buckets
|
|
// RateLimiter holds all ratelimit buckets
|
|
type RateLimiter struct {
|
|
type RateLimiter struct {
|
|
sync.Mutex
|
|
sync.Mutex
|
|
- global *Bucket
|
|
|
|
|
|
+ global *int64
|
|
buckets map[string]*Bucket
|
|
buckets map[string]*Bucket
|
|
globalRateLimit time.Duration
|
|
globalRateLimit time.Duration
|
|
}
|
|
}
|
|
@@ -20,7 +21,7 @@ func NewRatelimiter() *RateLimiter {
|
|
|
|
|
|
return &RateLimiter{
|
|
return &RateLimiter{
|
|
buckets: make(map[string]*Bucket),
|
|
buckets: make(map[string]*Bucket),
|
|
- global: &Bucket{Key: "global"},
|
|
|
|
|
|
+ global: new(int64),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -58,8 +59,10 @@ func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
|
|
}
|
|
}
|
|
|
|
|
|
// Check for global ratelimits
|
|
// Check for global ratelimits
|
|
- r.global.Lock()
|
|
|
|
- r.global.Unlock()
|
|
|
|
|
|
+ sleepTo := time.Unix(0, atomic.LoadInt64(r.global))
|
|
|
|
+ if now := time.Now(); now.Before(sleepTo) {
|
|
|
|
+ time.Sleep(sleepTo.Sub(now))
|
|
|
|
+ }
|
|
|
|
|
|
b.remaining--
|
|
b.remaining--
|
|
return b
|
|
return b
|
|
@@ -72,7 +75,7 @@ type Bucket struct {
|
|
remaining int
|
|
remaining int
|
|
limit int
|
|
limit int
|
|
reset time.Time
|
|
reset time.Time
|
|
- global *Bucket
|
|
|
|
|
|
+ global *int64
|
|
}
|
|
}
|
|
|
|
|
|
// Release unlocks the bucket and reads the headers to update the buckets ratelimit info
|
|
// Release unlocks the bucket and reads the headers to update the buckets ratelimit info
|
|
@@ -89,41 +92,25 @@ func (b *Bucket) Release(headers http.Header) error {
|
|
global := headers.Get("X-RateLimit-Global")
|
|
global := headers.Get("X-RateLimit-Global")
|
|
retryAfter := headers.Get("Retry-After")
|
|
retryAfter := headers.Get("Retry-After")
|
|
|
|
|
|
- // If it's global just keep the main ratelimit mutex locked
|
|
|
|
- if global != "" {
|
|
|
|
- parsedAfter, err := strconv.Atoi(retryAfter)
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // Lock it in a new goroutine so that this isn't a blocking call
|
|
|
|
- go func() {
|
|
|
|
- // Make sure if several requests were waiting we don't sleep for n * retry-after
|
|
|
|
- // where n is the amount of requests that were going on
|
|
|
|
- sleepTo := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
|
|
|
|
-
|
|
|
|
- b.global.Lock()
|
|
|
|
-
|
|
|
|
- sleepDuration := sleepTo.Sub(time.Now())
|
|
|
|
- if sleepDuration > 0 {
|
|
|
|
- time.Sleep(sleepDuration)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- b.global.Unlock()
|
|
|
|
- }()
|
|
|
|
-
|
|
|
|
- return nil
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // Update reset time if either retry after or reset headers are present
|
|
|
|
- // Prefer retryafter because it's more accurate with time sync and whatnot
|
|
|
|
|
|
+ // Update global and per bucket reset time if the proper headers are available
|
|
|
|
+ // If global is set, then it will block all buckets until after Retry-After
|
|
|
|
+ // If Retry-After without global is provided it will use that for the new reset
|
|
|
|
+ // time since it's more accurate than X-RateLimit-Reset.
|
|
|
|
+ // If Retry-After after is not proided, it will update the reset time from X-RateLimit-Reset
|
|
if retryAfter != "" {
|
|
if retryAfter != "" {
|
|
parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
|
|
parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
- b.reset = time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
|
|
|
|
|
|
|
|
|
|
+ resetAt := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
|
|
|
|
+
|
|
|
|
+ // Lock either this single bucket or all buckets
|
|
|
|
+ if global != "" {
|
|
|
|
+ atomic.StoreInt64(b.global, resetAt.UnixNano())
|
|
|
|
+ } else {
|
|
|
|
+ b.reset = resetAt
|
|
|
|
+ }
|
|
} else if reset != "" {
|
|
} else if reset != "" {
|
|
// Calculate the reset time by using the date header returned from discord
|
|
// Calculate the reset time by using the date header returned from discord
|
|
discordTime, err := http.ParseTime(headers.Get("Date"))
|
|
discordTime, err := http.ParseTime(headers.Get("Date"))
|