Browse Source

Merge pull request #90 from iopred/develop

Add support for message state.
Bruce 9 years ago
parent
commit
bf20ffffa8
4 changed files with 184 additions and 27 deletions
  1. 3 0
      message.go
  2. 135 8
      state.go
  3. 8 10
      structs.go
  4. 38 9
      wsapi.go

+ 3 - 0
message.go

@@ -70,6 +70,9 @@ type Embed struct {
 // ContentWithMentionsReplaced will replace all @<id> mentions with the
 // ContentWithMentionsReplaced will replace all @<id> mentions with the
 // username of the mention.
 // username of the mention.
 func (m *Message) ContentWithMentionsReplaced() string {
 func (m *Message) ContentWithMentionsReplaced() string {
+	if m.Mentions == nil {
+		return m.Content
+	}
 	content := m.Content
 	content := m.Content
 	for _, user := range m.Mentions {
 	for _, user := range m.Mentions {
 		content = strings.Replace(content, fmt.Sprintf("<@%s>", user.ID),
 		content = strings.Replace(content, fmt.Sprintf("<@%s>", user.ID),

+ 135 - 8
state.go

@@ -31,6 +31,8 @@ func (s *State) OnReady(r *Ready) error {
 	if s == nil {
 	if s == nil {
 		return nilError
 		return nilError
 	}
 	}
+	s.Lock()
+	defer s.Unlock()
 
 
 	s.Ready = *r
 	s.Ready = *r
 	return nil
 	return nil
@@ -42,16 +44,18 @@ func (s *State) GuildAdd(guild *Guild) error {
 	if s == nil {
 	if s == nil {
 		return nilError
 		return nilError
 	}
 	}
+	s.Lock()
+	defer s.Unlock()
 
 
-	for _, g := range s.Guilds {
+	// If the guild exists, replace it.
+	for i, g := range s.Guilds {
 		if g.ID == guild.ID {
 		if g.ID == guild.ID {
-			// This could be a little faster ;)
-			for _, m := range guild.Members {
-				s.MemberAdd(m)
-			}
-			for _, c := range guild.Channels {
-				s.ChannelAdd(c)
-			}
+			// Don't stomp on properties that don't come in updates.
+			guild.Members = g.Members
+			guild.Presences = g.Presences
+			guild.Channels = g.Channels
+			guild.VoiceStates = g.VoiceStates
+			s.Guilds[i] = guild
 			return nil
 			return nil
 		}
 		}
 	}
 	}
@@ -65,6 +69,8 @@ func (s *State) GuildRemove(guild *Guild) error {
 	if s == nil {
 	if s == nil {
 		return nilError
 		return nilError
 	}
 	}
+	s.Lock()
+	defer s.Unlock()
 
 
 	for i, g := range s.Guilds {
 	for i, g := range s.Guilds {
 		if g.ID == guild.ID {
 		if g.ID == guild.ID {
@@ -84,6 +90,8 @@ func (s *State) Guild(guildID string) (*Guild, error) {
 	if s == nil {
 	if s == nil {
 		return nil, nilError
 		return nil, nilError
 	}
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 
 	for _, g := range s.Guilds {
 	for _, g := range s.Guilds {
 		if g.ID == guildID {
 		if g.ID == guildID {
@@ -108,6 +116,9 @@ func (s *State) MemberAdd(member *Member) error {
 		return err
 		return err
 	}
 	}
 
 
+	s.Lock()
+	defer s.Unlock()
+
 	for i, m := range guild.Members {
 	for i, m := range guild.Members {
 		if m.User.ID == member.User.ID {
 		if m.User.ID == member.User.ID {
 			guild.Members[i] = member
 			guild.Members[i] = member
@@ -130,6 +141,9 @@ func (s *State) MemberRemove(member *Member) error {
 		return err
 		return err
 	}
 	}
 
 
+	s.Lock()
+	defer s.Unlock()
+
 	for i, m := range guild.Members {
 	for i, m := range guild.Members {
 		if m.User.ID == member.User.ID {
 		if m.User.ID == member.User.ID {
 			guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
 			guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
@@ -151,6 +165,9 @@ func (s *State) Member(guildID, userID string) (*Member, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	s.RLock()
+	defer s.RUnlock()
+
 	for _, m := range guild.Members {
 	for _, m := range guild.Members {
 		if m.User.ID == userID {
 		if m.User.ID == userID {
 			return m, nil
 			return m, nil
@@ -170,8 +187,14 @@ func (s *State) ChannelAdd(channel *Channel) error {
 	}
 	}
 
 
 	if channel.IsPrivate {
 	if channel.IsPrivate {
+		s.Lock()
+		defer s.Unlock()
+
+		// If the channel exists, replace it.
 		for i, c := range s.PrivateChannels {
 		for i, c := range s.PrivateChannels {
 			if c.ID == channel.ID {
 			if c.ID == channel.ID {
+				// Don't stomp on messages.
+				channel.Messages = c.Messages
 				s.PrivateChannels[i] = channel
 				s.PrivateChannels[i] = channel
 				return nil
 				return nil
 			}
 			}
@@ -184,8 +207,14 @@ func (s *State) ChannelAdd(channel *Channel) error {
 			return err
 			return err
 		}
 		}
 
 
+		s.Lock()
+		defer s.Unlock()
+
+		// If the channel exists, replace it.
 		for i, c := range guild.Channels {
 		for i, c := range guild.Channels {
 			if c.ID == channel.ID {
 			if c.ID == channel.ID {
+				// Don't stomp on messages.
+				channel.Messages = c.Messages
 				guild.Channels[i] = channel
 				guild.Channels[i] = channel
 				return nil
 				return nil
 			}
 			}
@@ -204,6 +233,9 @@ func (s *State) ChannelRemove(channel *Channel) error {
 	}
 	}
 
 
 	if channel.IsPrivate {
 	if channel.IsPrivate {
+		s.Lock()
+		defer s.Unlock()
+
 		for i, c := range s.PrivateChannels {
 		for i, c := range s.PrivateChannels {
 			if c.ID == channel.ID {
 			if c.ID == channel.ID {
 				s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
 				s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
@@ -216,6 +248,9 @@ func (s *State) ChannelRemove(channel *Channel) error {
 			return err
 			return err
 		}
 		}
 
 
+		s.Lock()
+		defer s.Unlock()
+
 		for i, c := range guild.Channels {
 		for i, c := range guild.Channels {
 			if c.ID == channel.ID {
 			if c.ID == channel.ID {
 				guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
 				guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
@@ -238,6 +273,9 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	s.RLock()
+	defer s.RUnlock()
+
 	for _, c := range guild.Channels {
 	for _, c := range guild.Channels {
 		if c.ID == channelID {
 		if c.ID == channelID {
 			return c, nil
 			return c, nil
@@ -252,6 +290,8 @@ func (s *State) PrivateChannel(channelID string) (*Channel, error) {
 	if s == nil {
 	if s == nil {
 		return nil, nilError
 		return nil, nilError
 	}
 	}
+	s.RLock()
+	defer s.RUnlock()
 
 
 	for _, c := range s.PrivateChannels {
 	for _, c := range s.PrivateChannels {
 		if c.ID == channelID {
 		if c.ID == channelID {
@@ -294,6 +334,9 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	s.RLock()
+	defer s.RUnlock()
+
 	for _, e := range guild.Emojis {
 	for _, e := range guild.Emojis {
 		if e.ID == emojiID {
 		if e.ID == emojiID {
 			return e, nil
 			return e, nil
@@ -314,6 +357,9 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
 		return err
 		return err
 	}
 	}
 
 
+	s.Lock()
+	defer s.Unlock()
+
 	for i, e := range guild.Emojis {
 	for i, e := range guild.Emojis {
 		if e.ID == emoji.ID {
 		if e.ID == emoji.ID {
 			guild.Emojis[i] = emoji
 			guild.Emojis[i] = emoji
@@ -334,3 +380,84 @@ func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+// MessageAdd adds a message to the current world state, or updates it if it exists.
+// If the channel cannot be found, the message is discarded.
+// Messages are kept in state up to s.MaxMessageCount
+func (s *State) MessageAdd(message *Message) error {
+	if s == nil {
+		return nilError
+	}
+
+	c, err := s.Channel(message.ChannelID)
+	if err != nil {
+		return err
+	}
+
+	s.Lock()
+	defer s.Unlock()
+
+	// If the message exists, replace it.
+	for i, m := range c.Messages {
+		if m.ID == message.ID {
+			c.Messages[i] = message
+			return nil
+		}
+	}
+
+	c.Messages = append(c.Messages, message)
+
+	if len(c.Messages) > s.MaxMessageCount {
+		s.Unlock()
+		for len(c.Messages) > s.MaxMessageCount {
+			s.MessageRemove(c.Messages[0])
+		}
+		s.Lock()
+	}
+	return nil
+}
+
+// MessageRemove removes a message from the world state.
+func (s *State) MessageRemove(message *Message) error {
+	if s == nil {
+		return nilError
+	}
+	c, err := s.Channel(message.ChannelID)
+	if err != nil {
+		return err
+	}
+
+	s.Lock()
+	defer s.Unlock()
+
+	for i, m := range c.Messages {
+		if m.ID == message.ID {
+			c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
+			return nil
+		}
+	}
+
+	return errors.New("Message not found.")
+}
+
+// Message gets a message by channel and message ID.
+func (s *State) Message(channelID, messageID string) (*Message, error) {
+	if s == nil {
+		return nil, nilError
+	}
+	c, err := s.Channel(channelID)
+	if err != nil {
+		return nil, err
+	}
+
+	s.RLock()
+	defer s.RUnlock()
+
+	for _, m := range c.Messages {
+		if m.ID == messageID {
+			return m, nil
+		}
+	}
+
+	return nil, errors.New("Message not found.")
+}

+ 8 - 10
structs.go

@@ -33,7 +33,7 @@ type Session struct {
 	OnTypingStart             func(*Session, *TypingStart)
 	OnTypingStart             func(*Session, *TypingStart)
 	OnMessageCreate           func(*Session, *Message)
 	OnMessageCreate           func(*Session, *Message)
 	OnMessageUpdate           func(*Session, *Message)
 	OnMessageUpdate           func(*Session, *Message)
-	OnMessageDelete           func(*Session, *MessageDelete)
+	OnMessageDelete           func(*Session, *Message)
 	OnMessageAck              func(*Session, *MessageAck)
 	OnMessageAck              func(*Session, *MessageAck)
 	OnUserUpdate              func(*Session, *User)
 	OnUserUpdate              func(*Session, *User)
 	OnPresenceUpdate          func(*Session, *PresenceUpdate)
 	OnPresenceUpdate          func(*Session, *PresenceUpdate)
@@ -46,7 +46,7 @@ type Session struct {
 	OnGuildDelete             func(*Session, *Guild)
 	OnGuildDelete             func(*Session, *Guild)
 	OnGuildMemberAdd          func(*Session, *Member)
 	OnGuildMemberAdd          func(*Session, *Member)
 	OnGuildMemberRemove       func(*Session, *Member)
 	OnGuildMemberRemove       func(*Session, *Member)
-	OnGuildMemberDelete       func(*Session, *Member) // which is it?
+	OnGuildMemberDelete       func(*Session, *Member)
 	OnGuildMemberUpdate       func(*Session, *Member)
 	OnGuildMemberUpdate       func(*Session, *Member)
 	OnGuildRoleCreate         func(*Session, *GuildRole)
 	OnGuildRoleCreate         func(*Session, *GuildRole)
 	OnGuildRoleUpdate         func(*Session, *GuildRole)
 	OnGuildRoleUpdate         func(*Session, *GuildRole)
@@ -77,8 +77,9 @@ type Session struct {
 	Voice *Voice // Stores all details related to voice connections
 	Voice *Voice // Stores all details related to voice connections
 
 
 	// Managed state object, updated with events.
 	// Managed state object, updated with events.
-	State        *State
-	StateEnabled bool
+	State                *State
+	StateEnabled         bool
+	StateMaxMessageCount int
 
 
 	// Mutex/Bools for locks that prevent accidents.
 	// Mutex/Bools for locks that prevent accidents.
 	// TODO: Add channels.
 	// TODO: Add channels.
@@ -138,6 +139,7 @@ type Channel struct {
 	IsPrivate            bool                   `json:"is_private"`
 	IsPrivate            bool                   `json:"is_private"`
 	LastMessageID        string                 `json:"last_message_id"`
 	LastMessageID        string                 `json:"last_message_id"`
 	Recipient            *User                  `json:"recipient"`
 	Recipient            *User                  `json:"recipient"`
+	Messages             []*Message             `json:"-"`
 }
 }
 
 
 // A PermissionOverwrite holds permission overwrite data for a Channel
 // A PermissionOverwrite holds permission overwrite data for a Channel
@@ -309,12 +311,6 @@ type MessageAck struct {
 	ChannelID string `json:"channel_id"`
 	ChannelID string `json:"channel_id"`
 }
 }
 
 
-// A MessageDelete stores data for the message delete websocket event.
-type MessageDelete struct {
-	ID        string `json:"id"`
-	ChannelID string `json:"channel_id"`
-} // so much like MessageAck..
-
 // A GuildIntegrationsUpdate stores data for the guild integrations update
 // A GuildIntegrationsUpdate stores data for the guild integrations update
 // websocket event.
 // websocket event.
 type GuildIntegrationsUpdate struct {
 type GuildIntegrationsUpdate struct {
@@ -349,5 +345,7 @@ type GuildEmojisUpdate struct {
 // As discord sends this in a READY blob, it seems reasonable to simply
 // As discord sends this in a READY blob, it seems reasonable to simply
 // use that struct as the data store.
 // use that struct as the data store.
 type State struct {
 type State struct {
+	sync.RWMutex
 	Ready
 	Ready
+	MaxMessageCount int
 }
 }

+ 38 - 9
wsapi.go

@@ -251,27 +251,56 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 			}
 			}
 		*/
 		*/
 	case "MESSAGE_CREATE":
 	case "MESSAGE_CREATE":
-		if s.OnMessageCreate != nil {
-			var st *Message
-			if err = unmarshalEvent(e, &st); err == nil {
+		stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0
+		if !stateEnabled && s.OnMessageCreate == nil {
+			break
+		}
+		var st *Message
+		if err = unmarshalEvent(e, &st); err == nil {
+			if stateEnabled {
+				s.State.MessageAdd(st)
+			}
+			if s.OnMessageCreate != nil {
 				s.OnMessageCreate(s, st)
 				s.OnMessageCreate(s, st)
 			}
 			}
+		}
+		if s.OnMessageCreate != nil {
 			return
 			return
 		}
 		}
 	case "MESSAGE_UPDATE":
 	case "MESSAGE_UPDATE":
-		if s.OnMessageUpdate != nil {
-			var st *Message
-			if err = unmarshalEvent(e, &st); err == nil {
+		stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0
+		if !stateEnabled && s.OnMessageUpdate == nil {
+			break
+		}
+		var st *Message
+		if err = unmarshalEvent(e, &st); err == nil {
+			if stateEnabled {
+				s.State.MessageAdd(st)
+			}
+			if s.OnMessageUpdate != nil {
 				s.OnMessageUpdate(s, st)
 				s.OnMessageUpdate(s, st)
 			}
 			}
+		}
+		return
+		if s.OnMessageUpdate != nil {
 			return
 			return
 		}
 		}
 	case "MESSAGE_DELETE":
 	case "MESSAGE_DELETE":
-		if s.OnMessageDelete != nil {
-			var st *MessageDelete
-			if err = unmarshalEvent(e, &st); err == nil {
+		stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0
+		if !stateEnabled && s.OnMessageDelete == nil {
+			break
+		}
+		var st *Message
+		if err = unmarshalEvent(e, &st); err == nil {
+			if stateEnabled {
+				s.State.MessageRemove(st)
+			}
+			if s.OnMessageDelete != nil {
 				s.OnMessageDelete(s, st)
 				s.OnMessageDelete(s, st)
 			}
 			}
+		}
+		return
+		if s.OnMessageDelete != nil {
 			return
 			return
 		}
 		}
 	case "MESSAGE_ACK":
 	case "MESSAGE_ACK":