|
@@ -38,13 +38,15 @@ type State struct {
|
|
Ready
|
|
Ready
|
|
|
|
|
|
// MaxMessageCount represents how many messages per channel the state will store.
|
|
// MaxMessageCount represents how many messages per channel the state will store.
|
|
- MaxMessageCount int
|
|
|
|
- TrackChannels bool
|
|
|
|
- TrackEmojis bool
|
|
|
|
- TrackMembers bool
|
|
|
|
- TrackRoles bool
|
|
|
|
- TrackVoice bool
|
|
|
|
- TrackPresences bool
|
|
|
|
|
|
+ MaxMessageCount int
|
|
|
|
+ TrackChannels bool
|
|
|
|
+ TrackThreads bool
|
|
|
|
+ TrackEmojis bool
|
|
|
|
+ TrackMembers bool
|
|
|
|
+ TrackThreadMembers bool
|
|
|
|
+ TrackRoles bool
|
|
|
|
+ TrackVoice bool
|
|
|
|
+ TrackPresences bool
|
|
|
|
|
|
guildMap map[string]*Guild
|
|
guildMap map[string]*Guild
|
|
channelMap map[string]*Channel
|
|
channelMap map[string]*Channel
|
|
@@ -58,15 +60,17 @@ func NewState() *State {
|
|
PrivateChannels: []*Channel{},
|
|
PrivateChannels: []*Channel{},
|
|
Guilds: []*Guild{},
|
|
Guilds: []*Guild{},
|
|
},
|
|
},
|
|
- TrackChannels: true,
|
|
|
|
- TrackEmojis: true,
|
|
|
|
- TrackMembers: true,
|
|
|
|
- TrackRoles: true,
|
|
|
|
- TrackVoice: true,
|
|
|
|
- TrackPresences: true,
|
|
|
|
- guildMap: make(map[string]*Guild),
|
|
|
|
- channelMap: make(map[string]*Channel),
|
|
|
|
- memberMap: make(map[string]map[string]*Member),
|
|
|
|
|
|
+ TrackChannels: true,
|
|
|
|
+ TrackThreads: true,
|
|
|
|
+ TrackEmojis: true,
|
|
|
|
+ TrackMembers: true,
|
|
|
|
+ TrackThreadMembers: true,
|
|
|
|
+ TrackRoles: true,
|
|
|
|
+ TrackVoice: true,
|
|
|
|
+ TrackPresences: true,
|
|
|
|
+ guildMap: make(map[string]*Guild),
|
|
|
|
+ channelMap: make(map[string]*Channel),
|
|
|
|
+ memberMap: make(map[string]map[string]*Member),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -93,6 +97,11 @@ func (s *State) GuildAdd(guild *Guild) error {
|
|
s.channelMap[c.ID] = c
|
|
s.channelMap[c.ID] = c
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // Add all the threads to the state in case of thread sync list.
|
|
|
|
+ for _, t := range guild.Threads {
|
|
|
|
+ s.channelMap[t.ID] = t
|
|
|
|
+ }
|
|
|
|
+
|
|
// If this guild contains a new member slice, we must regenerate the member map so the pointers stay valid
|
|
// If this guild contains a new member slice, we must regenerate the member map so the pointers stay valid
|
|
if guild.Members != nil {
|
|
if guild.Members != nil {
|
|
s.createMemberMap(guild)
|
|
s.createMemberMap(guild)
|
|
@@ -122,6 +131,9 @@ func (s *State) GuildAdd(guild *Guild) error {
|
|
if guild.Channels == nil {
|
|
if guild.Channels == nil {
|
|
guild.Channels = g.Channels
|
|
guild.Channels = g.Channels
|
|
}
|
|
}
|
|
|
|
+ if guild.Threads == nil {
|
|
|
|
+ guild.Threads = g.Threads
|
|
|
|
+ }
|
|
if guild.VoiceStates == nil {
|
|
if guild.VoiceStates == nil {
|
|
guild.VoiceStates = g.VoiceStates
|
|
guild.VoiceStates = g.VoiceStates
|
|
}
|
|
}
|
|
@@ -180,21 +192,12 @@ func (s *State) Guild(guildID string) (*Guild, error) {
|
|
return nil, ErrStateNotFound
|
|
return nil, ErrStateNotFound
|
|
}
|
|
}
|
|
|
|
|
|
-// PresenceAdd adds a presence to the current world state, or
|
|
|
|
-// updates it if it already exists.
|
|
|
|
-func (s *State) PresenceAdd(guildID string, presence *Presence) error {
|
|
|
|
- if s == nil {
|
|
|
|
- return ErrNilState
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- guild, err := s.Guild(guildID)
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
|
|
+func (s *State) presenceAdd(guildID string, presence *Presence) error {
|
|
|
|
+ guild, ok := s.guildMap[guildID]
|
|
|
|
+ if !ok {
|
|
|
|
+ return ErrStateNotFound
|
|
}
|
|
}
|
|
|
|
|
|
- s.Lock()
|
|
|
|
- defer s.Unlock()
|
|
|
|
-
|
|
|
|
for i, p := range guild.Presences {
|
|
for i, p := range guild.Presences {
|
|
if p.User.ID == presence.User.ID {
|
|
if p.User.ID == presence.User.ID {
|
|
//guild.Presences[i] = presence
|
|
//guild.Presences[i] = presence
|
|
@@ -233,6 +236,19 @@ func (s *State) PresenceAdd(guildID string, presence *Presence) error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// PresenceAdd adds a presence to the current world state, or
|
|
|
|
+// updates it if it already exists.
|
|
|
|
+func (s *State) PresenceAdd(guildID string, presence *Presence) error {
|
|
|
|
+ if s == nil {
|
|
|
|
+ return ErrNilState
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.Lock()
|
|
|
|
+ defer s.Unlock()
|
|
|
|
+
|
|
|
|
+ return s.presenceAdd(guildID, presence)
|
|
|
|
+}
|
|
|
|
+
|
|
// PresenceRemove removes a presence from the current world state.
|
|
// PresenceRemove removes a presence from the current world state.
|
|
func (s *State) PresenceRemove(guildID string, presence *Presence) error {
|
|
func (s *State) PresenceRemove(guildID string, presence *Presence) error {
|
|
if s == nil {
|
|
if s == nil {
|
|
@@ -279,21 +295,12 @@ func (s *State) Presence(guildID, userID string) (*Presence, error) {
|
|
|
|
|
|
// TODO: Consider moving Guild state update methods onto *Guild.
|
|
// TODO: Consider moving Guild state update methods onto *Guild.
|
|
|
|
|
|
-// MemberAdd adds a member to the current world state, or
|
|
|
|
-// updates it if it already exists.
|
|
|
|
-func (s *State) MemberAdd(member *Member) error {
|
|
|
|
- if s == nil {
|
|
|
|
- return ErrNilState
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- guild, err := s.Guild(member.GuildID)
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
|
|
+func (s *State) memberAdd(member *Member) error {
|
|
|
|
+ guild, ok := s.guildMap[member.GuildID]
|
|
|
|
+ if !ok {
|
|
|
|
+ return ErrStateNotFound
|
|
}
|
|
}
|
|
|
|
|
|
- s.Lock()
|
|
|
|
- defer s.Unlock()
|
|
|
|
-
|
|
|
|
members, ok := s.memberMap[member.GuildID]
|
|
members, ok := s.memberMap[member.GuildID]
|
|
if !ok {
|
|
if !ok {
|
|
return ErrStateNotFound
|
|
return ErrStateNotFound
|
|
@@ -311,10 +318,22 @@ func (s *State) MemberAdd(member *Member) error {
|
|
}
|
|
}
|
|
*m = *member
|
|
*m = *member
|
|
}
|
|
}
|
|
-
|
|
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// MemberAdd adds a member to the current world state, or
|
|
|
|
+// updates it if it already exists.
|
|
|
|
+func (s *State) MemberAdd(member *Member) error {
|
|
|
|
+ if s == nil {
|
|
|
|
+ return ErrNilState
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.Lock()
|
|
|
|
+ defer s.Unlock()
|
|
|
|
+
|
|
|
|
+ return s.memberAdd(member)
|
|
|
|
+}
|
|
|
|
+
|
|
// MemberRemove removes a member from current world state.
|
|
// MemberRemove removes a member from current world state.
|
|
func (s *State) MemberRemove(member *Member) error {
|
|
func (s *State) MemberRemove(member *Member) error {
|
|
if s == nil {
|
|
if s == nil {
|
|
@@ -465,6 +484,9 @@ func (s *State) ChannelAdd(channel *Channel) error {
|
|
if channel.PermissionOverwrites == nil {
|
|
if channel.PermissionOverwrites == nil {
|
|
channel.PermissionOverwrites = c.PermissionOverwrites
|
|
channel.PermissionOverwrites = c.PermissionOverwrites
|
|
}
|
|
}
|
|
|
|
+ if channel.ThreadMetadata == nil {
|
|
|
|
+ channel.ThreadMetadata = c.ThreadMetadata
|
|
|
|
+ }
|
|
|
|
|
|
*c = *channel
|
|
*c = *channel
|
|
return nil
|
|
return nil
|
|
@@ -472,12 +494,18 @@ func (s *State) ChannelAdd(channel *Channel) error {
|
|
|
|
|
|
if channel.Type == ChannelTypeDM || channel.Type == ChannelTypeGroupDM {
|
|
if channel.Type == ChannelTypeDM || channel.Type == ChannelTypeGroupDM {
|
|
s.PrivateChannels = append(s.PrivateChannels, channel)
|
|
s.PrivateChannels = append(s.PrivateChannels, channel)
|
|
- } else {
|
|
|
|
- guild, ok := s.guildMap[channel.GuildID]
|
|
|
|
- if !ok {
|
|
|
|
- return ErrStateNotFound
|
|
|
|
- }
|
|
|
|
|
|
+ s.channelMap[channel.ID] = channel
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ guild, ok := s.guildMap[channel.GuildID]
|
|
|
|
+ if !ok {
|
|
|
|
+ return ErrStateNotFound
|
|
|
|
+ }
|
|
|
|
|
|
|
|
+ if channel.IsThread() {
|
|
|
|
+ guild.Threads = append(guild.Threads, channel)
|
|
|
|
+ } else {
|
|
guild.Channels = append(guild.Channels, channel)
|
|
guild.Channels = append(guild.Channels, channel)
|
|
}
|
|
}
|
|
|
|
|
|
@@ -507,15 +535,26 @@ func (s *State) ChannelRemove(channel *Channel) error {
|
|
break
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- } else {
|
|
|
|
- guild, err := s.Guild(channel.GuildID)
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
|
|
+ delete(s.channelMap, channel.ID)
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
|
|
- s.Lock()
|
|
|
|
- defer s.Unlock()
|
|
|
|
|
|
+ guild, err := s.Guild(channel.GuildID)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
|
|
|
|
+ s.Lock()
|
|
|
|
+ defer s.Unlock()
|
|
|
|
+
|
|
|
|
+ if channel.IsThread() {
|
|
|
|
+ for i, t := range guild.Threads {
|
|
|
|
+ if t.ID == channel.ID {
|
|
|
|
+ guild.Threads = append(guild.Threads[:i], guild.Threads[i+1:]...)
|
|
|
|
+ break
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
for i, c := range guild.Channels {
|
|
for i, c := range guild.Channels {
|
|
if c.ID == channel.ID {
|
|
if c.ID == channel.ID {
|
|
guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
|
|
guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
|
|
@@ -529,6 +568,99 @@ func (s *State) ChannelRemove(channel *Channel) error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// ThreadListSync syncs guild threads with provided ones.
|
|
|
|
+func (s *State) ThreadListSync(tls *ThreadListSync) error {
|
|
|
|
+ guild, err := s.Guild(tls.GuildID)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.Lock()
|
|
|
|
+ defer s.Unlock()
|
|
|
|
+
|
|
|
|
+ // This algorithm filters out archived or
|
|
|
|
+ // threads which are children of channels in channelIDs
|
|
|
|
+ // and then it adds all synced threads to guild threads and cache
|
|
|
|
+ index := 0
|
|
|
|
+outer:
|
|
|
|
+ for _, t := range guild.Threads {
|
|
|
|
+ if !t.ThreadMetadata.Archived && tls.ChannelIDs != nil {
|
|
|
|
+ for _, v := range tls.ChannelIDs {
|
|
|
|
+ if t.ParentID == v {
|
|
|
|
+ delete(s.channelMap, t.ID)
|
|
|
|
+ continue outer
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ guild.Threads[index] = t
|
|
|
|
+ index++
|
|
|
|
+ } else {
|
|
|
|
+ delete(s.channelMap, t.ID)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ guild.Threads = guild.Threads[:index]
|
|
|
|
+ for _, t := range tls.Threads {
|
|
|
|
+ s.channelMap[t.ID] = t
|
|
|
|
+ guild.Threads = append(guild.Threads, t)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for _, m := range tls.Members {
|
|
|
|
+ if c, ok := s.channelMap[m.ID]; ok {
|
|
|
|
+ c.Member = m
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// ThreadMembersUpdate updates thread members list
|
|
|
|
+func (s *State) ThreadMembersUpdate(tmu *ThreadMembersUpdate) error {
|
|
|
|
+ thread, err := s.Channel(tmu.ID)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ s.Lock()
|
|
|
|
+ defer s.Unlock()
|
|
|
|
+
|
|
|
|
+ for idx, member := range thread.Members {
|
|
|
|
+ for _, removedMember := range tmu.RemovedMembers {
|
|
|
|
+ if member.ID == removedMember {
|
|
|
|
+ thread.Members = append(thread.Members[:idx], thread.Members[idx+1:]...)
|
|
|
|
+ break
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for _, addedMember := range tmu.AddedMembers {
|
|
|
|
+ thread.Members = append(thread.Members, addedMember.ThreadMember)
|
|
|
|
+ if addedMember.Member != nil {
|
|
|
|
+ err = s.memberAdd(addedMember.Member)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ if addedMember.Presence != nil {
|
|
|
|
+ err = s.presenceAdd(tmu.GuildID, addedMember.Presence)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ thread.MemberCount = tmu.MemberCount
|
|
|
|
+
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// ThreadMemberUpdate sets or updates member data for the current user.
|
|
|
|
+func (s *State) ThreadMemberUpdate(mu *ThreadMemberUpdate) error {
|
|
|
|
+ thread, err := s.Channel(mu.ID)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ thread.Member = mu.ThreadMember
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
// GuildChannel gets a channel by ID from a guild.
|
|
// GuildChannel gets a channel by ID from a guild.
|
|
// This method is Deprecated, use Channel(channelID)
|
|
// This method is Deprecated, use Channel(channelID)
|
|
func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
|
|
func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
|
|
@@ -668,6 +800,7 @@ func (s *State) MessageAdd(message *Message) error {
|
|
if len(c.Messages) > s.MaxMessageCount {
|
|
if len(c.Messages) > s.MaxMessageCount {
|
|
c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
|
|
c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
|
|
}
|
|
}
|
|
|
|
+
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -693,6 +826,7 @@ func (s *State) messageRemoveByID(channelID, messageID string) error {
|
|
for i, m := range c.Messages {
|
|
for i, m := range c.Messages {
|
|
if m.ID == messageID {
|
|
if m.ID == messageID {
|
|
c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
|
|
c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
|
|
|
|
+
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -913,6 +1047,35 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) {
|
|
if s.TrackChannels {
|
|
if s.TrackChannels {
|
|
err = s.ChannelRemove(t.Channel)
|
|
err = s.ChannelRemove(t.Channel)
|
|
}
|
|
}
|
|
|
|
+ case *ThreadCreate:
|
|
|
|
+ if s.TrackThreads {
|
|
|
|
+ err = s.ChannelAdd(t.Channel)
|
|
|
|
+ }
|
|
|
|
+ case *ThreadUpdate:
|
|
|
|
+ if s.TrackThreads {
|
|
|
|
+ old, err := s.Channel(t.ID)
|
|
|
|
+ if err == nil {
|
|
|
|
+ oldCopy := *old
|
|
|
|
+ t.BeforeUpdate = &oldCopy
|
|
|
|
+ }
|
|
|
|
+ err = s.ChannelAdd(t.Channel)
|
|
|
|
+ }
|
|
|
|
+ case *ThreadDelete:
|
|
|
|
+ if s.TrackThreads {
|
|
|
|
+ err = s.ChannelRemove(t.Channel)
|
|
|
|
+ }
|
|
|
|
+ case *ThreadMemberUpdate:
|
|
|
|
+ if s.TrackThreads {
|
|
|
|
+ err = s.ThreadMemberUpdate(t)
|
|
|
|
+ }
|
|
|
|
+ case *ThreadMembersUpdate:
|
|
|
|
+ if s.TrackThreadMembers {
|
|
|
|
+ err = s.ThreadMembersUpdate(t)
|
|
|
|
+ }
|
|
|
|
+ case *ThreadListSync:
|
|
|
|
+ if s.TrackThreads {
|
|
|
|
+ err = s.ThreadListSync(t)
|
|
|
|
+ }
|
|
case *MessageCreate:
|
|
case *MessageCreate:
|
|
if s.MaxMessageCount != 0 {
|
|
if s.MaxMessageCount != 0 {
|
|
err = s.MessageAdd(t.Message)
|
|
err = s.MessageAdd(t.Message)
|