ratelimit.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package discordgo
  2. import (
  3. "net/http"
  4. "strconv"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. // RateLimiter holds all ratelimit buckets
  10. type RateLimiter struct {
  11. sync.Mutex
  12. global *int64
  13. buckets map[string]*Bucket
  14. globalRateLimit time.Duration
  15. }
  16. // NewRatelimiter returns a new RateLimiter
  17. func NewRatelimiter() *RateLimiter {
  18. return &RateLimiter{
  19. buckets: make(map[string]*Bucket),
  20. global: new(int64),
  21. }
  22. }
  23. // getBucket retrieves or creates a bucket
  24. func (r *RateLimiter) getBucket(key string) *Bucket {
  25. r.Lock()
  26. defer r.Unlock()
  27. if bucket, ok := r.buckets[key]; ok {
  28. return bucket
  29. }
  30. b := &Bucket{
  31. remaining: 1,
  32. Key: key,
  33. global: r.global,
  34. }
  35. r.buckets[key] = b
  36. return b
  37. }
  38. // LockBucket Locks until a request can be made
  39. func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
  40. b := r.getBucket(bucketID)
  41. b.Lock()
  42. // If we ran out of calls and the reset time is still ahead of us
  43. // then we need to take it easy and relax a little
  44. if b.remaining < 1 && b.reset.After(time.Now()) {
  45. time.Sleep(b.reset.Sub(time.Now()))
  46. }
  47. // Check for global ratelimits
  48. sleepTo := time.Unix(0, atomic.LoadInt64(r.global))
  49. if now := time.Now(); now.Before(sleepTo) {
  50. time.Sleep(sleepTo.Sub(now))
  51. }
  52. b.remaining--
  53. return b
  54. }
  55. // Bucket represents a ratelimit bucket, each bucket gets ratelimited individually (-global ratelimits)
  56. type Bucket struct {
  57. sync.Mutex
  58. Key string
  59. remaining int
  60. limit int
  61. reset time.Time
  62. global *int64
  63. }
  64. // Release unlocks the bucket and reads the headers to update the buckets ratelimit info
  65. // and locks up the whole thing in case if there's a global ratelimit.
  66. func (b *Bucket) Release(headers http.Header) error {
  67. defer b.Unlock()
  68. if headers == nil {
  69. return nil
  70. }
  71. remaining := headers.Get("X-RateLimit-Remaining")
  72. reset := headers.Get("X-RateLimit-Reset")
  73. global := headers.Get("X-RateLimit-Global")
  74. retryAfter := headers.Get("Retry-After")
  75. // Update global and per bucket reset time if the proper headers are available
  76. // If global is set, then it will block all buckets until after Retry-After
  77. // If Retry-After without global is provided it will use that for the new reset
  78. // time since it's more accurate than X-RateLimit-Reset.
  79. // If Retry-After after is not proided, it will update the reset time from X-RateLimit-Reset
  80. if retryAfter != "" {
  81. parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
  82. if err != nil {
  83. return err
  84. }
  85. resetAt := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
  86. // Lock either this single bucket or all buckets
  87. if global != "" {
  88. atomic.StoreInt64(b.global, resetAt.UnixNano())
  89. } else {
  90. b.reset = resetAt
  91. }
  92. } else if reset != "" {
  93. // Calculate the reset time by using the date header returned from discord
  94. discordTime, err := http.ParseTime(headers.Get("Date"))
  95. if err != nil {
  96. return err
  97. }
  98. unix, err := strconv.ParseInt(reset, 10, 64)
  99. if err != nil {
  100. return err
  101. }
  102. // Calculate the time until reset and add it to the current local time
  103. // some extra time is added because without it i still encountered 429's.
  104. // The added amount is the lowest amount that gave no 429's
  105. // in 1k requests
  106. delta := time.Unix(unix, 0).Sub(discordTime) + time.Millisecond*250
  107. b.reset = time.Now().Add(delta)
  108. }
  109. // Udpate remaining if header is present
  110. if remaining != "" {
  111. parsedRemaining, err := strconv.ParseInt(remaining, 10, 32)
  112. if err != nil {
  113. return err
  114. }
  115. b.remaining = int(parsedRemaining)
  116. }
  117. return nil
  118. }