Browse Source

Support millisecond precision in rate limits

Carson Hoffman 4 years ago
parent
commit
866ecccb2e
3 changed files with 30 additions and 8 deletions
  1. 9 6
      ratelimit.go
  2. 3 2
      ratelimit_test.go
  3. 18 0
      structs.go

+ 9 - 6
ratelimit.go

@@ -1,6 +1,7 @@
 package discordgo
 
 import (
+	"math"
 	"net/http"
 	"strconv"
 	"strings"
@@ -140,20 +141,21 @@ func (b *Bucket) Release(headers http.Header) error {
 	remaining := headers.Get("X-RateLimit-Remaining")
 	reset := headers.Get("X-RateLimit-Reset")
 	global := headers.Get("X-RateLimit-Global")
-	retryAfter := headers.Get("Retry-After")
+	resetAfter := headers.Get("X-RateLimit-Reset-After")
 
 	// Update global and per bucket reset time if the proper headers are available
 	// If global is set, then it will block all buckets until after Retry-After
 	// If Retry-After without global is provided it will use that for the new reset
 	// time since it's more accurate than X-RateLimit-Reset.
 	// If Retry-After after is not proided, it will update the reset time from X-RateLimit-Reset
-	if retryAfter != "" {
-		parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
+	if resetAfter != "" {
+		parsedAfter, err := strconv.ParseFloat(resetAfter, 64)
 		if err != nil {
 			return err
 		}
 
-		resetAt := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
+		whole, frac := math.Modf(parsedAfter)
+		resetAt := time.Now().Add(time.Duration(whole) * time.Second).Add(time.Duration(frac*1000) * time.Millisecond)
 
 		// Lock either this single bucket or all buckets
 		if global != "" {
@@ -168,7 +170,7 @@ func (b *Bucket) Release(headers http.Header) error {
 			return err
 		}
 
-		unix, err := strconv.ParseInt(reset, 10, 64)
+		unix, err := strconv.ParseFloat(reset, 64)
 		if err != nil {
 			return err
 		}
@@ -177,7 +179,8 @@ func (b *Bucket) Release(headers http.Header) error {
 		// some extra time is added because without it i still encountered 429's.
 		// The added amount is the lowest amount that gave no 429's
 		// in 1k requests
-		delta := time.Unix(unix, 0).Sub(discordTime) + time.Millisecond*250
+		whole, frac := math.Modf(unix)
+		delta := time.Unix(int64(whole), 0).Add(time.Duration(frac*1000)*time.Millisecond).Sub(discordTime) + time.Millisecond*250
 		b.reset = time.Now().Add(delta)
 	}
 

+ 3 - 2
ratelimit_test.go

@@ -1,6 +1,7 @@
 package discordgo
 
 import (
+	"fmt"
 	"net/http"
 	"strconv"
 	"testing"
@@ -18,7 +19,7 @@ func TestRatelimitReset(t *testing.T) {
 
 		headers.Set("X-RateLimit-Remaining", "0")
 		// Reset for approx 2 seconds from now
-		headers.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Second*2).Unix(), 10))
+		headers.Set("X-RateLimit-Reset", fmt.Sprint(float64(time.Now().Add(time.Second*2).UnixNano())/1e6))
 		headers.Set("Date", time.Now().Format(time.RFC850))
 
 		err := bucket.Release(headers)
@@ -105,7 +106,7 @@ func sendBenchReq(endpoint string, rl *RateLimiter) {
 	headers := http.Header(make(map[string][]string))
 
 	headers.Set("X-RateLimit-Remaining", "10")
-	headers.Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Unix(), 10))
+	headers.Set("X-RateLimit-Reset", fmt.Sprint(float64(time.Now().UnixNano())/1e6))
 	headers.Set("Date", time.Now().Format(time.RFC850))
 
 	bucket.Release(headers)

+ 18 - 0
structs.go

@@ -845,6 +845,24 @@ type TooManyRequests struct {
 	RetryAfter time.Duration `json:"retry_after"`
 }
 
+func (t *TooManyRequests) UnmarshalJSON(b []byte) error {
+	u := struct {
+		Bucket     string  `json:"bucket"`
+		Message    string  `json:"message"`
+		RetryAfter float64 `json:"retry_after"`
+	}{}
+	err := json.Unmarshal(b, &u)
+	if err != nil {
+		return err
+	}
+
+	t.Bucket = u.Bucket
+	t.Message = u.Message
+	whole, frac := math.Modf(u.RetryAfter)
+	t.RetryAfter = time.Duration(whole)*time.Second + time.Duration(frac*1000)*time.Millisecond
+	return nil
+}
+
 // A ReadState stores data on the read state of channels.
 type ReadState struct {
 	MentionCount  int    `json:"mention_count"`