ratelimit.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. package discordgo
  2. import (
  3. "math"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // customRateLimit holds information for defining a custom rate limit
  12. type customRateLimit struct {
  13. suffix string
  14. requests int
  15. reset time.Duration
  16. }
  17. // RateLimiter holds all ratelimit buckets
  18. type RateLimiter struct {
  19. sync.Mutex
  20. global *int64
  21. buckets map[string]*Bucket
  22. globalRateLimit time.Duration
  23. customRateLimits []*customRateLimit
  24. }
  25. // NewRatelimiter returns a new RateLimiter
  26. func NewRatelimiter() *RateLimiter {
  27. return &RateLimiter{
  28. buckets: make(map[string]*Bucket),
  29. global: new(int64),
  30. customRateLimits: []*customRateLimit{
  31. &customRateLimit{
  32. suffix: "//reactions//",
  33. requests: 1,
  34. reset: 200 * time.Millisecond,
  35. },
  36. },
  37. }
  38. }
  39. // GetBucket retrieves or creates a bucket
  40. func (r *RateLimiter) GetBucket(key string) *Bucket {
  41. r.Lock()
  42. defer r.Unlock()
  43. if bucket, ok := r.buckets[key]; ok {
  44. return bucket
  45. }
  46. b := &Bucket{
  47. Remaining: 1,
  48. Key: key,
  49. global: r.global,
  50. }
  51. // Check if there is a custom ratelimit set for this bucket ID.
  52. for _, rl := range r.customRateLimits {
  53. if strings.HasSuffix(b.Key, rl.suffix) {
  54. b.customRateLimit = rl
  55. break
  56. }
  57. }
  58. r.buckets[key] = b
  59. return b
  60. }
  61. // GetWaitTime returns the duration you should wait for a Bucket
  62. func (r *RateLimiter) GetWaitTime(b *Bucket, minRemaining int) time.Duration {
  63. // If we ran out of calls and the reset time is still ahead of us
  64. // then we need to take it easy and relax a little
  65. if b.Remaining < minRemaining && b.reset.After(time.Now()) {
  66. return b.reset.Sub(time.Now())
  67. }
  68. // Check for global ratelimits
  69. sleepTo := time.Unix(0, atomic.LoadInt64(r.global))
  70. if now := time.Now(); now.Before(sleepTo) {
  71. return sleepTo.Sub(now)
  72. }
  73. return 0
  74. }
  75. // LockBucket Locks until a request can be made
  76. func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
  77. return r.LockBucketObject(r.GetBucket(bucketID))
  78. }
  79. // LockBucketObject Locks an already resolved bucket until a request can be made
  80. func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket {
  81. b.Lock()
  82. if wait := r.GetWaitTime(b, 1); wait > 0 {
  83. time.Sleep(wait)
  84. }
  85. b.Remaining--
  86. return b
  87. }
  88. // Bucket represents a ratelimit bucket, each bucket gets ratelimited individually (-global ratelimits)
  89. type Bucket struct {
  90. sync.Mutex
  91. Key string
  92. Remaining int
  93. limit int
  94. reset time.Time
  95. global *int64
  96. lastReset time.Time
  97. customRateLimit *customRateLimit
  98. Userdata interface{}
  99. }
  100. // Release unlocks the bucket and reads the headers to update the buckets ratelimit info
  101. // and locks up the whole thing in case if there's a global ratelimit.
  102. func (b *Bucket) Release(headers http.Header) error {
  103. defer b.Unlock()
  104. // Check if the bucket uses a custom ratelimiter
  105. if rl := b.customRateLimit; rl != nil {
  106. if time.Now().Sub(b.lastReset) >= rl.reset {
  107. b.Remaining = rl.requests - 1
  108. b.lastReset = time.Now()
  109. }
  110. if b.Remaining < 1 {
  111. b.reset = time.Now().Add(rl.reset)
  112. }
  113. return nil
  114. }
  115. if headers == nil {
  116. return nil
  117. }
  118. remaining := headers.Get("X-RateLimit-Remaining")
  119. reset := headers.Get("X-RateLimit-Reset")
  120. global := headers.Get("X-RateLimit-Global")
  121. resetAfter := headers.Get("X-RateLimit-Reset-After")
  122. // Update global and per bucket reset time if the proper headers are available
  123. // If global is set, then it will block all buckets until after Retry-After
  124. // If Retry-After without global is provided it will use that for the new reset
  125. // time since it's more accurate than X-RateLimit-Reset.
  126. // If Retry-After after is not proided, it will update the reset time from X-RateLimit-Reset
  127. if resetAfter != "" {
  128. parsedAfter, err := strconv.ParseFloat(resetAfter, 64)
  129. if err != nil {
  130. return err
  131. }
  132. whole, frac := math.Modf(parsedAfter)
  133. resetAt := time.Now().Add(time.Duration(whole) * time.Second).Add(time.Duration(frac*1000) * time.Millisecond)
  134. // Lock either this single bucket or all buckets
  135. if global != "" {
  136. atomic.StoreInt64(b.global, resetAt.UnixNano())
  137. } else {
  138. b.reset = resetAt
  139. }
  140. } else if reset != "" {
  141. // Calculate the reset time by using the date header returned from discord
  142. discordTime, err := http.ParseTime(headers.Get("Date"))
  143. if err != nil {
  144. return err
  145. }
  146. unix, err := strconv.ParseFloat(reset, 64)
  147. if err != nil {
  148. return err
  149. }
  150. // Calculate the time until reset and add it to the current local time
  151. // some extra time is added because without it i still encountered 429's.
  152. // The added amount is the lowest amount that gave no 429's
  153. // in 1k requests
  154. whole, frac := math.Modf(unix)
  155. delta := time.Unix(int64(whole), 0).Add(time.Duration(frac*1000)*time.Millisecond).Sub(discordTime) + time.Millisecond*250
  156. b.reset = time.Now().Add(delta)
  157. }
  158. // Udpate remaining if header is present
  159. if remaining != "" {
  160. parsedRemaining, err := strconv.ParseInt(remaining, 10, 32)
  161. if err != nil {
  162. return err
  163. }
  164. b.Remaining = int(parsedRemaining)
  165. }
  166. return nil
  167. }