ratelimit.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package discordgo
  2. import (
  3. "net/http"
  4. "strconv"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. )
  10. // customRateLimit holds information for defining a custom rate limit
  11. type customRateLimit struct {
  12. suffix string
  13. requests int
  14. reset time.Duration
  15. }
  16. // RateLimiter holds all ratelimit buckets
  17. type RateLimiter struct {
  18. sync.Mutex
  19. global *int64
  20. buckets map[string]*Bucket
  21. globalRateLimit time.Duration
  22. customRateLimits []*customRateLimit
  23. }
  24. // NewRatelimiter returns a new RateLimiter
  25. func NewRatelimiter() *RateLimiter {
  26. return &RateLimiter{
  27. buckets: make(map[string]*Bucket),
  28. global: new(int64),
  29. customRateLimits: []*customRateLimit{
  30. &customRateLimit{
  31. suffix: "//reactions//",
  32. requests: 1,
  33. reset: 200 * time.Millisecond,
  34. },
  35. },
  36. }
  37. }
  38. // GetBucket retrieves or creates a bucket
  39. func (r *RateLimiter) GetBucket(key string) *Bucket {
  40. r.Lock()
  41. defer r.Unlock()
  42. if bucket, ok := r.buckets[key]; ok {
  43. return bucket
  44. }
  45. b := &Bucket{
  46. Remaining: 1,
  47. Key: key,
  48. global: r.global,
  49. }
  50. // Check if there is a custom ratelimit set for this bucket ID.
  51. for _, rl := range r.customRateLimits {
  52. if strings.HasSuffix(b.Key, rl.suffix) {
  53. b.customRateLimit = rl
  54. break
  55. }
  56. }
  57. r.buckets[key] = b
  58. return b
  59. }
  60. // GetWaitTime returns the duration you should wait for a Bucket
  61. func (r *RateLimiter) GetWaitTime(b *Bucket, minRemaining int) time.Duration {
  62. // If we ran out of calls and the reset time is still ahead of us
  63. // then we need to take it easy and relax a little
  64. if b.Remaining < minRemaining && b.reset.After(time.Now()) {
  65. return b.reset.Sub(time.Now())
  66. }
  67. // Check for global ratelimits
  68. sleepTo := time.Unix(0, atomic.LoadInt64(r.global))
  69. if now := time.Now(); now.Before(sleepTo) {
  70. return sleepTo.Sub(now)
  71. }
  72. return 0
  73. }
  74. // LockBucket Locks until a request can be made
  75. func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
  76. return r.LockBucketObject(r.GetBucket(bucketID))
  77. }
  78. // LockBucketObject Locks an already resolved bucket until a request can be made
  79. func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket {
  80. b.Lock()
  81. if wait := r.GetWaitTime(b, 1); wait > 0 {
  82. time.Sleep(wait)
  83. }
  84. b.Remaining--
  85. return b
  86. }
  87. // Bucket represents a ratelimit bucket, each bucket gets ratelimited individually (-global ratelimits)
  88. type Bucket struct {
  89. sync.Mutex
  90. Key string
  91. Remaining int
  92. limit int
  93. reset time.Time
  94. global *int64
  95. lastReset time.Time
  96. customRateLimit *customRateLimit
  97. Userdata interface{}
  98. }
  99. // Release unlocks the bucket and reads the headers to update the buckets ratelimit info
  100. // and locks up the whole thing in case if there's a global ratelimit.
  101. func (b *Bucket) Release(headers http.Header) error {
  102. defer b.Unlock()
  103. // Check if the bucket uses a custom ratelimiter
  104. if rl := b.customRateLimit; rl != nil {
  105. if time.Now().Sub(b.lastReset) >= rl.reset {
  106. b.Remaining = rl.requests - 1
  107. b.lastReset = time.Now()
  108. }
  109. if b.Remaining < 1 {
  110. b.reset = time.Now().Add(rl.reset)
  111. }
  112. return nil
  113. }
  114. if headers == nil {
  115. return nil
  116. }
  117. remaining := headers.Get("X-RateLimit-Remaining")
  118. reset := headers.Get("X-RateLimit-Reset")
  119. global := headers.Get("X-RateLimit-Global")
  120. retryAfter := headers.Get("Retry-After")
  121. // Update global and per bucket reset time if the proper headers are available
  122. // If global is set, then it will block all buckets until after Retry-After
  123. // If Retry-After without global is provided it will use that for the new reset
  124. // time since it's more accurate than X-RateLimit-Reset.
  125. // If Retry-After after is not proided, it will update the reset time from X-RateLimit-Reset
  126. if retryAfter != "" {
  127. parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
  128. if err != nil {
  129. return err
  130. }
  131. resetAt := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
  132. // Lock either this single bucket or all buckets
  133. if global != "" {
  134. atomic.StoreInt64(b.global, resetAt.UnixNano())
  135. } else {
  136. b.reset = resetAt
  137. }
  138. } else if reset != "" {
  139. // Calculate the reset time by using the date header returned from discord
  140. discordTime, err := http.ParseTime(headers.Get("Date"))
  141. if err != nil {
  142. return err
  143. }
  144. unix, err := strconv.ParseInt(reset, 10, 64)
  145. if err != nil {
  146. return err
  147. }
  148. // Calculate the time until reset and add it to the current local time
  149. // some extra time is added because without it i still encountered 429's.
  150. // The added amount is the lowest amount that gave no 429's
  151. // in 1k requests
  152. delta := time.Unix(unix, 0).Sub(discordTime) + time.Millisecond*250
  153. b.reset = time.Now().Add(delta)
  154. }
  155. // Udpate remaining if header is present
  156. if remaining != "" {
  157. parsedRemaining, err := strconv.ParseInt(remaining, 10, 32)
  158. if err != nil {
  159. return err
  160. }
  161. b.Remaining = int(parsedRemaining)
  162. }
  163. return nil
  164. }