state.go 16 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sort"
  14. "sync"
  15. )
  16. // ErrNilState is returned when the state is nil.
  17. var ErrNilState = errors.New("state not instantiated, please use discordgo.New() or assign Session.State")
  18. // A State contains the current known state.
  19. // As discord sends this in a READY blob, it seems reasonable to simply
  20. // use that struct as the data store.
  21. type State struct {
  22. sync.RWMutex
  23. Ready
  24. MaxMessageCount int
  25. TrackChannels bool
  26. TrackEmojis bool
  27. TrackMembers bool
  28. TrackRoles bool
  29. TrackVoice bool
  30. guildMap map[string]*Guild
  31. channelMap map[string]*Channel
  32. }
  33. // NewState creates an empty state.
  34. func NewState() *State {
  35. return &State{
  36. Ready: Ready{
  37. PrivateChannels: []*Channel{},
  38. Guilds: []*Guild{},
  39. },
  40. TrackChannels: true,
  41. TrackEmojis: true,
  42. TrackMembers: true,
  43. TrackRoles: true,
  44. TrackVoice: true,
  45. guildMap: make(map[string]*Guild),
  46. channelMap: make(map[string]*Channel),
  47. }
  48. }
  49. // GuildAdd adds a guild to the current world state, or
  50. // updates it if it already exists.
  51. func (s *State) GuildAdd(guild *Guild) error {
  52. if s == nil {
  53. return ErrNilState
  54. }
  55. s.Lock()
  56. defer s.Unlock()
  57. // Update the channels to point to the right guild, adding them to the channelMap as we go
  58. for _, c := range guild.Channels {
  59. s.channelMap[c.ID] = c
  60. }
  61. if g, ok := s.guildMap[guild.ID]; ok {
  62. // We are about to replace `g` in the state with `guild`, but first we need to
  63. // make sure we preserve any fields that the `guild` doesn't contain from `g`.
  64. if guild.Roles == nil {
  65. guild.Roles = g.Roles
  66. }
  67. if guild.Emojis == nil {
  68. guild.Emojis = g.Emojis
  69. }
  70. if guild.Members == nil {
  71. guild.Members = g.Members
  72. }
  73. if guild.Presences == nil {
  74. guild.Presences = g.Presences
  75. }
  76. if guild.Channels == nil {
  77. guild.Channels = g.Channels
  78. }
  79. if guild.VoiceStates == nil {
  80. guild.VoiceStates = g.VoiceStates
  81. }
  82. *g = *guild
  83. return nil
  84. }
  85. s.Guilds = append(s.Guilds, guild)
  86. s.guildMap[guild.ID] = guild
  87. return nil
  88. }
  89. // GuildRemove removes a guild from current world state.
  90. func (s *State) GuildRemove(guild *Guild) error {
  91. if s == nil {
  92. return ErrNilState
  93. }
  94. _, err := s.Guild(guild.ID)
  95. if err != nil {
  96. return err
  97. }
  98. s.Lock()
  99. defer s.Unlock()
  100. delete(s.guildMap, guild.ID)
  101. for i, g := range s.Guilds {
  102. if g.ID == guild.ID {
  103. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  104. return nil
  105. }
  106. }
  107. return nil
  108. }
  109. // Guild gets a guild by ID.
  110. // Useful for querying if @me is in a guild:
  111. // _, err := discordgo.Session.State.Guild(guildID)
  112. // isInGuild := err == nil
  113. func (s *State) Guild(guildID string) (*Guild, error) {
  114. if s == nil {
  115. return nil, ErrNilState
  116. }
  117. s.RLock()
  118. defer s.RUnlock()
  119. if g, ok := s.guildMap[guildID]; ok {
  120. return g, nil
  121. }
  122. return nil, errors.New("guild not found")
  123. }
  124. // TODO: Consider moving Guild state update methods onto *Guild.
  125. // MemberAdd adds a member to the current world state, or
  126. // updates it if it already exists.
  127. func (s *State) MemberAdd(member *Member) error {
  128. if s == nil {
  129. return ErrNilState
  130. }
  131. guild, err := s.Guild(member.GuildID)
  132. if err != nil {
  133. return err
  134. }
  135. s.Lock()
  136. defer s.Unlock()
  137. for i, m := range guild.Members {
  138. if m.User.ID == member.User.ID {
  139. guild.Members[i] = member
  140. return nil
  141. }
  142. }
  143. guild.Members = append(guild.Members, member)
  144. return nil
  145. }
  146. // MemberRemove removes a member from current world state.
  147. func (s *State) MemberRemove(member *Member) error {
  148. if s == nil {
  149. return ErrNilState
  150. }
  151. guild, err := s.Guild(member.GuildID)
  152. if err != nil {
  153. return err
  154. }
  155. s.Lock()
  156. defer s.Unlock()
  157. for i, m := range guild.Members {
  158. if m.User.ID == member.User.ID {
  159. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  160. return nil
  161. }
  162. }
  163. return errors.New("member not found")
  164. }
  165. // Member gets a member by ID from a guild.
  166. func (s *State) Member(guildID, userID string) (*Member, error) {
  167. if s == nil {
  168. return nil, ErrNilState
  169. }
  170. guild, err := s.Guild(guildID)
  171. if err != nil {
  172. return nil, err
  173. }
  174. s.RLock()
  175. defer s.RUnlock()
  176. for _, m := range guild.Members {
  177. if m.User.ID == userID {
  178. return m, nil
  179. }
  180. }
  181. return nil, errors.New("member not found")
  182. }
  183. // RoleAdd adds a role to the current world state, or
  184. // updates it if it already exists.
  185. func (s *State) RoleAdd(guildID string, role *Role) error {
  186. if s == nil {
  187. return ErrNilState
  188. }
  189. guild, err := s.Guild(guildID)
  190. if err != nil {
  191. return err
  192. }
  193. s.Lock()
  194. defer s.Unlock()
  195. for i, r := range guild.Roles {
  196. if r.ID == role.ID {
  197. guild.Roles[i] = role
  198. return nil
  199. }
  200. }
  201. guild.Roles = append(guild.Roles, role)
  202. return nil
  203. }
  204. // RoleRemove removes a role from current world state by ID.
  205. func (s *State) RoleRemove(guildID, roleID string) error {
  206. if s == nil {
  207. return ErrNilState
  208. }
  209. guild, err := s.Guild(guildID)
  210. if err != nil {
  211. return err
  212. }
  213. s.Lock()
  214. defer s.Unlock()
  215. for i, r := range guild.Roles {
  216. if r.ID == roleID {
  217. guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...)
  218. return nil
  219. }
  220. }
  221. return errors.New("role not found")
  222. }
  223. // Role gets a role by ID from a guild.
  224. func (s *State) Role(guildID, roleID string) (*Role, error) {
  225. if s == nil {
  226. return nil, ErrNilState
  227. }
  228. guild, err := s.Guild(guildID)
  229. if err != nil {
  230. return nil, err
  231. }
  232. s.RLock()
  233. defer s.RUnlock()
  234. for _, r := range guild.Roles {
  235. if r.ID == roleID {
  236. return r, nil
  237. }
  238. }
  239. return nil, errors.New("role not found")
  240. }
  241. // ChannelAdd adds a channel to the current world state, or
  242. // updates it if it already exists.
  243. // Channels may exist either as PrivateChannels or inside
  244. // a guild.
  245. func (s *State) ChannelAdd(channel *Channel) error {
  246. if s == nil {
  247. return ErrNilState
  248. }
  249. s.Lock()
  250. defer s.Unlock()
  251. // If the channel exists, replace it
  252. if c, ok := s.channelMap[channel.ID]; ok {
  253. if channel.Messages == nil {
  254. channel.Messages = c.Messages
  255. }
  256. if channel.PermissionOverwrites == nil {
  257. channel.PermissionOverwrites = c.PermissionOverwrites
  258. }
  259. *c = *channel
  260. return nil
  261. }
  262. if channel.IsPrivate {
  263. s.PrivateChannels = append(s.PrivateChannels, channel)
  264. } else {
  265. guild, ok := s.guildMap[channel.GuildID]
  266. if !ok {
  267. return errors.New("guild for channel not found")
  268. }
  269. guild.Channels = append(guild.Channels, channel)
  270. }
  271. s.channelMap[channel.ID] = channel
  272. return nil
  273. }
  274. // ChannelRemove removes a channel from current world state.
  275. func (s *State) ChannelRemove(channel *Channel) error {
  276. if s == nil {
  277. return ErrNilState
  278. }
  279. _, err := s.Channel(channel.ID)
  280. if err != nil {
  281. return err
  282. }
  283. if channel.IsPrivate {
  284. s.Lock()
  285. defer s.Unlock()
  286. for i, c := range s.PrivateChannels {
  287. if c.ID == channel.ID {
  288. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  289. break
  290. }
  291. }
  292. } else {
  293. guild, err := s.Guild(channel.GuildID)
  294. if err != nil {
  295. return err
  296. }
  297. s.Lock()
  298. defer s.Unlock()
  299. for i, c := range guild.Channels {
  300. if c.ID == channel.ID {
  301. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  302. break
  303. }
  304. }
  305. }
  306. delete(s.channelMap, channel.ID)
  307. return nil
  308. }
  309. // GuildChannel gets a channel by ID from a guild.
  310. // This method is Deprecated, use Channel(channelID)
  311. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  312. return s.Channel(channelID)
  313. }
  314. // PrivateChannel gets a private channel by ID.
  315. // This method is Deprecated, use Channel(channelID)
  316. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  317. return s.Channel(channelID)
  318. }
  319. // Channel gets a channel by ID, it will look in all guilds an private channels.
  320. func (s *State) Channel(channelID string) (*Channel, error) {
  321. if s == nil {
  322. return nil, ErrNilState
  323. }
  324. s.RLock()
  325. defer s.RUnlock()
  326. if c, ok := s.channelMap[channelID]; ok {
  327. return c, nil
  328. }
  329. return nil, errors.New("channel not found")
  330. }
  331. // Emoji returns an emoji for a guild and emoji id.
  332. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  333. if s == nil {
  334. return nil, ErrNilState
  335. }
  336. guild, err := s.Guild(guildID)
  337. if err != nil {
  338. return nil, err
  339. }
  340. s.RLock()
  341. defer s.RUnlock()
  342. for _, e := range guild.Emojis {
  343. if e.ID == emojiID {
  344. return e, nil
  345. }
  346. }
  347. return nil, errors.New("emoji not found")
  348. }
  349. // EmojiAdd adds an emoji to the current world state.
  350. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  351. if s == nil {
  352. return ErrNilState
  353. }
  354. guild, err := s.Guild(guildID)
  355. if err != nil {
  356. return err
  357. }
  358. s.Lock()
  359. defer s.Unlock()
  360. for i, e := range guild.Emojis {
  361. if e.ID == emoji.ID {
  362. guild.Emojis[i] = emoji
  363. return nil
  364. }
  365. }
  366. guild.Emojis = append(guild.Emojis, emoji)
  367. return nil
  368. }
  369. // EmojisAdd adds multiple emojis to the world state.
  370. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  371. for _, e := range emojis {
  372. if err := s.EmojiAdd(guildID, e); err != nil {
  373. return err
  374. }
  375. }
  376. return nil
  377. }
  378. // MessageAdd adds a message to the current world state, or updates it if it exists.
  379. // If the channel cannot be found, the message is discarded.
  380. // Messages are kept in state up to s.MaxMessageCount
  381. func (s *State) MessageAdd(message *Message) error {
  382. if s == nil {
  383. return ErrNilState
  384. }
  385. c, err := s.Channel(message.ChannelID)
  386. if err != nil {
  387. return err
  388. }
  389. s.Lock()
  390. defer s.Unlock()
  391. // If the message exists, merge in the new message contents.
  392. for _, m := range c.Messages {
  393. if m.ID == message.ID {
  394. if message.Content != "" {
  395. m.Content = message.Content
  396. }
  397. if message.EditedTimestamp != "" {
  398. m.EditedTimestamp = message.EditedTimestamp
  399. }
  400. if message.Mentions != nil {
  401. m.Mentions = message.Mentions
  402. }
  403. if message.Embeds != nil {
  404. m.Embeds = message.Embeds
  405. }
  406. if message.Attachments != nil {
  407. m.Attachments = message.Attachments
  408. }
  409. if message.Timestamp != "" {
  410. m.Timestamp = message.Timestamp
  411. }
  412. if message.Author != nil {
  413. m.Author = message.Author
  414. }
  415. return nil
  416. }
  417. }
  418. c.Messages = append(c.Messages, message)
  419. if len(c.Messages) > s.MaxMessageCount {
  420. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  421. }
  422. return nil
  423. }
  424. // MessageRemove removes a message from the world state.
  425. func (s *State) MessageRemove(message *Message) error {
  426. if s == nil {
  427. return ErrNilState
  428. }
  429. return s.messageRemoveByID(message.ChannelID, message.ID)
  430. }
  431. // messageRemoveByID removes a message by channelID and messageID from the world state.
  432. func (s *State) messageRemoveByID(channelID, messageID string) error {
  433. c, err := s.Channel(channelID)
  434. if err != nil {
  435. return err
  436. }
  437. s.Lock()
  438. defer s.Unlock()
  439. for i, m := range c.Messages {
  440. if m.ID == messageID {
  441. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  442. return nil
  443. }
  444. }
  445. return errors.New("message not found")
  446. }
  447. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  448. guild, err := s.Guild(update.GuildID)
  449. if err != nil {
  450. return err
  451. }
  452. s.Lock()
  453. defer s.Unlock()
  454. // Handle Leaving Channel
  455. if update.ChannelID == "" {
  456. for i, state := range guild.VoiceStates {
  457. if state.UserID == update.UserID {
  458. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  459. return nil
  460. }
  461. }
  462. } else {
  463. for i, state := range guild.VoiceStates {
  464. if state.UserID == update.UserID {
  465. guild.VoiceStates[i] = update.VoiceState
  466. return nil
  467. }
  468. }
  469. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  470. }
  471. return nil
  472. }
  473. // Message gets a message by channel and message ID.
  474. func (s *State) Message(channelID, messageID string) (*Message, error) {
  475. if s == nil {
  476. return nil, ErrNilState
  477. }
  478. c, err := s.Channel(channelID)
  479. if err != nil {
  480. return nil, err
  481. }
  482. s.RLock()
  483. defer s.RUnlock()
  484. for _, m := range c.Messages {
  485. if m.ID == messageID {
  486. return m, nil
  487. }
  488. }
  489. return nil, errors.New("message not found")
  490. }
  491. // OnReady takes a Ready event and updates all internal state.
  492. func (s *State) onReady(se *Session, r *Ready) (err error) {
  493. if s == nil {
  494. return ErrNilState
  495. }
  496. s.Lock()
  497. defer s.Unlock()
  498. // We must track at least the current user for Voice, even
  499. // if state is disabled, store the bare essentials.
  500. if !se.StateEnabled {
  501. ready := Ready{
  502. Version: r.Version,
  503. SessionID: r.SessionID,
  504. User: r.User,
  505. }
  506. s.Ready = ready
  507. return nil
  508. }
  509. s.Ready = *r
  510. for _, g := range s.Guilds {
  511. s.guildMap[g.ID] = g
  512. for _, c := range g.Channels {
  513. s.channelMap[c.ID] = c
  514. }
  515. }
  516. for _, c := range s.PrivateChannels {
  517. s.channelMap[c.ID] = c
  518. }
  519. return nil
  520. }
  521. // onInterface handles all events related to states.
  522. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  523. if s == nil {
  524. return ErrNilState
  525. }
  526. r, ok := i.(*Ready)
  527. if ok {
  528. return s.onReady(se, r)
  529. }
  530. if !se.StateEnabled {
  531. return nil
  532. }
  533. switch t := i.(type) {
  534. case *GuildCreate:
  535. err = s.GuildAdd(t.Guild)
  536. case *GuildUpdate:
  537. err = s.GuildAdd(t.Guild)
  538. case *GuildDelete:
  539. err = s.GuildRemove(t.Guild)
  540. case *GuildMemberAdd:
  541. if s.TrackMembers {
  542. err = s.MemberAdd(t.Member)
  543. }
  544. case *GuildMemberUpdate:
  545. if s.TrackMembers {
  546. err = s.MemberAdd(t.Member)
  547. }
  548. case *GuildMemberRemove:
  549. if s.TrackMembers {
  550. err = s.MemberRemove(t.Member)
  551. }
  552. case *GuildRoleCreate:
  553. if s.TrackRoles {
  554. err = s.RoleAdd(t.GuildID, t.Role)
  555. }
  556. case *GuildRoleUpdate:
  557. if s.TrackRoles {
  558. err = s.RoleAdd(t.GuildID, t.Role)
  559. }
  560. case *GuildRoleDelete:
  561. if s.TrackRoles {
  562. err = s.RoleRemove(t.GuildID, t.RoleID)
  563. }
  564. case *GuildEmojisUpdate:
  565. if s.TrackEmojis {
  566. err = s.EmojisAdd(t.GuildID, t.Emojis)
  567. }
  568. case *ChannelCreate:
  569. if s.TrackChannels {
  570. err = s.ChannelAdd(t.Channel)
  571. }
  572. case *ChannelUpdate:
  573. if s.TrackChannels {
  574. err = s.ChannelAdd(t.Channel)
  575. }
  576. case *ChannelDelete:
  577. if s.TrackChannels {
  578. err = s.ChannelRemove(t.Channel)
  579. }
  580. case *MessageCreate:
  581. if s.MaxMessageCount != 0 {
  582. err = s.MessageAdd(t.Message)
  583. }
  584. case *MessageUpdate:
  585. if s.MaxMessageCount != 0 {
  586. err = s.MessageAdd(t.Message)
  587. }
  588. case *MessageDelete:
  589. if s.MaxMessageCount != 0 {
  590. err = s.MessageRemove(t.Message)
  591. }
  592. case *MessageDeleteBulk:
  593. if s.MaxMessageCount != 0 {
  594. for _, mID := range t.Messages {
  595. s.messageRemoveByID(t.ChannelID, mID)
  596. }
  597. }
  598. case *VoiceStateUpdate:
  599. if s.TrackVoice {
  600. err = s.voiceStateUpdate(t)
  601. }
  602. }
  603. return
  604. }
  605. // UserChannelPermissions returns the permission of a user in a channel.
  606. // userID : The ID of the user to calculate permissions for.
  607. // channelID : The ID of the channel to calculate permission for.
  608. func (s *State) UserChannelPermissions(userID, channelID string) (apermissions int, err error) {
  609. if s == nil {
  610. return 0, ErrNilState
  611. }
  612. channel, err := s.Channel(channelID)
  613. if err != nil {
  614. return
  615. }
  616. guild, err := s.Guild(channel.GuildID)
  617. if err != nil {
  618. return
  619. }
  620. if userID == guild.OwnerID {
  621. apermissions = PermissionAll
  622. return
  623. }
  624. member, err := s.Member(guild.ID, userID)
  625. if err != nil {
  626. return
  627. }
  628. return memberPermissions(guild, channel, member), nil
  629. }
  630. // UserColor returns the color of a user in a channel.
  631. // While colors are defined at a Guild level, determining for a channel is more useful in message handlers.
  632. // 0 is returned in cases of error, which is the color of @everyone.
  633. // userID : The ID of the user to calculate the color for.
  634. // channelID : The ID of the channel to calculate the color for.
  635. func (s *State) UserColor(userID, channelID string) int {
  636. if s == nil {
  637. return 0
  638. }
  639. channel, err := s.Channel(channelID)
  640. if err != nil {
  641. return 0
  642. }
  643. guild, err := s.Guild(channel.GuildID)
  644. if err != nil {
  645. return 0
  646. }
  647. member, err := s.Member(guild.ID, userID)
  648. if err != nil {
  649. return 0
  650. }
  651. roles := Roles(guild.Roles)
  652. sort.Sort(roles)
  653. for _, role := range roles {
  654. for _, roleID := range member.Roles {
  655. if role.ID == roleID {
  656. if role.Color != 0 {
  657. return role.Color
  658. }
  659. }
  660. }
  661. }
  662. return 0
  663. }