Przeglądaj źródła

Message tracking.

Chris Rhodes 9 lat temu
rodzic
commit
6b73b588ba
2 zmienionych plików z 46 dodań i 23 usunięć
  1. 36 16
      state.go
  2. 10 7
      wsapi.go

+ 36 - 16
state.go

@@ -50,6 +50,11 @@ func (s *State) GuildAdd(guild *Guild) error {
 	// If the guild exists, replace it.
 	for i, g := range s.Guilds {
 		if g.ID == guild.ID {
+			// Don't stomp on properties that don't come in updates.
+			guild.Members = g.Members
+			guild.Presences = g.Presences
+			guild.Channels = g.Channels
+			guild.VoiceStates = g.VoiceStates
 			s.Guilds[i] = guild
 			return nil
 		}
@@ -105,14 +110,15 @@ func (s *State) MemberAdd(member *Member) error {
 	if s == nil {
 		return nilError
 	}
-	s.Lock()
-	defer s.Unlock()
 
 	guild, err := s.Guild(member.GuildID)
 	if err != nil {
 		return err
 	}
 
+	s.Lock()
+	defer s.Unlock()
+
 	for i, m := range guild.Members {
 		if m.User.ID == member.User.ID {
 			guild.Members[i] = member
@@ -129,14 +135,15 @@ func (s *State) MemberRemove(member *Member) error {
 	if s == nil {
 		return nilError
 	}
-	s.Lock()
-	defer s.Unlock()
 
 	guild, err := s.Guild(member.GuildID)
 	if err != nil {
 		return err
 	}
 
+	s.Lock()
+	defer s.Unlock()
+
 	for i, m := range guild.Members {
 		if m.User.ID == member.User.ID {
 			guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
@@ -152,14 +159,15 @@ func (s *State) Member(guildID, userID string) (*Member, error) {
 	if s == nil {
 		return nil, nilError
 	}
-	s.RLock()
-	defer s.RUnlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
 		return nil, err
 	}
 
+	s.RLock()
+	defer s.RUnlock()
+
 	for _, m := range guild.Members {
 		if m.User.ID == userID {
 			return m, nil
@@ -177,13 +185,16 @@ func (s *State) ChannelAdd(channel *Channel) error {
 	if s == nil {
 		return nilError
 	}
-	s.Lock()
-	defer s.Unlock()
 
 	if channel.IsPrivate {
+		s.Lock()
+		defer s.Unlock()
+
 		// If the channel exists, replace it.
 		for i, c := range s.PrivateChannels {
 			if c.ID == channel.ID {
+				// Don't stomp on messages.
+				channel.Messages = c.Messages
 				s.PrivateChannels[i] = channel
 				return nil
 			}
@@ -196,9 +207,14 @@ func (s *State) ChannelAdd(channel *Channel) error {
 			return err
 		}
 
+		s.Lock()
+		defer s.Unlock()
+
 		// If the channel exists, replace it.
 		for i, c := range guild.Channels {
 			if c.ID == channel.ID {
+				// Don't stomp on messages.
+				channel.Messages = c.Messages
 				guild.Channels[i] = channel
 				return nil
 			}
@@ -215,8 +231,6 @@ func (s *State) ChannelRemove(channel *Channel) error {
 	if s == nil {
 		return nilError
 	}
-	s.Lock()
-	defer s.Unlock()
 
 	if channel.IsPrivate {
 		for i, c := range s.PrivateChannels {
@@ -231,6 +245,9 @@ func (s *State) ChannelRemove(channel *Channel) error {
 			return err
 		}
 
+		s.Lock()
+		defer s.Unlock()
+
 		for i, c := range guild.Channels {
 			if c.ID == channel.ID {
 				guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
@@ -247,14 +264,15 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
 	if s == nil {
 		return nil, nilError
 	}
-	s.RLock()
-	defer s.RUnlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
 		return nil, err
 	}
 
+	s.RLock()
+	defer s.RUnlock()
+
 	for _, c := range guild.Channels {
 		if c.ID == channelID {
 			return c, nil
@@ -307,14 +325,15 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
 	if s == nil {
 		return nil, nilError
 	}
-	s.RLock()
-	defer s.RUnlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
 		return nil, err
 	}
 
+	s.RLock()
+	defer s.RUnlock()
+
 	for _, e := range guild.Emojis {
 		if e.ID == emojiID {
 			return e, nil
@@ -329,14 +348,15 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
 	if s == nil {
 		return nilError
 	}
-	s.Lock()
-	defer s.Unlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
 		return err
 	}
 
+	s.Lock()
+	defer s.Unlock()
+
 	for i, e := range guild.Emojis {
 		if e.ID == emoji.ID {
 			guild.Emojis[i] = emoji

+ 10 - 7
wsapi.go

@@ -251,13 +251,14 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 			}
 		*/
 	case "MESSAGE_CREATE":
-		if !s.StateEnabled && s.OnMessageCreate == nil {
+		stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0
+		if !stateEnabled && s.OnMessageCreate == nil {
 			break
 		}
 		var st *Message
 		if err = unmarshalEvent(e, &st); err == nil {
-			if s.StateEnabled {
-				fmt.Println(s.State.MessageAdd(st))
+			if stateEnabled {
+				s.State.MessageAdd(st)
 			}
 			if s.OnMessageCreate != nil {
 				s.OnMessageCreate(s, st)
@@ -267,12 +268,13 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 			return
 		}
 	case "MESSAGE_UPDATE":
-		if !s.StateEnabled && s.OnMessageUpdate == nil {
+		stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0
+		if !stateEnabled && s.OnMessageUpdate == nil {
 			break
 		}
 		var st *Message
 		if err = unmarshalEvent(e, &st); err == nil {
-			if s.StateEnabled {
+			if stateEnabled {
 				s.State.MessageAdd(st)
 			}
 			if s.OnMessageUpdate != nil {
@@ -284,12 +286,13 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 			return
 		}
 	case "MESSAGE_DELETE":
-		if !s.StateEnabled && s.OnMessageDelete == nil {
+		stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0
+		if !stateEnabled && s.OnMessageDelete == nil {
 			break
 		}
 		var st *Message
 		if err = unmarshalEvent(e, &st); err == nil {
-			if s.StateEnabled {
+			if stateEnabled {
 				s.State.MessageRemove(st)
 			}
 			if s.OnMessageDelete != nil {