ratelimit.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. package discordgo
  2. import (
  3. "net/http"
  4. "strconv"
  5. "sync"
  6. "time"
  7. )
  8. // RateLimiter holds all ratelimit buckets
  9. type RateLimiter struct {
  10. sync.Mutex
  11. global *Bucket
  12. buckets map[string]*Bucket
  13. globalRateLimit time.Duration
  14. }
  15. // NewRatelimiter returns a new RateLimiter
  16. func NewRatelimiter() *RateLimiter {
  17. return &RateLimiter{
  18. buckets: make(map[string]*Bucket),
  19. global: &Bucket{Key: "global"},
  20. }
  21. }
  22. // getBucket retrieves or creates a bucket
  23. func (r *RateLimiter) getBucket(key string) *Bucket {
  24. r.Lock()
  25. defer r.Unlock()
  26. if bucket, ok := r.buckets[key]; ok {
  27. return bucket
  28. }
  29. b := &Bucket{
  30. remaining: 1,
  31. Key: key,
  32. global: r.global,
  33. }
  34. r.buckets[key] = b
  35. return b
  36. }
  37. // LockBucket Locks until a request can be made
  38. func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
  39. b := r.getBucket(bucketID)
  40. b.Lock()
  41. // If we ran out of calls and the reset time is still ahead of us
  42. // then we need to take it easy and relax a little
  43. if b.remaining < 1 && b.reset.After(time.Now()) {
  44. time.Sleep(b.reset.Sub(time.Now()))
  45. }
  46. // Check for global ratelimits
  47. r.global.Lock()
  48. r.global.Unlock()
  49. b.remaining--
  50. return b
  51. }
  52. // Bucket represents a ratelimit bucket, each bucket gets ratelimited individually (-global ratelimits)
  53. type Bucket struct {
  54. sync.Mutex
  55. Key string
  56. remaining int
  57. limit int
  58. reset time.Time
  59. global *Bucket
  60. }
  61. // Release unlocks the bucket and reads the headers to update the buckets ratelimit info
  62. // and locks up the whole thing in case if there's a global ratelimit.
  63. func (b *Bucket) Release(headers http.Header) error {
  64. defer b.Unlock()
  65. if headers == nil {
  66. return nil
  67. }
  68. remaining := headers.Get("X-RateLimit-Remaining")
  69. reset := headers.Get("X-RateLimit-Reset")
  70. global := headers.Get("X-RateLimit-Global")
  71. retryAfter := headers.Get("Retry-After")
  72. // If it's global just keep the main ratelimit mutex locked
  73. if global != "" {
  74. parsedAfter, err := strconv.Atoi(retryAfter)
  75. if err != nil {
  76. return err
  77. }
  78. // Lock it in a new goroutine so that this isn't a blocking call
  79. go func() {
  80. // Make sure if several requests were waiting we don't sleep for n * retry-after
  81. // where n is the amount of requests that were going on
  82. sleepTo := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
  83. b.global.Lock()
  84. sleepDuration := sleepTo.Sub(time.Now())
  85. if sleepDuration > 0 {
  86. time.Sleep(sleepDuration)
  87. }
  88. b.global.Unlock()
  89. }()
  90. return nil
  91. }
  92. // Update reset time if either retry after or reset headers are present
  93. // Prefer retryafter because it's more accurate with time sync and whatnot
  94. if retryAfter != "" {
  95. parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
  96. if err != nil {
  97. return err
  98. }
  99. b.reset = time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
  100. } else if reset != "" {
  101. // Calculate the reset time by using the date header returned from discord
  102. discordTime, err := http.ParseTime(headers.Get("Date"))
  103. if err != nil {
  104. return err
  105. }
  106. unix, err := strconv.ParseInt(reset, 10, 64)
  107. if err != nil {
  108. return err
  109. }
  110. // Calculate the time until reset and add it to the current local time
  111. // some extra time is added because without it i still encountered 429's.
  112. // The added amount is the lowest amount that gave no 429's
  113. // in 1k requests
  114. delta := time.Unix(unix, 0).Sub(discordTime) + time.Millisecond*250
  115. b.reset = time.Now().Add(delta)
  116. }
  117. // Udpate remaining if header is present
  118. if remaining != "" {
  119. parsedRemaining, err := strconv.ParseInt(remaining, 10, 32)
  120. if err != nil {
  121. return err
  122. }
  123. b.remaining = int(parsedRemaining)
  124. }
  125. return nil
  126. }