123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- package discordgo
- import (
- "net/http"
- "strconv"
- "sync"
- "time"
- )
- type RateLimiter struct {
- sync.Mutex
- global *Bucket
- buckets map[string]*Bucket
- globalRateLimit time.Duration
- }
- func NewRatelimiter() *RateLimiter {
- return &RateLimiter{
- buckets: make(map[string]*Bucket),
- global: &Bucket{Key: "global"},
- }
- }
- func (r *RateLimiter) getBucket(key string) *Bucket {
- r.Lock()
- defer r.Unlock()
- if bucket, ok := r.buckets[key]; ok {
- return bucket
- }
- b := &Bucket{
- remaining: 1,
- Key: key,
- global: r.global,
- }
- r.buckets[key] = b
- return b
- }
- func (r *RateLimiter) LockBucket(bucketID string) *Bucket {
- b := r.getBucket(bucketID)
- b.Lock()
-
-
- if b.remaining < 1 && b.reset.After(time.Now()) {
- time.Sleep(b.reset.Sub(time.Now()))
- }
-
- r.global.Lock()
- r.global.Unlock()
- b.remaining--
- return b
- }
- type Bucket struct {
- sync.Mutex
- Key string
- remaining int
- limit int
- reset time.Time
- global *Bucket
- }
- func (b *Bucket) Release(headers http.Header) error {
- defer b.Unlock()
- if headers == nil {
- return nil
- }
- remaining := headers.Get("X-RateLimit-Remaining")
- reset := headers.Get("X-RateLimit-Reset")
- global := headers.Get("X-RateLimit-Global")
- retryAfter := headers.Get("Retry-After")
-
- if global != "" {
- parsedAfter, err := strconv.Atoi(retryAfter)
- if err != nil {
- return err
- }
-
- go func() {
-
-
- sleepTo := time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
- b.global.Lock()
- sleepDuration := sleepTo.Sub(time.Now())
- if sleepDuration > 0 {
- time.Sleep(sleepDuration)
- }
- b.global.Unlock()
- }()
- return nil
- }
-
-
- if retryAfter != "" {
- parsedAfter, err := strconv.ParseInt(retryAfter, 10, 64)
- if err != nil {
- return err
- }
- b.reset = time.Now().Add(time.Duration(parsedAfter) * time.Millisecond)
- } else if reset != "" {
-
- discordTime, err := http.ParseTime(headers.Get("Date"))
- if err != nil {
- return err
- }
- unix, err := strconv.ParseInt(reset, 10, 64)
- if err != nil {
- return err
- }
-
-
-
-
- delta := time.Unix(unix, 0).Sub(discordTime) + time.Millisecond*250
- b.reset = time.Now().Add(delta)
- }
-
- if remaining != "" {
- parsedRemaining, err := strconv.ParseInt(remaining, 10, 32)
- if err != nil {
- return err
- }
- b.remaining = int(parsedRemaining)
- }
- return nil
- }
|