|
@@ -4,13 +4,14 @@ import (
|
|
|
"net/http"
|
|
|
"strconv"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
|
|
|
type RateLimiter struct {
|
|
|
sync.Mutex
|
|
|
- global *Bucket
|
|
|
+ global *int64
|
|
|
buckets map[string]*Bucket
|
|
|
globalRateLimit time.Duration
|
|
|
}
|
|
@@ -20,7 +21,7 @@ func NewRatelimiter() *RateLimiter {
|
|
|
|
|
|
return &RateLimiter{
|
|
|
buckets: make(map[string]*Bucket),
|
|
|
- global: &Bucket{Key: "global"},
|
|
|
+ global: new(int64),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -58,8 +59,10 @@ func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
|
|
|
}
|
|
|
|
|
|
|
|
|
- r.global.Lock()
|
|
|
- r.global.Unlock()
|
|
|
+ sleepTo := time.Unix(0, atomic.LoadInt64(r.global))
|
|
|
+ if now := time.Now(); now.Before(sleepTo) {
|
|
|
+ time.Sleep(sleepTo.Sub(now))
|
|
|
+ }
|
|
|
|
|
|
b.remaining--
|
|
|
return b
|
|
@@ -72,7 +75,7 @@ type Bucket struct {
|
|
|
remaining int
|
|
|
limit int
|
|
|
reset time.Time
|
|
|
- global *Bucket
|
|
|
+ global *int64
|
|
|
}
|
|
|
|
|
|
|
|
@@ -89,41 +92,25 @@ func (b *Bucket) Release(headers http.Header) error {
|
|
|
global := headers.Get("X-RateLimit-Global")
|
|
|
retryAfter := headers.Get("Retry-After")
|
|
|
|
|
|
-
|
|
|
- if global != "" {
|
|
|
- parsedAfter, err := strconv.Atoi(retryAfter)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- go func() {
|
|
|
-
|
|
|
-
|
|
|
- sleepTo := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
|
|
|
-
|
|
|
- b.global.Lock()
|
|
|
-
|
|
|
- sleepDuration := sleepTo.Sub(time.Now())
|
|
|
- if sleepDuration > 0 {
|
|
|
- time.Sleep(sleepDuration)
|
|
|
- }
|
|
|
-
|
|
|
- b.global.Unlock()
|
|
|
- }()
|
|
|
-
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
if retryAfter != "" {
|
|
|
parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- b.reset = time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
|
|
|
|
|
|
+ resetAt := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
|
|
|
+
|
|
|
+
|
|
|
+ if global != "" {
|
|
|
+ atomic.StoreInt64(b.global, resetAt.UnixNano())
|
|
|
+ } else {
|
|
|
+ b.reset = resetAt
|
|
|
+ }
|
|
|
} else if reset != "" {
|
|
|
|
|
|
discordTime, err := http.ParseTime(headers.Get("Date"))
|