state.go 16 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sort"
  14. "sync"
  15. )
  16. // ErrNilState is returned when the state is nil.
  17. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  18. // A State contains the current known state.
  19. // As discord sends this in a READY blob, it seems reasonable to simply
  20. // use that struct as the data store.
  21. type State struct {
  22. sync.RWMutex
  23. Ready
  24. MaxMessageCount int
  25. TrackChannels bool
  26. TrackEmojis bool
  27. TrackMembers bool
  28. TrackRoles bool
  29. TrackVoice bool
  30. guildMap map[string]*Guild
  31. channelMap map[string]*Channel
  32. }
  33. // NewState creates an empty state.
  34. func NewState() *State {
  35. return &State{
  36. Ready: Ready{
  37. PrivateChannels: []*Channel{},
  38. Guilds: []*Guild{},
  39. },
  40. TrackChannels: true,
  41. TrackEmojis: true,
  42. TrackMembers: true,
  43. TrackRoles: true,
  44. TrackVoice: true,
  45. guildMap: make(map[string]*Guild),
  46. channelMap: make(map[string]*Channel),
  47. }
  48. }
  49. // GuildAdd adds a guild to the current world state, or
  50. // updates it if it already exists.
  51. func (s *State) GuildAdd(guild *Guild) error {
  52. if s == nil {
  53. return ErrNilState
  54. }
  55. s.Lock()
  56. defer s.Unlock()
  57. // Update the channels to point to the right guild, adding them to the channelMap as we go
  58. for _, c := range guild.Channels {
  59. s.channelMap[c.ID] = c
  60. }
  61. if g, ok := s.guildMap[guild.ID]; ok {
  62. // We are about to replace `g` in the state with `guild`, but first we need to
  63. // make sure we preserve any fields that the `guild` doesn't contain from `g`.
  64. if guild.Roles == nil {
  65. guild.Roles = g.Roles
  66. }
  67. if guild.Emojis == nil {
  68. guild.Emojis = g.Emojis
  69. }
  70. if guild.Members == nil {
  71. guild.Members = g.Members
  72. }
  73. if guild.Presences == nil {
  74. guild.Presences = g.Presences
  75. }
  76. if guild.Channels == nil {
  77. guild.Channels = g.Channels
  78. }
  79. if guild.VoiceStates == nil {
  80. guild.VoiceStates = g.VoiceStates
  81. }
  82. *g = *guild
  83. return nil
  84. }
  85. s.Guilds = append(s.Guilds, guild)
  86. s.guildMap[guild.ID] = guild
  87. return nil
  88. }
  89. // GuildRemove removes a guild from current world state.
  90. func (s *State) GuildRemove(guild *Guild) error {
  91. if s == nil {
  92. return ErrNilState
  93. }
  94. _, err := s.Guild(guild.ID)
  95. if err != nil {
  96. return err
  97. }
  98. s.Lock()
  99. defer s.Unlock()
  100. delete(s.guildMap, guild.ID)
  101. for i, g := range s.Guilds {
  102. if g.ID == guild.ID {
  103. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  104. return nil
  105. }
  106. }
  107. return nil
  108. }
  109. // Guild gets a guild by ID.
  110. // Useful for querying if @me is in a guild:
  111. // _, err := discordgo.Session.State.Guild(guildID)
  112. // isInGuild := err == nil
  113. func (s *State) Guild(guildID string) (*Guild, error) {
  114. if s == nil {
  115. return nil, ErrNilState
  116. }
  117. s.RLock()
  118. defer s.RUnlock()
  119. if g, ok := s.guildMap[guildID]; ok {
  120. return g, nil
  121. }
  122. return nil, errors.New("Guild not found.")
  123. }
  124. // TODO: Consider moving Guild state update methods onto *Guild.
  125. // MemberAdd adds a member to the current world state, or
  126. // updates it if it already exists.
  127. func (s *State) MemberAdd(member *Member) error {
  128. if s == nil {
  129. return ErrNilState
  130. }
  131. guild, err := s.Guild(member.GuildID)
  132. if err != nil {
  133. return err
  134. }
  135. s.Lock()
  136. defer s.Unlock()
  137. for i, m := range guild.Members {
  138. if m.User.ID == member.User.ID {
  139. guild.Members[i] = member
  140. return nil
  141. }
  142. }
  143. guild.Members = append(guild.Members, member)
  144. return nil
  145. }
  146. // MemberRemove removes a member from current world state.
  147. func (s *State) MemberRemove(member *Member) error {
  148. if s == nil {
  149. return ErrNilState
  150. }
  151. guild, err := s.Guild(member.GuildID)
  152. if err != nil {
  153. return err
  154. }
  155. s.Lock()
  156. defer s.Unlock()
  157. for i, m := range guild.Members {
  158. if m.User.ID == member.User.ID {
  159. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  160. return nil
  161. }
  162. }
  163. return errors.New("Member not found.")
  164. }
  165. // Member gets a member by ID from a guild.
  166. func (s *State) Member(guildID, userID string) (*Member, error) {
  167. if s == nil {
  168. return nil, ErrNilState
  169. }
  170. guild, err := s.Guild(guildID)
  171. if err != nil {
  172. return nil, err
  173. }
  174. s.RLock()
  175. defer s.RUnlock()
  176. for _, m := range guild.Members {
  177. if m.User.ID == userID {
  178. return m, nil
  179. }
  180. }
  181. return nil, errors.New("Member not found.")
  182. }
  183. // RoleAdd adds a role to the current world state, or
  184. // updates it if it already exists.
  185. func (s *State) RoleAdd(guildID string, role *Role) error {
  186. if s == nil {
  187. return ErrNilState
  188. }
  189. guild, err := s.Guild(guildID)
  190. if err != nil {
  191. return err
  192. }
  193. s.Lock()
  194. defer s.Unlock()
  195. for i, r := range guild.Roles {
  196. if r.ID == role.ID {
  197. guild.Roles[i] = role
  198. return nil
  199. }
  200. }
  201. guild.Roles = append(guild.Roles, role)
  202. return nil
  203. }
  204. // RoleRemove removes a role from current world state by ID.
  205. func (s *State) RoleRemove(guildID, roleID string) error {
  206. if s == nil {
  207. return ErrNilState
  208. }
  209. guild, err := s.Guild(guildID)
  210. if err != nil {
  211. return err
  212. }
  213. s.Lock()
  214. defer s.Unlock()
  215. for i, r := range guild.Roles {
  216. if r.ID == roleID {
  217. guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...)
  218. return nil
  219. }
  220. }
  221. return errors.New("Role not found.")
  222. }
  223. // Role gets a role by ID from a guild.
  224. func (s *State) Role(guildID, roleID string) (*Role, error) {
  225. if s == nil {
  226. return nil, ErrNilState
  227. }
  228. guild, err := s.Guild(guildID)
  229. if err != nil {
  230. return nil, err
  231. }
  232. s.RLock()
  233. defer s.RUnlock()
  234. for _, r := range guild.Roles {
  235. if r.ID == roleID {
  236. return r, nil
  237. }
  238. }
  239. return nil, errors.New("Role not found.")
  240. }
  241. // ChannelAdd adds a guild to the current world state, or
  242. // updates it if it already exists.
  243. // Channels may exist either as PrivateChannels or inside
  244. // a guild.
  245. func (s *State) ChannelAdd(channel *Channel) error {
  246. if s == nil {
  247. return ErrNilState
  248. }
  249. s.Lock()
  250. defer s.Unlock()
  251. // If the channel exists, replace it
  252. if c, ok := s.channelMap[channel.ID]; ok {
  253. if channel.Messages == nil {
  254. channel.Messages = c.Messages
  255. }
  256. if channel.PermissionOverwrites == nil {
  257. channel.PermissionOverwrites = c.PermissionOverwrites
  258. }
  259. *c = *channel
  260. return nil
  261. }
  262. if channel.IsPrivate {
  263. s.PrivateChannels = append(s.PrivateChannels, channel)
  264. } else {
  265. guild, ok := s.guildMap[channel.GuildID]
  266. if !ok {
  267. return errors.New("Guild for channel not found.")
  268. }
  269. guild.Channels = append(guild.Channels, channel)
  270. }
  271. s.channelMap[channel.ID] = channel
  272. return nil
  273. }
  274. // ChannelRemove removes a channel from current world state.
  275. func (s *State) ChannelRemove(channel *Channel) error {
  276. if s == nil {
  277. return ErrNilState
  278. }
  279. _, err := s.Channel(channel.ID)
  280. if err != nil {
  281. return err
  282. }
  283. if channel.IsPrivate {
  284. s.Lock()
  285. defer s.Unlock()
  286. for i, c := range s.PrivateChannels {
  287. if c.ID == channel.ID {
  288. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  289. break
  290. }
  291. }
  292. } else {
  293. guild, err := s.Guild(channel.GuildID)
  294. if err != nil {
  295. return err
  296. }
  297. s.Lock()
  298. defer s.Unlock()
  299. for i, c := range guild.Channels {
  300. if c.ID == channel.ID {
  301. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  302. break
  303. }
  304. }
  305. }
  306. delete(s.channelMap, channel.ID)
  307. return nil
  308. }
  309. // GuildChannel gets a channel by ID from a guild.
  310. // This method is Deprecated, use Channel(channelID)
  311. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  312. return s.Channel(channelID)
  313. }
  314. // PrivateChannel gets a private channel by ID.
  315. // This method is Deprecated, use Channel(channelID)
  316. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  317. return s.Channel(channelID)
  318. }
  319. // Channel gets a channel by ID, it will look in all guilds an private channels.
  320. func (s *State) Channel(channelID string) (*Channel, error) {
  321. if s == nil {
  322. return nil, ErrNilState
  323. }
  324. s.RLock()
  325. defer s.RUnlock()
  326. if c, ok := s.channelMap[channelID]; ok {
  327. return c, nil
  328. }
  329. return nil, errors.New("Channel not found.")
  330. }
  331. // Emoji returns an emoji for a guild and emoji id.
  332. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  333. if s == nil {
  334. return nil, ErrNilState
  335. }
  336. guild, err := s.Guild(guildID)
  337. if err != nil {
  338. return nil, err
  339. }
  340. s.RLock()
  341. defer s.RUnlock()
  342. for _, e := range guild.Emojis {
  343. if e.ID == emojiID {
  344. return e, nil
  345. }
  346. }
  347. return nil, errors.New("Emoji not found.")
  348. }
  349. // EmojiAdd adds an emoji to the current world state.
  350. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  351. if s == nil {
  352. return ErrNilState
  353. }
  354. guild, err := s.Guild(guildID)
  355. if err != nil {
  356. return err
  357. }
  358. s.Lock()
  359. defer s.Unlock()
  360. for i, e := range guild.Emojis {
  361. if e.ID == emoji.ID {
  362. guild.Emojis[i] = emoji
  363. return nil
  364. }
  365. }
  366. guild.Emojis = append(guild.Emojis, emoji)
  367. return nil
  368. }
  369. // EmojisAdd adds multiple emojis to the world state.
  370. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  371. for _, e := range emojis {
  372. if err := s.EmojiAdd(guildID, e); err != nil {
  373. return err
  374. }
  375. }
  376. return nil
  377. }
  378. // MessageAdd adds a message to the current world state, or updates it if it exists.
  379. // If the channel cannot be found, the message is discarded.
  380. // Messages are kept in state up to s.MaxMessageCount
  381. func (s *State) MessageAdd(message *Message) error {
  382. if s == nil {
  383. return ErrNilState
  384. }
  385. c, err := s.Channel(message.ChannelID)
  386. if err != nil {
  387. return err
  388. }
  389. s.Lock()
  390. defer s.Unlock()
  391. // If the message exists, merge in the new message contents.
  392. for _, m := range c.Messages {
  393. if m.ID == message.ID {
  394. if message.Content != "" {
  395. m.Content = message.Content
  396. }
  397. if message.EditedTimestamp != "" {
  398. m.EditedTimestamp = message.EditedTimestamp
  399. }
  400. if message.Mentions != nil {
  401. m.Mentions = message.Mentions
  402. }
  403. if message.Embeds != nil {
  404. m.Embeds = message.Embeds
  405. }
  406. if message.Attachments != nil {
  407. m.Attachments = message.Attachments
  408. }
  409. if message.Timestamp != "" {
  410. m.Timestamp = message.Timestamp
  411. }
  412. if message.Author != nil {
  413. m.Author = message.Author
  414. }
  415. return nil
  416. }
  417. }
  418. c.Messages = append(c.Messages, message)
  419. if len(c.Messages) > s.MaxMessageCount {
  420. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  421. }
  422. return nil
  423. }
  424. // MessageRemove removes a message from the world state.
  425. func (s *State) MessageRemove(message *Message) error {
  426. if s == nil {
  427. return ErrNilState
  428. }
  429. c, err := s.Channel(message.ChannelID)
  430. if err != nil {
  431. return err
  432. }
  433. s.Lock()
  434. defer s.Unlock()
  435. for i, m := range c.Messages {
  436. if m.ID == message.ID {
  437. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  438. return nil
  439. }
  440. }
  441. return errors.New("Message not found.")
  442. }
  443. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  444. guild, err := s.Guild(update.GuildID)
  445. if err != nil {
  446. return err
  447. }
  448. s.Lock()
  449. defer s.Unlock()
  450. // Handle Leaving Channel
  451. if update.ChannelID == "" {
  452. for i, state := range guild.VoiceStates {
  453. if state.UserID == update.UserID {
  454. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  455. return nil
  456. }
  457. }
  458. } else {
  459. for i, state := range guild.VoiceStates {
  460. if state.UserID == update.UserID {
  461. guild.VoiceStates[i] = update.VoiceState
  462. return nil
  463. }
  464. }
  465. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  466. }
  467. return nil
  468. }
  469. // Message gets a message by channel and message ID.
  470. func (s *State) Message(channelID, messageID string) (*Message, error) {
  471. if s == nil {
  472. return nil, ErrNilState
  473. }
  474. c, err := s.Channel(channelID)
  475. if err != nil {
  476. return nil, err
  477. }
  478. s.RLock()
  479. defer s.RUnlock()
  480. for _, m := range c.Messages {
  481. if m.ID == messageID {
  482. return m, nil
  483. }
  484. }
  485. return nil, errors.New("Message not found.")
  486. }
  487. // OnReady takes a Ready event and updates all internal state.
  488. func (s *State) onReady(se *Session, r *Ready) (err error) {
  489. if s == nil {
  490. return ErrNilState
  491. }
  492. s.Lock()
  493. defer s.Unlock()
  494. // We must track at least the current user for Voice, even
  495. // if state is disabled, store the bare essentials.
  496. if !se.StateEnabled {
  497. ready := Ready{
  498. Version: r.Version,
  499. SessionID: r.SessionID,
  500. HeartbeatInterval: r.HeartbeatInterval,
  501. User: r.User,
  502. }
  503. s.Ready = ready
  504. return nil
  505. }
  506. s.Ready = *r
  507. for _, g := range s.Guilds {
  508. s.guildMap[g.ID] = g
  509. for _, c := range g.Channels {
  510. s.channelMap[c.ID] = c
  511. }
  512. }
  513. for _, c := range s.PrivateChannels {
  514. s.channelMap[c.ID] = c
  515. }
  516. return nil
  517. }
  518. // onInterface handles all events related to states.
  519. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  520. if s == nil {
  521. return ErrNilState
  522. }
  523. r, ok := i.(*Ready)
  524. if ok {
  525. return s.onReady(se, r)
  526. }
  527. if !se.StateEnabled {
  528. return nil
  529. }
  530. switch t := i.(type) {
  531. case *GuildCreate:
  532. err = s.GuildAdd(t.Guild)
  533. case *GuildUpdate:
  534. err = s.GuildAdd(t.Guild)
  535. case *GuildDelete:
  536. err = s.GuildRemove(t.Guild)
  537. case *GuildMemberAdd:
  538. if s.TrackMembers {
  539. err = s.MemberAdd(t.Member)
  540. }
  541. case *GuildMemberUpdate:
  542. if s.TrackMembers {
  543. err = s.MemberAdd(t.Member)
  544. }
  545. case *GuildMemberRemove:
  546. if s.TrackMembers {
  547. err = s.MemberRemove(t.Member)
  548. }
  549. case *GuildRoleCreate:
  550. if s.TrackRoles {
  551. err = s.RoleAdd(t.GuildID, t.Role)
  552. }
  553. case *GuildRoleUpdate:
  554. if s.TrackRoles {
  555. err = s.RoleAdd(t.GuildID, t.Role)
  556. }
  557. case *GuildRoleDelete:
  558. if s.TrackRoles {
  559. err = s.RoleRemove(t.GuildID, t.RoleID)
  560. }
  561. case *GuildEmojisUpdate:
  562. if s.TrackEmojis {
  563. err = s.EmojisAdd(t.GuildID, t.Emojis)
  564. }
  565. case *ChannelCreate:
  566. if s.TrackChannels {
  567. err = s.ChannelAdd(t.Channel)
  568. }
  569. case *ChannelUpdate:
  570. if s.TrackChannels {
  571. err = s.ChannelAdd(t.Channel)
  572. }
  573. case *ChannelDelete:
  574. if s.TrackChannels {
  575. err = s.ChannelRemove(t.Channel)
  576. }
  577. case *MessageCreate:
  578. if s.MaxMessageCount != 0 {
  579. err = s.MessageAdd(t.Message)
  580. }
  581. case *MessageUpdate:
  582. if s.MaxMessageCount != 0 {
  583. err = s.MessageAdd(t.Message)
  584. }
  585. case *MessageDelete:
  586. if s.MaxMessageCount != 0 {
  587. err = s.MessageRemove(t.Message)
  588. }
  589. case *VoiceStateUpdate:
  590. if s.TrackVoice {
  591. err = s.voiceStateUpdate(t)
  592. }
  593. }
  594. return
  595. }
  596. // UserChannelPermissions returns the permission of a user in a channel.
  597. // userID : The ID of the user to calculate permissions for.
  598. // channelID : The ID of the channel to calculate permission for.
  599. func (s *State) UserChannelPermissions(userID, channelID string) (apermissions int, err error) {
  600. if s == nil {
  601. return 0, ErrNilState
  602. }
  603. channel, err := s.Channel(channelID)
  604. if err != nil {
  605. return
  606. }
  607. guild, err := s.Guild(channel.GuildID)
  608. if err != nil {
  609. return
  610. }
  611. if userID == guild.OwnerID {
  612. apermissions = PermissionAll
  613. return
  614. }
  615. member, err := s.Member(guild.ID, userID)
  616. if err != nil {
  617. return
  618. }
  619. return memberPermissions(guild, channel, member), nil
  620. }
  621. // UserColor returns the color of a user in a channel.
  622. // While colors are defined at a Guild level, determining for a channel is more useful in message handlers.
  623. // 0 is returned in cases of error, which is the color of @everyone.
  624. // userID : The ID of the user to calculate the color for.
  625. // channelID : The ID of the channel to calculate the color for.
  626. func (s *State) UserColor(userID, channelID string) int {
  627. if s == nil {
  628. return 0
  629. }
  630. channel, err := s.Channel(channelID)
  631. if err != nil {
  632. return 0
  633. }
  634. guild, err := s.Guild(channel.GuildID)
  635. if err != nil {
  636. return 0
  637. }
  638. member, err := s.Member(guild.ID, userID)
  639. if err != nil {
  640. return 0
  641. }
  642. roles := Roles(guild.Roles)
  643. sort.Sort(roles)
  644. for _, role := range roles {
  645. for _, roleID := range member.Roles {
  646. if role.ID == roleID {
  647. if role.Color != 0 {
  648. return role.Color
  649. }
  650. }
  651. }
  652. }
  653. return 0
  654. }