Browse Source

Message state tracking.

Chris Rhodes 9 years ago
parent
commit
0f38b22ca1
4 changed files with 158 additions and 27 deletions
  1. 3 0
      message.go
  2. 112 8
      state.go
  3. 8 10
      structs.go
  4. 35 9
      wsapi.go

+ 3 - 0
message.go

@@ -70,6 +70,9 @@ type Embed struct {
 // ContentWithMentionsReplaced will replace all @<id> mentions with the
 // username of the mention.
 func (m *Message) ContentWithMentionsReplaced() string {
+	if m.Mentions == nil {
+		return m.Content
+	}
 	content := m.Content
 	for _, user := range m.Mentions {
 		content = strings.Replace(content, fmt.Sprintf("<@%s>", user.ID),

+ 112 - 8
state.go

@@ -31,6 +31,8 @@ func (s *State) OnReady(r *Ready) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	s.Ready = *r
 	return nil
@@ -42,16 +44,13 @@ func (s *State) GuildAdd(guild *Guild) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
-	for _, g := range s.Guilds {
+	// If the guild exists, replace it.
+	for i, g := range s.Guilds {
 		if g.ID == guild.ID {
-			// This could be a little faster ;)
-			for _, m := range guild.Members {
-				s.MemberAdd(m)
-			}
-			for _, c := range guild.Channels {
-				s.ChannelAdd(c)
-			}
+			s.Guilds[i] = guild
 			return nil
 		}
 	}
@@ -65,6 +64,8 @@ func (s *State) GuildRemove(guild *Guild) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	for i, g := range s.Guilds {
 		if g.ID == guild.ID {
@@ -84,6 +85,8 @@ func (s *State) Guild(guildID string) (*Guild, error) {
 	if s == nil {
 		return nil, nilError
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 	for _, g := range s.Guilds {
 		if g.ID == guildID {
@@ -102,6 +105,8 @@ func (s *State) MemberAdd(member *Member) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	guild, err := s.Guild(member.GuildID)
 	if err != nil {
@@ -124,6 +129,8 @@ func (s *State) MemberRemove(member *Member) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	guild, err := s.Guild(member.GuildID)
 	if err != nil {
@@ -145,6 +152,8 @@ func (s *State) Member(guildID, userID string) (*Member, error) {
 	if s == nil {
 		return nil, nilError
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
@@ -168,8 +177,11 @@ func (s *State) ChannelAdd(channel *Channel) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	if channel.IsPrivate {
+		// If the channel exists, replace it.
 		for i, c := range s.PrivateChannels {
 			if c.ID == channel.ID {
 				s.PrivateChannels[i] = channel
@@ -184,6 +196,7 @@ func (s *State) ChannelAdd(channel *Channel) error {
 			return err
 		}
 
+		// If the channel exists, replace it.
 		for i, c := range guild.Channels {
 			if c.ID == channel.ID {
 				guild.Channels[i] = channel
@@ -202,6 +215,8 @@ func (s *State) ChannelRemove(channel *Channel) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	if channel.IsPrivate {
 		for i, c := range s.PrivateChannels {
@@ -232,6 +247,8 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
 	if s == nil {
 		return nil, nilError
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
@@ -252,6 +269,8 @@ func (s *State) PrivateChannel(channelID string) (*Channel, error) {
 	if s == nil {
 		return nil, nilError
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 	for _, c := range s.PrivateChannels {
 		if c.ID == channelID {
@@ -288,6 +307,8 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
 	if s == nil {
 		return nil, nilError
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
@@ -308,6 +329,8 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
 	if s == nil {
 		return nilError
 	}
+	s.Lock()
+	defer s.Unlock()
 
 	guild, err := s.Guild(guildID)
 	if err != nil {
@@ -334,3 +357,84 @@ func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
 	}
 	return nil
 }
+
+// MessageAdd adds a message to the current world state, or updates it if it exists.
+// If the channel cannot be found, the message is discarded.
+// Messages are kept in state up to s.MaxMessageCount
+func (s *State) MessageAdd(message *Message) error {
+	if s == nil {
+		return nilError
+	}
+
+	c, err := s.Channel(message.ChannelID)
+	if err != nil {
+		return err
+	}
+
+	s.Lock()
+	defer s.Unlock()
+
+	// If the message exists, replace it.
+	for i, m := range c.Messages {
+		if m.ID == message.ID {
+			c.Messages[i] = message
+			return nil
+		}
+	}
+
+	c.Messages = append(c.Messages, message)
+
+	if len(c.Messages) > s.MaxMessageCount {
+		s.Unlock()
+		for len(c.Messages) > s.MaxMessageCount {
+			s.MessageRemove(c.Messages[0])
+		}
+		s.Lock()
+	}
+	return nil
+}
+
+// MessageRemove removes a message from the world state.
+func (s *State) MessageRemove(message *Message) error {
+	if s == nil {
+		return nilError
+	}
+	c, err := s.Channel(message.ChannelID)
+	if err != nil {
+		return err
+	}
+
+	s.Lock()
+	defer s.Unlock()
+
+	for i, m := range c.Messages {
+		if m.ID == message.ID {
+			c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
+			return nil
+		}
+	}
+
+	return errors.New("Message not found.")
+}
+
+// Message gets a message by channel and message ID.
+func (s *State) Message(channelID, messageID string) (*Message, error) {
+	if s == nil {
+		return nil, nilError
+	}
+	c, err := s.Channel(channelID)
+	if err != nil {
+		return nil, err
+	}
+
+	s.RLock()
+	defer s.RUnlock()
+
+	for _, m := range c.Messages {
+		if m.ID == messageID {
+			return m, nil
+		}
+	}
+
+	return nil, errors.New("Message not found.")
+}

+ 8 - 10
structs.go

@@ -33,7 +33,7 @@ type Session struct {
 	OnTypingStart             func(*Session, *TypingStart)
 	OnMessageCreate           func(*Session, *Message)
 	OnMessageUpdate           func(*Session, *Message)
-	OnMessageDelete           func(*Session, *MessageDelete)
+	OnMessageDelete           func(*Session, *Message)
 	OnMessageAck              func(*Session, *MessageAck)
 	OnUserUpdate              func(*Session, *User)
 	OnPresenceUpdate          func(*Session, *PresenceUpdate)
@@ -46,7 +46,7 @@ type Session struct {
 	OnGuildDelete             func(*Session, *Guild)
 	OnGuildMemberAdd          func(*Session, *Member)
 	OnGuildMemberRemove       func(*Session, *Member)
-	OnGuildMemberDelete       func(*Session, *Member) // which is it?
+	OnGuildMemberDelete       func(*Session, *Member)
 	OnGuildMemberUpdate       func(*Session, *Member)
 	OnGuildRoleCreate         func(*Session, *GuildRole)
 	OnGuildRoleUpdate         func(*Session, *GuildRole)
@@ -77,8 +77,9 @@ type Session struct {
 	Voice *Voice // Stores all details related to voice connections
 
 	// Managed state object, updated with events.
-	State        *State
-	StateEnabled bool
+	State                *State
+	StateEnabled         bool
+	StateMaxMessageCount int
 
 	// Mutex/Bools for locks that prevent accidents.
 	// TODO: Add channels.
@@ -138,6 +139,7 @@ type Channel struct {
 	IsPrivate            bool                   `json:"is_private"`
 	LastMessageID        string                 `json:"last_message_id"`
 	Recipient            *User                  `json:"recipient"`
+	Messages             []*Message             `json:"-"`
 }
 
 // A PermissionOverwrite holds permission overwrite data for a Channel
@@ -309,12 +311,6 @@ type MessageAck struct {
 	ChannelID string `json:"channel_id"`
 }
 
-// A MessageDelete stores data for the message delete websocket event.
-type MessageDelete struct {
-	ID        string `json:"id"`
-	ChannelID string `json:"channel_id"`
-} // so much like MessageAck..
-
 // A GuildIntegrationsUpdate stores data for the guild integrations update
 // websocket event.
 type GuildIntegrationsUpdate struct {
@@ -349,5 +345,7 @@ type GuildEmojisUpdate struct {
 // As discord sends this in a READY blob, it seems reasonable to simply
 // use that struct as the data store.
 type State struct {
+	sync.RWMutex
 	Ready
+	MaxMessageCount int
 }

+ 35 - 9
wsapi.go

@@ -251,27 +251,53 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 			}
 		*/
 	case "MESSAGE_CREATE":
-		if s.OnMessageCreate != nil {
-			var st *Message
-			if err = unmarshalEvent(e, &st); err == nil {
+		if !s.StateEnabled && s.OnMessageCreate == nil {
+			break
+		}
+		var st *Message
+		if err = unmarshalEvent(e, &st); err == nil {
+			if s.StateEnabled {
+				fmt.Println(s.State.MessageAdd(st))
+			}
+			if s.OnMessageCreate != nil {
 				s.OnMessageCreate(s, st)
 			}
+		}
+		if s.OnMessageCreate != nil {
 			return
 		}
 	case "MESSAGE_UPDATE":
-		if s.OnMessageUpdate != nil {
-			var st *Message
-			if err = unmarshalEvent(e, &st); err == nil {
+		if !s.StateEnabled && s.OnMessageUpdate == nil {
+			break
+		}
+		var st *Message
+		if err = unmarshalEvent(e, &st); err == nil {
+			if s.StateEnabled {
+				s.State.MessageAdd(st)
+			}
+			if s.OnMessageUpdate != nil {
 				s.OnMessageUpdate(s, st)
 			}
+		}
+		return
+		if s.OnMessageUpdate != nil {
 			return
 		}
 	case "MESSAGE_DELETE":
-		if s.OnMessageDelete != nil {
-			var st *MessageDelete
-			if err = unmarshalEvent(e, &st); err == nil {
+		if !s.StateEnabled && s.OnMessageDelete == nil {
+			break
+		}
+		var st *Message
+		if err = unmarshalEvent(e, &st); err == nil {
+			if s.StateEnabled {
+				s.State.MessageRemove(st)
+			}
+			if s.OnMessageDelete != nil {
 				s.OnMessageDelete(s, st)
 			}
+		}
+		return
+		if s.OnMessageDelete != nil {
 			return
 		}
 	case "MESSAGE_ACK":