state.go 27 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sort"
  14. "sync"
  15. )
  16. // ErrNilState is returned when the state is nil.
  17. var ErrNilState = errors.New("state not instantiated, please use discordgo.New() or assign Session.State")
  18. // ErrStateNotFound is returned when the state cache
  19. // requested is not found
  20. var ErrStateNotFound = errors.New("state cache not found")
  21. // ErrMessageIncompletePermissions is returned when the message
  22. // requested for permissions does not contain enough data to
  23. // generate the permissions.
  24. var ErrMessageIncompletePermissions = errors.New("message incomplete, unable to determine permissions")
  25. // A State contains the current known state.
  26. // As discord sends this in a READY blob, it seems reasonable to simply
  27. // use that struct as the data store.
  28. type State struct {
  29. sync.RWMutex
  30. Ready
  31. // MaxMessageCount represents how many messages per channel the state will store.
  32. MaxMessageCount int
  33. TrackChannels bool
  34. TrackThreads bool
  35. TrackEmojis bool
  36. TrackMembers bool
  37. TrackThreadMembers bool
  38. TrackRoles bool
  39. TrackVoice bool
  40. TrackPresences bool
  41. guildMap map[string]*Guild
  42. channelMap map[string]*Channel
  43. memberMap map[string]map[string]*Member
  44. }
  45. // NewState creates an empty state.
  46. func NewState() *State {
  47. return &State{
  48. Ready: Ready{
  49. PrivateChannels: []*Channel{},
  50. Guilds: []*Guild{},
  51. },
  52. TrackChannels: true,
  53. TrackThreads: true,
  54. TrackEmojis: true,
  55. TrackMembers: true,
  56. TrackThreadMembers: true,
  57. TrackRoles: true,
  58. TrackVoice: true,
  59. TrackPresences: true,
  60. guildMap: make(map[string]*Guild),
  61. channelMap: make(map[string]*Channel),
  62. memberMap: make(map[string]map[string]*Member),
  63. }
  64. }
  65. func (s *State) createMemberMap(guild *Guild) {
  66. members := make(map[string]*Member)
  67. for _, m := range guild.Members {
  68. members[m.User.ID] = m
  69. }
  70. s.memberMap[guild.ID] = members
  71. }
  72. // GuildAdd adds a guild to the current world state, or
  73. // updates it if it already exists.
  74. func (s *State) GuildAdd(guild *Guild) error {
  75. if s == nil {
  76. return ErrNilState
  77. }
  78. s.Lock()
  79. defer s.Unlock()
  80. // Update the channels to point to the right guild, adding them to the channelMap as we go
  81. for _, c := range guild.Channels {
  82. s.channelMap[c.ID] = c
  83. }
  84. // Add all the threads to the state in case of thread sync list.
  85. for _, t := range guild.Threads {
  86. s.channelMap[t.ID] = t
  87. }
  88. // If this guild contains a new member slice, we must regenerate the member map so the pointers stay valid
  89. if guild.Members != nil {
  90. s.createMemberMap(guild)
  91. } else if _, ok := s.memberMap[guild.ID]; !ok {
  92. // Even if we have no new member slice, we still initialize the member map for this guild if it doesn't exist
  93. s.memberMap[guild.ID] = make(map[string]*Member)
  94. }
  95. if g, ok := s.guildMap[guild.ID]; ok {
  96. // We are about to replace `g` in the state with `guild`, but first we need to
  97. // make sure we preserve any fields that the `guild` doesn't contain from `g`.
  98. if guild.MemberCount == 0 {
  99. guild.MemberCount = g.MemberCount
  100. }
  101. if guild.Roles == nil {
  102. guild.Roles = g.Roles
  103. }
  104. if guild.Emojis == nil {
  105. guild.Emojis = g.Emojis
  106. }
  107. if guild.Members == nil {
  108. guild.Members = g.Members
  109. }
  110. if guild.Presences == nil {
  111. guild.Presences = g.Presences
  112. }
  113. if guild.Channels == nil {
  114. guild.Channels = g.Channels
  115. }
  116. if guild.Threads == nil {
  117. guild.Threads = g.Threads
  118. }
  119. if guild.VoiceStates == nil {
  120. guild.VoiceStates = g.VoiceStates
  121. }
  122. *g = *guild
  123. return nil
  124. }
  125. s.Guilds = append(s.Guilds, guild)
  126. s.guildMap[guild.ID] = guild
  127. return nil
  128. }
  129. // GuildRemove removes a guild from current world state.
  130. func (s *State) GuildRemove(guild *Guild) error {
  131. if s == nil {
  132. return ErrNilState
  133. }
  134. _, err := s.Guild(guild.ID)
  135. if err != nil {
  136. return err
  137. }
  138. s.Lock()
  139. defer s.Unlock()
  140. delete(s.guildMap, guild.ID)
  141. for i, g := range s.Guilds {
  142. if g.ID == guild.ID {
  143. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  144. return nil
  145. }
  146. }
  147. return nil
  148. }
  149. // Guild gets a guild by ID.
  150. // Useful for querying if @me is in a guild:
  151. // _, err := discordgo.Session.State.Guild(guildID)
  152. // isInGuild := err == nil
  153. func (s *State) Guild(guildID string) (*Guild, error) {
  154. if s == nil {
  155. return nil, ErrNilState
  156. }
  157. s.RLock()
  158. defer s.RUnlock()
  159. if g, ok := s.guildMap[guildID]; ok {
  160. return g, nil
  161. }
  162. return nil, ErrStateNotFound
  163. }
  164. func (s *State) presenceAdd(guildID string, presence *Presence) error {
  165. guild, ok := s.guildMap[guildID]
  166. if !ok {
  167. return ErrStateNotFound
  168. }
  169. for i, p := range guild.Presences {
  170. if p.User.ID == presence.User.ID {
  171. //guild.Presences[i] = presence
  172. //Update status
  173. guild.Presences[i].Activities = presence.Activities
  174. if presence.Status != "" {
  175. guild.Presences[i].Status = presence.Status
  176. }
  177. //Update the optionally sent user information
  178. //ID Is a mandatory field so you should not need to check if it is empty
  179. guild.Presences[i].User.ID = presence.User.ID
  180. if presence.User.Avatar != "" {
  181. guild.Presences[i].User.Avatar = presence.User.Avatar
  182. }
  183. if presence.User.Discriminator != "" {
  184. guild.Presences[i].User.Discriminator = presence.User.Discriminator
  185. }
  186. if presence.User.Email != "" {
  187. guild.Presences[i].User.Email = presence.User.Email
  188. }
  189. if presence.User.Token != "" {
  190. guild.Presences[i].User.Token = presence.User.Token
  191. }
  192. if presence.User.Username != "" {
  193. guild.Presences[i].User.Username = presence.User.Username
  194. }
  195. return nil
  196. }
  197. }
  198. guild.Presences = append(guild.Presences, presence)
  199. return nil
  200. }
  201. // PresenceAdd adds a presence to the current world state, or
  202. // updates it if it already exists.
  203. func (s *State) PresenceAdd(guildID string, presence *Presence) error {
  204. if s == nil {
  205. return ErrNilState
  206. }
  207. s.Lock()
  208. defer s.Unlock()
  209. return s.presenceAdd(guildID, presence)
  210. }
  211. // PresenceRemove removes a presence from the current world state.
  212. func (s *State) PresenceRemove(guildID string, presence *Presence) error {
  213. if s == nil {
  214. return ErrNilState
  215. }
  216. guild, err := s.Guild(guildID)
  217. if err != nil {
  218. return err
  219. }
  220. s.Lock()
  221. defer s.Unlock()
  222. for i, p := range guild.Presences {
  223. if p.User.ID == presence.User.ID {
  224. guild.Presences = append(guild.Presences[:i], guild.Presences[i+1:]...)
  225. return nil
  226. }
  227. }
  228. return ErrStateNotFound
  229. }
  230. // Presence gets a presence by ID from a guild.
  231. func (s *State) Presence(guildID, userID string) (*Presence, error) {
  232. if s == nil {
  233. return nil, ErrNilState
  234. }
  235. guild, err := s.Guild(guildID)
  236. if err != nil {
  237. return nil, err
  238. }
  239. for _, p := range guild.Presences {
  240. if p.User.ID == userID {
  241. return p, nil
  242. }
  243. }
  244. return nil, ErrStateNotFound
  245. }
  246. // TODO: Consider moving Guild state update methods onto *Guild.
  247. func (s *State) memberAdd(member *Member) error {
  248. guild, ok := s.guildMap[member.GuildID]
  249. if !ok {
  250. return ErrStateNotFound
  251. }
  252. members, ok := s.memberMap[member.GuildID]
  253. if !ok {
  254. return ErrStateNotFound
  255. }
  256. m, ok := members[member.User.ID]
  257. if !ok {
  258. members[member.User.ID] = member
  259. guild.Members = append(guild.Members, member)
  260. } else {
  261. // We are about to replace `m` in the state with `member`, but first we need to
  262. // make sure we preserve any fields that the `member` doesn't contain from `m`.
  263. if member.JoinedAt.IsZero() {
  264. member.JoinedAt = m.JoinedAt
  265. }
  266. *m = *member
  267. }
  268. return nil
  269. }
  270. // MemberAdd adds a member to the current world state, or
  271. // updates it if it already exists.
  272. func (s *State) MemberAdd(member *Member) error {
  273. if s == nil {
  274. return ErrNilState
  275. }
  276. s.Lock()
  277. defer s.Unlock()
  278. return s.memberAdd(member)
  279. }
  280. // MemberRemove removes a member from current world state.
  281. func (s *State) MemberRemove(member *Member) error {
  282. if s == nil {
  283. return ErrNilState
  284. }
  285. guild, err := s.Guild(member.GuildID)
  286. if err != nil {
  287. return err
  288. }
  289. s.Lock()
  290. defer s.Unlock()
  291. members, ok := s.memberMap[member.GuildID]
  292. if !ok {
  293. return ErrStateNotFound
  294. }
  295. _, ok = members[member.User.ID]
  296. if !ok {
  297. return ErrStateNotFound
  298. }
  299. delete(members, member.User.ID)
  300. for i, m := range guild.Members {
  301. if m.User.ID == member.User.ID {
  302. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  303. return nil
  304. }
  305. }
  306. return ErrStateNotFound
  307. }
  308. // Member gets a member by ID from a guild.
  309. func (s *State) Member(guildID, userID string) (*Member, error) {
  310. if s == nil {
  311. return nil, ErrNilState
  312. }
  313. s.RLock()
  314. defer s.RUnlock()
  315. members, ok := s.memberMap[guildID]
  316. if !ok {
  317. return nil, ErrStateNotFound
  318. }
  319. m, ok := members[userID]
  320. if ok {
  321. return m, nil
  322. }
  323. return nil, ErrStateNotFound
  324. }
  325. // RoleAdd adds a role to the current world state, or
  326. // updates it if it already exists.
  327. func (s *State) RoleAdd(guildID string, role *Role) error {
  328. if s == nil {
  329. return ErrNilState
  330. }
  331. guild, err := s.Guild(guildID)
  332. if err != nil {
  333. return err
  334. }
  335. s.Lock()
  336. defer s.Unlock()
  337. for i, r := range guild.Roles {
  338. if r.ID == role.ID {
  339. guild.Roles[i] = role
  340. return nil
  341. }
  342. }
  343. guild.Roles = append(guild.Roles, role)
  344. return nil
  345. }
  346. // RoleRemove removes a role from current world state by ID.
  347. func (s *State) RoleRemove(guildID, roleID string) error {
  348. if s == nil {
  349. return ErrNilState
  350. }
  351. guild, err := s.Guild(guildID)
  352. if err != nil {
  353. return err
  354. }
  355. s.Lock()
  356. defer s.Unlock()
  357. for i, r := range guild.Roles {
  358. if r.ID == roleID {
  359. guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...)
  360. return nil
  361. }
  362. }
  363. return ErrStateNotFound
  364. }
  365. // Role gets a role by ID from a guild.
  366. func (s *State) Role(guildID, roleID string) (*Role, error) {
  367. if s == nil {
  368. return nil, ErrNilState
  369. }
  370. guild, err := s.Guild(guildID)
  371. if err != nil {
  372. return nil, err
  373. }
  374. s.RLock()
  375. defer s.RUnlock()
  376. for _, r := range guild.Roles {
  377. if r.ID == roleID {
  378. return r, nil
  379. }
  380. }
  381. return nil, ErrStateNotFound
  382. }
  383. // ChannelAdd adds a channel to the current world state, or
  384. // updates it if it already exists.
  385. // Channels may exist either as PrivateChannels or inside
  386. // a guild.
  387. func (s *State) ChannelAdd(channel *Channel) error {
  388. if s == nil {
  389. return ErrNilState
  390. }
  391. s.Lock()
  392. defer s.Unlock()
  393. // If the channel exists, replace it
  394. if c, ok := s.channelMap[channel.ID]; ok {
  395. if channel.Messages == nil {
  396. channel.Messages = c.Messages
  397. }
  398. if channel.PermissionOverwrites == nil {
  399. channel.PermissionOverwrites = c.PermissionOverwrites
  400. }
  401. if channel.ThreadMetadata == nil {
  402. channel.ThreadMetadata = c.ThreadMetadata
  403. }
  404. *c = *channel
  405. return nil
  406. }
  407. if channel.Type == ChannelTypeDM || channel.Type == ChannelTypeGroupDM {
  408. s.PrivateChannels = append(s.PrivateChannels, channel)
  409. s.channelMap[channel.ID] = channel
  410. return nil
  411. }
  412. guild, ok := s.guildMap[channel.GuildID]
  413. if !ok {
  414. return ErrStateNotFound
  415. }
  416. if channel.IsThread() {
  417. guild.Threads = append(guild.Threads, channel)
  418. } else {
  419. guild.Channels = append(guild.Channels, channel)
  420. }
  421. s.channelMap[channel.ID] = channel
  422. return nil
  423. }
  424. // ChannelRemove removes a channel from current world state.
  425. func (s *State) ChannelRemove(channel *Channel) error {
  426. if s == nil {
  427. return ErrNilState
  428. }
  429. _, err := s.Channel(channel.ID)
  430. if err != nil {
  431. return err
  432. }
  433. if channel.Type == ChannelTypeDM || channel.Type == ChannelTypeGroupDM {
  434. s.Lock()
  435. defer s.Unlock()
  436. for i, c := range s.PrivateChannels {
  437. if c.ID == channel.ID {
  438. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  439. break
  440. }
  441. }
  442. delete(s.channelMap, channel.ID)
  443. return nil
  444. }
  445. guild, err := s.Guild(channel.GuildID)
  446. if err != nil {
  447. return err
  448. }
  449. s.Lock()
  450. defer s.Unlock()
  451. if channel.IsThread() {
  452. for i, t := range guild.Threads {
  453. if t.ID == channel.ID {
  454. guild.Threads = append(guild.Threads[:i], guild.Threads[i+1:]...)
  455. break
  456. }
  457. }
  458. } else {
  459. for i, c := range guild.Channels {
  460. if c.ID == channel.ID {
  461. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  462. break
  463. }
  464. }
  465. }
  466. delete(s.channelMap, channel.ID)
  467. return nil
  468. }
  469. // ThreadListSync syncs guild threads with provided ones.
  470. func (s *State) ThreadListSync(tls *ThreadListSync) error {
  471. guild, err := s.Guild(tls.GuildID)
  472. if err != nil {
  473. return err
  474. }
  475. s.Lock()
  476. defer s.Unlock()
  477. // This algorithm filters out archived or
  478. // threads which are children of channels in channelIDs
  479. // and then it adds all synced threads to guild threads and cache
  480. index := 0
  481. outer:
  482. for _, t := range guild.Threads {
  483. if !t.ThreadMetadata.Archived && tls.ChannelIDs != nil {
  484. for _, v := range tls.ChannelIDs {
  485. if t.ParentID == v {
  486. delete(s.channelMap, t.ID)
  487. continue outer
  488. }
  489. }
  490. guild.Threads[index] = t
  491. index++
  492. } else {
  493. delete(s.channelMap, t.ID)
  494. }
  495. }
  496. guild.Threads = guild.Threads[:index]
  497. for _, t := range tls.Threads {
  498. s.channelMap[t.ID] = t
  499. guild.Threads = append(guild.Threads, t)
  500. }
  501. for _, m := range tls.Members {
  502. if c, ok := s.channelMap[m.ID]; ok {
  503. c.Member = m
  504. }
  505. }
  506. return nil
  507. }
  508. // ThreadMembersUpdate updates thread members list
  509. func (s *State) ThreadMembersUpdate(tmu *ThreadMembersUpdate) error {
  510. thread, err := s.Channel(tmu.ID)
  511. if err != nil {
  512. return err
  513. }
  514. s.Lock()
  515. defer s.Unlock()
  516. for idx, member := range thread.Members {
  517. for _, removedMember := range tmu.RemovedMembers {
  518. if member.ID == removedMember {
  519. thread.Members = append(thread.Members[:idx], thread.Members[idx+1:]...)
  520. break
  521. }
  522. }
  523. }
  524. for _, addedMember := range tmu.AddedMembers {
  525. thread.Members = append(thread.Members, addedMember.ThreadMember)
  526. if addedMember.Member != nil {
  527. err = s.memberAdd(addedMember.Member)
  528. if err != nil {
  529. return err
  530. }
  531. }
  532. if addedMember.Presence != nil {
  533. err = s.presenceAdd(tmu.GuildID, addedMember.Presence)
  534. if err != nil {
  535. return err
  536. }
  537. }
  538. }
  539. thread.MemberCount = tmu.MemberCount
  540. return nil
  541. }
  542. // ThreadMemberUpdate sets or updates member data for the current user.
  543. func (s *State) ThreadMemberUpdate(mu *ThreadMemberUpdate) error {
  544. thread, err := s.Channel(mu.ID)
  545. if err != nil {
  546. return err
  547. }
  548. thread.Member = mu.ThreadMember
  549. return nil
  550. }
  551. // GuildChannel gets a channel by ID from a guild.
  552. // This method is Deprecated, use Channel(channelID)
  553. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  554. return s.Channel(channelID)
  555. }
  556. // PrivateChannel gets a private channel by ID.
  557. // This method is Deprecated, use Channel(channelID)
  558. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  559. return s.Channel(channelID)
  560. }
  561. // Channel gets a channel by ID, it will look in all guilds and private channels.
  562. func (s *State) Channel(channelID string) (*Channel, error) {
  563. if s == nil {
  564. return nil, ErrNilState
  565. }
  566. s.RLock()
  567. defer s.RUnlock()
  568. if c, ok := s.channelMap[channelID]; ok {
  569. return c, nil
  570. }
  571. return nil, ErrStateNotFound
  572. }
  573. // Emoji returns an emoji for a guild and emoji id.
  574. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  575. if s == nil {
  576. return nil, ErrNilState
  577. }
  578. guild, err := s.Guild(guildID)
  579. if err != nil {
  580. return nil, err
  581. }
  582. s.RLock()
  583. defer s.RUnlock()
  584. for _, e := range guild.Emojis {
  585. if e.ID == emojiID {
  586. return e, nil
  587. }
  588. }
  589. return nil, ErrStateNotFound
  590. }
  591. // EmojiAdd adds an emoji to the current world state.
  592. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  593. if s == nil {
  594. return ErrNilState
  595. }
  596. guild, err := s.Guild(guildID)
  597. if err != nil {
  598. return err
  599. }
  600. s.Lock()
  601. defer s.Unlock()
  602. for i, e := range guild.Emojis {
  603. if e.ID == emoji.ID {
  604. guild.Emojis[i] = emoji
  605. return nil
  606. }
  607. }
  608. guild.Emojis = append(guild.Emojis, emoji)
  609. return nil
  610. }
  611. // EmojisAdd adds multiple emojis to the world state.
  612. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  613. for _, e := range emojis {
  614. if err := s.EmojiAdd(guildID, e); err != nil {
  615. return err
  616. }
  617. }
  618. return nil
  619. }
  620. // MessageAdd adds a message to the current world state, or updates it if it exists.
  621. // If the channel cannot be found, the message is discarded.
  622. // Messages are kept in state up to s.MaxMessageCount per channel.
  623. func (s *State) MessageAdd(message *Message) error {
  624. if s == nil {
  625. return ErrNilState
  626. }
  627. c, err := s.Channel(message.ChannelID)
  628. if err != nil {
  629. return err
  630. }
  631. s.Lock()
  632. defer s.Unlock()
  633. // If the message exists, merge in the new message contents.
  634. for _, m := range c.Messages {
  635. if m.ID == message.ID {
  636. if message.Content != "" {
  637. m.Content = message.Content
  638. }
  639. if message.EditedTimestamp != nil {
  640. m.EditedTimestamp = message.EditedTimestamp
  641. }
  642. if message.Mentions != nil {
  643. m.Mentions = message.Mentions
  644. }
  645. if message.Embeds != nil {
  646. m.Embeds = message.Embeds
  647. }
  648. if message.Attachments != nil {
  649. m.Attachments = message.Attachments
  650. }
  651. if !message.Timestamp.IsZero() {
  652. m.Timestamp = message.Timestamp
  653. }
  654. if message.Author != nil {
  655. m.Author = message.Author
  656. }
  657. if message.Components != nil {
  658. m.Components = message.Components
  659. }
  660. return nil
  661. }
  662. }
  663. c.Messages = append(c.Messages, message)
  664. if len(c.Messages) > s.MaxMessageCount {
  665. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  666. }
  667. return nil
  668. }
  669. // MessageRemove removes a message from the world state.
  670. func (s *State) MessageRemove(message *Message) error {
  671. if s == nil {
  672. return ErrNilState
  673. }
  674. return s.messageRemoveByID(message.ChannelID, message.ID)
  675. }
  676. // messageRemoveByID removes a message by channelID and messageID from the world state.
  677. func (s *State) messageRemoveByID(channelID, messageID string) error {
  678. c, err := s.Channel(channelID)
  679. if err != nil {
  680. return err
  681. }
  682. s.Lock()
  683. defer s.Unlock()
  684. for i, m := range c.Messages {
  685. if m.ID == messageID {
  686. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  687. return nil
  688. }
  689. }
  690. return ErrStateNotFound
  691. }
  692. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  693. guild, err := s.Guild(update.GuildID)
  694. if err != nil {
  695. return err
  696. }
  697. s.Lock()
  698. defer s.Unlock()
  699. // Handle Leaving Channel
  700. if update.ChannelID == "" {
  701. for i, state := range guild.VoiceStates {
  702. if state.UserID == update.UserID {
  703. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  704. return nil
  705. }
  706. }
  707. } else {
  708. for i, state := range guild.VoiceStates {
  709. if state.UserID == update.UserID {
  710. guild.VoiceStates[i] = update.VoiceState
  711. return nil
  712. }
  713. }
  714. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  715. }
  716. return nil
  717. }
  718. // VoiceState gets a VoiceState by guild and user ID.
  719. func (s *State) VoiceState(guildID, userID string) (*VoiceState, error) {
  720. if s == nil {
  721. return nil, ErrNilState
  722. }
  723. guild, err := s.Guild(guildID)
  724. if err != nil {
  725. return nil, err
  726. }
  727. for _, state := range guild.VoiceStates {
  728. if state.UserID == userID {
  729. return state, nil
  730. }
  731. }
  732. return nil, ErrStateNotFound
  733. }
  734. // Message gets a message by channel and message ID.
  735. func (s *State) Message(channelID, messageID string) (*Message, error) {
  736. if s == nil {
  737. return nil, ErrNilState
  738. }
  739. c, err := s.Channel(channelID)
  740. if err != nil {
  741. return nil, err
  742. }
  743. s.RLock()
  744. defer s.RUnlock()
  745. for _, m := range c.Messages {
  746. if m.ID == messageID {
  747. return m, nil
  748. }
  749. }
  750. return nil, ErrStateNotFound
  751. }
  752. // OnReady takes a Ready event and updates all internal state.
  753. func (s *State) onReady(se *Session, r *Ready) (err error) {
  754. if s == nil {
  755. return ErrNilState
  756. }
  757. s.Lock()
  758. defer s.Unlock()
  759. // We must track at least the current user for Voice, even
  760. // if state is disabled, store the bare essentials.
  761. if !se.StateEnabled {
  762. ready := Ready{
  763. Version: r.Version,
  764. SessionID: r.SessionID,
  765. User: r.User,
  766. }
  767. s.Ready = ready
  768. return nil
  769. }
  770. s.Ready = *r
  771. for _, g := range s.Guilds {
  772. s.guildMap[g.ID] = g
  773. s.createMemberMap(g)
  774. for _, c := range g.Channels {
  775. s.channelMap[c.ID] = c
  776. }
  777. }
  778. for _, c := range s.PrivateChannels {
  779. s.channelMap[c.ID] = c
  780. }
  781. return nil
  782. }
  783. // OnInterface handles all events related to states.
  784. func (s *State) OnInterface(se *Session, i interface{}) (err error) {
  785. if s == nil {
  786. return ErrNilState
  787. }
  788. r, ok := i.(*Ready)
  789. if ok {
  790. return s.onReady(se, r)
  791. }
  792. if !se.StateEnabled {
  793. return nil
  794. }
  795. switch t := i.(type) {
  796. case *GuildCreate:
  797. err = s.GuildAdd(t.Guild)
  798. case *GuildUpdate:
  799. err = s.GuildAdd(t.Guild)
  800. case *GuildDelete:
  801. var old *Guild
  802. old, err = s.Guild(t.ID)
  803. if err == nil {
  804. oldCopy := *old
  805. t.BeforeDelete = &oldCopy
  806. }
  807. err = s.GuildRemove(t.Guild)
  808. case *GuildMemberAdd:
  809. // Updates the MemberCount of the guild.
  810. guild, err := s.Guild(t.Member.GuildID)
  811. if err != nil {
  812. return err
  813. }
  814. guild.MemberCount++
  815. // Caches member if tracking is enabled.
  816. if s.TrackMembers {
  817. err = s.MemberAdd(t.Member)
  818. }
  819. case *GuildMemberUpdate:
  820. if s.TrackMembers {
  821. err = s.MemberAdd(t.Member)
  822. }
  823. case *GuildMemberRemove:
  824. // Updates the MemberCount of the guild.
  825. guild, err := s.Guild(t.Member.GuildID)
  826. if err != nil {
  827. return err
  828. }
  829. guild.MemberCount--
  830. // Removes member from the cache if tracking is enabled.
  831. if s.TrackMembers {
  832. err = s.MemberRemove(t.Member)
  833. }
  834. case *GuildMembersChunk:
  835. if s.TrackMembers {
  836. for i := range t.Members {
  837. t.Members[i].GuildID = t.GuildID
  838. err = s.MemberAdd(t.Members[i])
  839. }
  840. }
  841. if s.TrackPresences {
  842. for _, p := range t.Presences {
  843. err = s.PresenceAdd(t.GuildID, p)
  844. }
  845. }
  846. case *GuildRoleCreate:
  847. if s.TrackRoles {
  848. err = s.RoleAdd(t.GuildID, t.Role)
  849. }
  850. case *GuildRoleUpdate:
  851. if s.TrackRoles {
  852. err = s.RoleAdd(t.GuildID, t.Role)
  853. }
  854. case *GuildRoleDelete:
  855. if s.TrackRoles {
  856. err = s.RoleRemove(t.GuildID, t.RoleID)
  857. }
  858. case *GuildEmojisUpdate:
  859. if s.TrackEmojis {
  860. err = s.EmojisAdd(t.GuildID, t.Emojis)
  861. }
  862. case *ChannelCreate:
  863. if s.TrackChannels {
  864. err = s.ChannelAdd(t.Channel)
  865. }
  866. case *ChannelUpdate:
  867. if s.TrackChannels {
  868. err = s.ChannelAdd(t.Channel)
  869. }
  870. case *ChannelDelete:
  871. if s.TrackChannels {
  872. err = s.ChannelRemove(t.Channel)
  873. }
  874. case *ThreadCreate:
  875. if s.TrackThreads {
  876. err = s.ChannelAdd(t.Channel)
  877. }
  878. case *ThreadUpdate:
  879. if s.TrackThreads {
  880. old, err := s.Channel(t.ID)
  881. if err == nil {
  882. oldCopy := *old
  883. t.BeforeUpdate = &oldCopy
  884. }
  885. err = s.ChannelAdd(t.Channel)
  886. }
  887. case *ThreadDelete:
  888. if s.TrackThreads {
  889. err = s.ChannelRemove(t.Channel)
  890. }
  891. case *ThreadMemberUpdate:
  892. if s.TrackThreads {
  893. err = s.ThreadMemberUpdate(t)
  894. }
  895. case *ThreadMembersUpdate:
  896. if s.TrackThreadMembers {
  897. err = s.ThreadMembersUpdate(t)
  898. }
  899. case *ThreadListSync:
  900. if s.TrackThreads {
  901. err = s.ThreadListSync(t)
  902. }
  903. case *MessageCreate:
  904. if s.MaxMessageCount != 0 {
  905. err = s.MessageAdd(t.Message)
  906. }
  907. case *MessageUpdate:
  908. if s.MaxMessageCount != 0 {
  909. var old *Message
  910. old, err = s.Message(t.ChannelID, t.ID)
  911. if err == nil {
  912. oldCopy := *old
  913. t.BeforeUpdate = &oldCopy
  914. }
  915. err = s.MessageAdd(t.Message)
  916. }
  917. case *MessageDelete:
  918. if s.MaxMessageCount != 0 {
  919. var old *Message
  920. old, err = s.Message(t.ChannelID, t.ID)
  921. if err == nil {
  922. oldCopy := *old
  923. t.BeforeDelete = &oldCopy
  924. }
  925. err = s.MessageRemove(t.Message)
  926. }
  927. case *MessageDeleteBulk:
  928. if s.MaxMessageCount != 0 {
  929. for _, mID := range t.Messages {
  930. s.messageRemoveByID(t.ChannelID, mID)
  931. }
  932. }
  933. case *VoiceStateUpdate:
  934. if s.TrackVoice {
  935. var old *VoiceState
  936. old, err = s.VoiceState(t.GuildID, t.UserID)
  937. if err == nil {
  938. oldCopy := *old
  939. t.BeforeUpdate = &oldCopy
  940. }
  941. err = s.voiceStateUpdate(t)
  942. }
  943. case *PresenceUpdate:
  944. if s.TrackPresences {
  945. s.PresenceAdd(t.GuildID, &t.Presence)
  946. }
  947. if s.TrackMembers {
  948. if t.Status == StatusOffline {
  949. return
  950. }
  951. var m *Member
  952. m, err = s.Member(t.GuildID, t.User.ID)
  953. if err != nil {
  954. // Member not found; this is a user coming online
  955. m = &Member{
  956. GuildID: t.GuildID,
  957. User: t.User,
  958. }
  959. } else {
  960. if t.User.Username != "" {
  961. m.User.Username = t.User.Username
  962. }
  963. }
  964. err = s.MemberAdd(m)
  965. }
  966. }
  967. return
  968. }
  969. // UserChannelPermissions returns the permission of a user in a channel.
  970. // userID : The ID of the user to calculate permissions for.
  971. // channelID : The ID of the channel to calculate permission for.
  972. func (s *State) UserChannelPermissions(userID, channelID string) (apermissions int64, err error) {
  973. if s == nil {
  974. return 0, ErrNilState
  975. }
  976. channel, err := s.Channel(channelID)
  977. if err != nil {
  978. return
  979. }
  980. guild, err := s.Guild(channel.GuildID)
  981. if err != nil {
  982. return
  983. }
  984. member, err := s.Member(guild.ID, userID)
  985. if err != nil {
  986. return
  987. }
  988. return memberPermissions(guild, channel, userID, member.Roles), nil
  989. }
  990. // MessagePermissions returns the permissions of the author of the message
  991. // in the channel in which it was sent.
  992. func (s *State) MessagePermissions(message *Message) (apermissions int64, err error) {
  993. if s == nil {
  994. return 0, ErrNilState
  995. }
  996. if message.Author == nil || message.Member == nil {
  997. return 0, ErrMessageIncompletePermissions
  998. }
  999. channel, err := s.Channel(message.ChannelID)
  1000. if err != nil {
  1001. return
  1002. }
  1003. guild, err := s.Guild(channel.GuildID)
  1004. if err != nil {
  1005. return
  1006. }
  1007. return memberPermissions(guild, channel, message.Author.ID, message.Member.Roles), nil
  1008. }
  1009. // UserColor returns the color of a user in a channel.
  1010. // While colors are defined at a Guild level, determining for a channel is more useful in message handlers.
  1011. // 0 is returned in cases of error, which is the color of @everyone.
  1012. // userID : The ID of the user to calculate the color for.
  1013. // channelID : The ID of the channel to calculate the color for.
  1014. func (s *State) UserColor(userID, channelID string) int {
  1015. if s == nil {
  1016. return 0
  1017. }
  1018. channel, err := s.Channel(channelID)
  1019. if err != nil {
  1020. return 0
  1021. }
  1022. guild, err := s.Guild(channel.GuildID)
  1023. if err != nil {
  1024. return 0
  1025. }
  1026. member, err := s.Member(guild.ID, userID)
  1027. if err != nil {
  1028. return 0
  1029. }
  1030. return firstRoleColorColor(guild, member.Roles)
  1031. }
  1032. // MessageColor returns the color of the author's name as displayed
  1033. // in the client associated with this message.
  1034. func (s *State) MessageColor(message *Message) int {
  1035. if s == nil {
  1036. return 0
  1037. }
  1038. if message.Member == nil || message.Member.Roles == nil {
  1039. return 0
  1040. }
  1041. channel, err := s.Channel(message.ChannelID)
  1042. if err != nil {
  1043. return 0
  1044. }
  1045. guild, err := s.Guild(channel.GuildID)
  1046. if err != nil {
  1047. return 0
  1048. }
  1049. return firstRoleColorColor(guild, message.Member.Roles)
  1050. }
  1051. func firstRoleColorColor(guild *Guild, memberRoles []string) int {
  1052. roles := Roles(guild.Roles)
  1053. sort.Sort(roles)
  1054. for _, role := range roles {
  1055. for _, roleID := range memberRoles {
  1056. if role.ID == roleID {
  1057. if role.Color != 0 {
  1058. return role.Color
  1059. }
  1060. }
  1061. }
  1062. }
  1063. for _, role := range roles {
  1064. if role.ID == guild.ID {
  1065. return role.Color
  1066. }
  1067. }
  1068. return 0
  1069. }