Procházet zdrojové kódy

Merge pull request #163 from iopred/develop

Add fast lookups for guilds and channels in state.
Bruce před 8 roky
rodič
revize
38103b6061
2 změnil soubory, kde provedl 63 přidání a 86 odebrání
  1. 60 86
      state.go
  2. 3 0
      structs.go

+ 60 - 86
state.go

@@ -24,6 +24,8 @@ func NewState() *State {
 			PrivateChannels: []*Channel{},
 			Guilds:          []*Guild{},
 		},
+		guildMap:   make(map[string]*Guild),
+		channelMap: make(map[string]*Channel),
 	}
 }
 
@@ -42,6 +44,8 @@ func (s *State) OnReady(r *Ready) error {
 		for _, c := range g.Channels {
 			c.GuildID = g.ID
 		}
+
+		s.guildMap[g.ID] = g
 	}
 
 	return nil
@@ -54,30 +58,34 @@ func (s *State) GuildAdd(guild *Guild) error {
 		return ErrNilState
 	}
 
-	s.Lock()
-	defer s.Unlock()
-
-	// Otherwise, update the channels to point to the right guild
+	// Update the channels to point to the right guild
 	for _, c := range guild.Channels {
 		c.GuildID = guild.ID
 	}
 
 	// If the guild exists, replace it.
-	for i, g := range s.Guilds {
-		if g.ID == guild.ID {
-			// If this guild already exists with data, don't stomp on props
-			if g.Unavailable != nil && !*g.Unavailable {
-				guild.Members = g.Members
-				guild.Presences = g.Presences
-				guild.Channels = g.Channels
-				guild.VoiceStates = g.VoiceStates
-			}
-			s.Guilds[i] = guild
-			return nil
+	if g, err := s.Guild(guild.ID); err == nil {
+		s.Lock()
+		defer s.Unlock()
+
+		// If this guild already exists with data, don't stomp on props.
+		if g.Unavailable != nil && !*g.Unavailable {
+			guild.Members = g.Members
+			guild.Presences = g.Presences
+			guild.Channels = g.Channels
+			guild.VoiceStates = g.VoiceStates
 		}
+
+		*g = *guild
+		return nil
 	}
 
+	s.Lock()
+	defer s.Unlock()
+
 	s.Guilds = append(s.Guilds, guild)
+	s.guildMap[guild.ID] = guild
+
 	return nil
 }
 
@@ -87,6 +95,11 @@ func (s *State) GuildRemove(guild *Guild) error {
 		return ErrNilState
 	}
 
+	_, err := s.Guild(guild.ID)
+	if err != nil {
+		return err
+	}
+
 	s.Lock()
 	defer s.Unlock()
 
@@ -97,7 +110,9 @@ func (s *State) GuildRemove(guild *Guild) error {
 		}
 	}
 
-	return errors.New("Guild not found.")
+	delete(s.guildMap, guild.ID)
+
+	return nil
 }
 
 // Guild gets a guild by ID.
@@ -112,10 +127,8 @@ func (s *State) Guild(guildID string) (*Guild, error) {
 	s.RLock()
 	defer s.RUnlock()
 
-	for _, g := range s.Guilds {
-		if g.ID == guildID {
-			return g, nil
-		}
+	if g, ok := s.guildMap[guildID]; ok {
+		return g, nil
 	}
 
 	return nil, errors.New("Guild not found.")
@@ -205,20 +218,22 @@ func (s *State) ChannelAdd(channel *Channel) error {
 		return ErrNilState
 	}
 
-	if channel.IsPrivate {
+	// If the channel exists, replace it.
+	if c, err := s.Channel(channel.ID); err == nil {
 		s.Lock()
 		defer s.Unlock()
 
-		// If the channel exists, replace it.
-		for i, c := range s.PrivateChannels {
-			if c.ID == channel.ID {
-				// Don't stomp on messages.
-				channel.Messages = c.Messages
-				s.PrivateChannels[i] = channel
-				return nil
-			}
-		}
+		channel.Messages = c.Messages
+		channel.PermissionOverwrites = c.PermissionOverwrites
 
+		*c = *channel
+		return nil
+	}
+
+	s.Lock()
+	defer s.Unlock()
+
+	if channel.IsPrivate {
 		s.PrivateChannels = append(s.PrivateChannels, channel)
 	} else {
 		guild, err := s.Guild(channel.GuildID)
@@ -226,22 +241,11 @@ func (s *State) ChannelAdd(channel *Channel) error {
 			return err
 		}
 
-		s.Lock()
-		defer s.Unlock()
-
-		// If the channel exists, replace it.
-		for i, c := range guild.Channels {
-			if c.ID == channel.ID {
-				// Don't stomp on messages.
-				channel.Messages = c.Messages
-				guild.Channels[i] = channel
-				return nil
-			}
-		}
-
 		guild.Channels = append(guild.Channels, channel)
 	}
 
+	s.channelMap[channel.ID] = channel
+
 	return nil
 }
 
@@ -251,6 +255,11 @@ func (s *State) ChannelRemove(channel *Channel) error {
 		return ErrNilState
 	}
 
+	_, err := s.Channel(channel.ID)
+	if err != nil {
+		return err
+	}
+
 	if channel.IsPrivate {
 		s.Lock()
 		defer s.Unlock()
@@ -278,48 +287,21 @@ func (s *State) ChannelRemove(channel *Channel) error {
 		}
 	}
 
-	return errors.New("Channel not found.")
+	delete(s.channelMap, channel.ID)
+
+	return nil
 }
 
 // GuildChannel gets a channel by ID from a guild.
+// This method is Deprecated, use Channel(channelID)
 func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
-	if s == nil {
-		return nil, ErrNilState
-	}
-
-	guild, err := s.Guild(guildID)
-	if err != nil {
-		return nil, err
-	}
-
-	s.RLock()
-	defer s.RUnlock()
-
-	for _, c := range guild.Channels {
-		if c.ID == channelID {
-			return c, nil
-		}
-	}
-
-	return nil, errors.New("Channel not found.")
+	return s.Channel(channelID)
 }
 
 // PrivateChannel gets a private channel by ID.
+// This method is Deprecated, use Channel(channelID)
 func (s *State) PrivateChannel(channelID string) (*Channel, error) {
-	if s == nil {
-		return nil, ErrNilState
-	}
-
-	s.RLock()
-	defer s.RUnlock()
-
-	for _, c := range s.PrivateChannels {
-		if c.ID == channelID {
-			return c, nil
-		}
-	}
-
-	return nil, errors.New("Channel not found.")
+	return s.Channel(channelID)
 }
 
 // Channel gets a channel by ID, it will look in all guilds an private channels.
@@ -328,18 +310,10 @@ func (s *State) Channel(channelID string) (*Channel, error) {
 		return nil, ErrNilState
 	}
 
-	c, err := s.PrivateChannel(channelID)
-	if err == nil {
+	if c, ok := s.channelMap[channelID]; ok {
 		return c, nil
 	}
 
-	for _, g := range s.Guilds {
-		c, err := s.GuildChannel(g.ID, channelID)
-		if err == nil {
-			return c, nil
-		}
-	}
-
 	return nil, errors.New("Channel not found.")
 }
 

+ 3 - 0
structs.go

@@ -369,6 +369,9 @@ type State struct {
 	sync.RWMutex
 	Ready
 	MaxMessageCount int
+
+	guildMap   map[string]*Guild
+	channelMap map[string]*Channel
 }
 
 // Constants for the different bit offsets of text channel permissions