state.go 16 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sync"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // A State contains the current known state.
  18. // As discord sends this in a READY blob, it seems reasonable to simply
  19. // use that struct as the data store.
  20. type State struct {
  21. sync.RWMutex
  22. Ready
  23. MaxMessageCount int
  24. TrackChannels bool
  25. TrackEmojis bool
  26. TrackMembers bool
  27. TrackRoles bool
  28. TrackVoice bool
  29. guildMap map[string]*Guild
  30. channelMap map[string]*Channel
  31. }
  32. // NewState creates an empty state.
  33. func NewState() *State {
  34. return &State{
  35. Ready: Ready{
  36. PrivateChannels: []*Channel{},
  37. Guilds: []*Guild{},
  38. },
  39. TrackChannels: true,
  40. TrackEmojis: true,
  41. TrackMembers: true,
  42. TrackRoles: true,
  43. TrackVoice: true,
  44. guildMap: make(map[string]*Guild),
  45. channelMap: make(map[string]*Channel),
  46. }
  47. }
  48. // GuildAdd adds a guild to the current world state, or
  49. // updates it if it already exists.
  50. func (s *State) GuildAdd(guild *Guild) error {
  51. if s == nil {
  52. return ErrNilState
  53. }
  54. s.Lock()
  55. defer s.Unlock()
  56. // Update the channels to point to the right guild, adding them to the channelMap as we go
  57. for _, c := range guild.Channels {
  58. c.GuildID = guild.ID
  59. s.channelMap[c.ID] = c
  60. }
  61. if g, ok := s.guildMap[guild.ID]; ok {
  62. // We are about to replace `g` in the state with `guild`, but first we need to
  63. // make sure we preserve any fields that the `guild` doesn't contain from `g`.
  64. if guild.Roles == nil {
  65. guild.Roles = g.Roles
  66. }
  67. if guild.Emojis == nil {
  68. guild.Emojis = g.Emojis
  69. }
  70. if guild.Members == nil {
  71. for _, m := range g.Members {
  72. m.GuildID = guild.ID
  73. }
  74. guild.Members = g.Members
  75. }
  76. if guild.Presences == nil {
  77. guild.Presences = g.Presences
  78. }
  79. if guild.Channels == nil {
  80. for _, c := range g.Channels {
  81. c.GuildID = guild.ID
  82. }
  83. guild.Channels = g.Channels
  84. }
  85. if guild.VoiceStates == nil {
  86. for _, g := range g.VoiceStates {
  87. g.GuildID = guild.ID
  88. }
  89. guild.VoiceStates = g.VoiceStates
  90. }
  91. *g = *guild
  92. return nil
  93. }
  94. s.Guilds = append(s.Guilds, guild)
  95. s.guildMap[guild.ID] = guild
  96. return nil
  97. }
  98. // GuildRemove removes a guild from current world state.
  99. func (s *State) GuildRemove(guild *Guild) error {
  100. if s == nil {
  101. return ErrNilState
  102. }
  103. _, err := s.Guild(guild.ID)
  104. if err != nil {
  105. return err
  106. }
  107. s.Lock()
  108. defer s.Unlock()
  109. delete(s.guildMap, guild.ID)
  110. for i, g := range s.Guilds {
  111. if g.ID == guild.ID {
  112. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  113. return nil
  114. }
  115. }
  116. return nil
  117. }
  118. // Guild gets a guild by ID.
  119. // Useful for querying if @me is in a guild:
  120. // _, err := discordgo.Session.State.Guild(guildID)
  121. // isInGuild := err == nil
  122. func (s *State) Guild(guildID string) (*Guild, error) {
  123. if s == nil {
  124. return nil, ErrNilState
  125. }
  126. s.RLock()
  127. defer s.RUnlock()
  128. if g, ok := s.guildMap[guildID]; ok {
  129. return g, nil
  130. }
  131. return nil, errors.New("Guild not found.")
  132. }
  133. // TODO: Consider moving Guild state update methods onto *Guild.
  134. // MemberAdd adds a member to the current world state, or
  135. // updates it if it already exists.
  136. func (s *State) MemberAdd(member *Member) error {
  137. if s == nil {
  138. return ErrNilState
  139. }
  140. guild, err := s.Guild(member.GuildID)
  141. if err != nil {
  142. return err
  143. }
  144. s.Lock()
  145. defer s.Unlock()
  146. for i, m := range guild.Members {
  147. if m.User.ID == member.User.ID {
  148. guild.Members[i] = member
  149. return nil
  150. }
  151. }
  152. guild.Members = append(guild.Members, member)
  153. return nil
  154. }
  155. // MemberRemove removes a member from current world state.
  156. func (s *State) MemberRemove(member *Member) error {
  157. if s == nil {
  158. return ErrNilState
  159. }
  160. guild, err := s.Guild(member.GuildID)
  161. if err != nil {
  162. return err
  163. }
  164. s.Lock()
  165. defer s.Unlock()
  166. for i, m := range guild.Members {
  167. if m.User.ID == member.User.ID {
  168. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  169. return nil
  170. }
  171. }
  172. return errors.New("Member not found.")
  173. }
  174. // Member gets a member by ID from a guild.
  175. func (s *State) Member(guildID, userID string) (*Member, error) {
  176. if s == nil {
  177. return nil, ErrNilState
  178. }
  179. guild, err := s.Guild(guildID)
  180. if err != nil {
  181. return nil, err
  182. }
  183. s.RLock()
  184. defer s.RUnlock()
  185. for _, m := range guild.Members {
  186. if m.User.ID == userID {
  187. return m, nil
  188. }
  189. }
  190. return nil, errors.New("Member not found.")
  191. }
  192. // RoleAdd adds a role to the current world state, or
  193. // updates it if it already exists.
  194. func (s *State) RoleAdd(guildID string, role *Role) error {
  195. if s == nil {
  196. return ErrNilState
  197. }
  198. guild, err := s.Guild(guildID)
  199. if err != nil {
  200. return err
  201. }
  202. s.Lock()
  203. defer s.Unlock()
  204. for i, r := range guild.Roles {
  205. if r.ID == role.ID {
  206. guild.Roles[i] = role
  207. return nil
  208. }
  209. }
  210. guild.Roles = append(guild.Roles, role)
  211. return nil
  212. }
  213. // RoleRemove removes a role from current world state by ID.
  214. func (s *State) RoleRemove(guildID, roleID string) error {
  215. if s == nil {
  216. return ErrNilState
  217. }
  218. guild, err := s.Guild(guildID)
  219. if err != nil {
  220. return err
  221. }
  222. s.Lock()
  223. defer s.Unlock()
  224. for i, r := range guild.Roles {
  225. if r.ID == roleID {
  226. guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...)
  227. return nil
  228. }
  229. }
  230. return errors.New("Role not found.")
  231. }
  232. // Role gets a role by ID from a guild.
  233. func (s *State) Role(guildID, roleID string) (*Role, error) {
  234. if s == nil {
  235. return nil, ErrNilState
  236. }
  237. guild, err := s.Guild(guildID)
  238. if err != nil {
  239. return nil, err
  240. }
  241. s.RLock()
  242. defer s.RUnlock()
  243. for _, r := range guild.Roles {
  244. if r.ID == roleID {
  245. return r, nil
  246. }
  247. }
  248. return nil, errors.New("Role not found.")
  249. }
  250. // ChannelAdd adds a guild to the current world state, or
  251. // updates it if it already exists.
  252. // Channels may exist either as PrivateChannels or inside
  253. // a guild.
  254. func (s *State) ChannelAdd(channel *Channel) error {
  255. if s == nil {
  256. return ErrNilState
  257. }
  258. s.Lock()
  259. defer s.Unlock()
  260. // If the channel exists, replace it
  261. if c, ok := s.channelMap[channel.ID]; ok {
  262. channel.Messages = c.Messages
  263. channel.PermissionOverwrites = c.PermissionOverwrites
  264. *c = *channel
  265. return nil
  266. }
  267. if channel.IsPrivate {
  268. s.PrivateChannels = append(s.PrivateChannels, channel)
  269. } else {
  270. guild, ok := s.guildMap[channel.GuildID]
  271. if !ok {
  272. return errors.New("Guild for channel not found.")
  273. }
  274. guild.Channels = append(guild.Channels, channel)
  275. }
  276. s.channelMap[channel.ID] = channel
  277. return nil
  278. }
  279. // ChannelRemove removes a channel from current world state.
  280. func (s *State) ChannelRemove(channel *Channel) error {
  281. if s == nil {
  282. return ErrNilState
  283. }
  284. _, err := s.Channel(channel.ID)
  285. if err != nil {
  286. return err
  287. }
  288. if channel.IsPrivate {
  289. s.Lock()
  290. defer s.Unlock()
  291. for i, c := range s.PrivateChannels {
  292. if c.ID == channel.ID {
  293. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  294. break
  295. }
  296. }
  297. } else {
  298. guild, err := s.Guild(channel.GuildID)
  299. if err != nil {
  300. return err
  301. }
  302. s.Lock()
  303. defer s.Unlock()
  304. for i, c := range guild.Channels {
  305. if c.ID == channel.ID {
  306. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  307. break
  308. }
  309. }
  310. }
  311. delete(s.channelMap, channel.ID)
  312. return nil
  313. }
  314. // GuildChannel gets a channel by ID from a guild.
  315. // This method is Deprecated, use Channel(channelID)
  316. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  317. return s.Channel(channelID)
  318. }
  319. // PrivateChannel gets a private channel by ID.
  320. // This method is Deprecated, use Channel(channelID)
  321. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  322. return s.Channel(channelID)
  323. }
  324. // Channel gets a channel by ID, it will look in all guilds an private channels.
  325. func (s *State) Channel(channelID string) (*Channel, error) {
  326. if s == nil {
  327. return nil, ErrNilState
  328. }
  329. s.RLock()
  330. defer s.RUnlock()
  331. if c, ok := s.channelMap[channelID]; ok {
  332. return c, nil
  333. }
  334. return nil, errors.New("Channel not found.")
  335. }
  336. // Emoji returns an emoji for a guild and emoji id.
  337. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  338. if s == nil {
  339. return nil, ErrNilState
  340. }
  341. guild, err := s.Guild(guildID)
  342. if err != nil {
  343. return nil, err
  344. }
  345. s.RLock()
  346. defer s.RUnlock()
  347. for _, e := range guild.Emojis {
  348. if e.ID == emojiID {
  349. return e, nil
  350. }
  351. }
  352. return nil, errors.New("Emoji not found.")
  353. }
  354. // EmojiAdd adds an emoji to the current world state.
  355. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  356. if s == nil {
  357. return ErrNilState
  358. }
  359. guild, err := s.Guild(guildID)
  360. if err != nil {
  361. return err
  362. }
  363. s.Lock()
  364. defer s.Unlock()
  365. for i, e := range guild.Emojis {
  366. if e.ID == emoji.ID {
  367. guild.Emojis[i] = emoji
  368. return nil
  369. }
  370. }
  371. guild.Emojis = append(guild.Emojis, emoji)
  372. return nil
  373. }
  374. // EmojisAdd adds multiple emojis to the world state.
  375. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  376. for _, e := range emojis {
  377. if err := s.EmojiAdd(guildID, e); err != nil {
  378. return err
  379. }
  380. }
  381. return nil
  382. }
  383. // MessageAdd adds a message to the current world state, or updates it if it exists.
  384. // If the channel cannot be found, the message is discarded.
  385. // Messages are kept in state up to s.MaxMessageCount
  386. func (s *State) MessageAdd(message *Message) error {
  387. if s == nil {
  388. return ErrNilState
  389. }
  390. c, err := s.Channel(message.ChannelID)
  391. if err != nil {
  392. return err
  393. }
  394. s.Lock()
  395. defer s.Unlock()
  396. // If the message exists, merge in the new message contents.
  397. for _, m := range c.Messages {
  398. if m.ID == message.ID {
  399. if message.Content != "" {
  400. m.Content = message.Content
  401. }
  402. if message.EditedTimestamp != "" {
  403. m.EditedTimestamp = message.EditedTimestamp
  404. }
  405. if message.Mentions != nil {
  406. m.Mentions = message.Mentions
  407. }
  408. if message.Embeds != nil {
  409. m.Embeds = message.Embeds
  410. }
  411. if message.Attachments != nil {
  412. m.Attachments = message.Attachments
  413. }
  414. if message.Timestamp != "" {
  415. m.Timestamp = message.Timestamp
  416. }
  417. if message.Author != nil {
  418. m.Author = message.Author
  419. }
  420. return nil
  421. }
  422. }
  423. c.Messages = append(c.Messages, message)
  424. if len(c.Messages) > s.MaxMessageCount {
  425. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  426. }
  427. return nil
  428. }
  429. // MessageRemove removes a message from the world state.
  430. func (s *State) MessageRemove(message *Message) error {
  431. if s == nil {
  432. return ErrNilState
  433. }
  434. c, err := s.Channel(message.ChannelID)
  435. if err != nil {
  436. return err
  437. }
  438. s.Lock()
  439. defer s.Unlock()
  440. for i, m := range c.Messages {
  441. if m.ID == message.ID {
  442. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  443. return nil
  444. }
  445. }
  446. return errors.New("Message not found.")
  447. }
  448. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  449. guild, err := s.Guild(update.GuildID)
  450. if err != nil {
  451. return err
  452. }
  453. s.Lock()
  454. defer s.Unlock()
  455. // Handle Leaving Channel
  456. if update.ChannelID == "" {
  457. for i, state := range guild.VoiceStates {
  458. if state.UserID == update.UserID {
  459. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  460. return nil
  461. }
  462. }
  463. } else {
  464. for i, state := range guild.VoiceStates {
  465. if state.UserID == update.UserID {
  466. guild.VoiceStates[i] = update.VoiceState
  467. return nil
  468. }
  469. }
  470. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  471. }
  472. return nil
  473. }
  474. // Message gets a message by channel and message ID.
  475. func (s *State) Message(channelID, messageID string) (*Message, error) {
  476. if s == nil {
  477. return nil, ErrNilState
  478. }
  479. c, err := s.Channel(channelID)
  480. if err != nil {
  481. return nil, err
  482. }
  483. s.RLock()
  484. defer s.RUnlock()
  485. for _, m := range c.Messages {
  486. if m.ID == messageID {
  487. return m, nil
  488. }
  489. }
  490. return nil, errors.New("Message not found.")
  491. }
  492. // OnReady takes a Ready event and updates all internal state.
  493. func (s *State) onReady(se *Session, r *Ready) (err error) {
  494. if s == nil {
  495. return ErrNilState
  496. }
  497. s.Lock()
  498. defer s.Unlock()
  499. // We must track at least the current user for Voice, even
  500. // if state is disabled, store the bare essentials.
  501. if !se.StateEnabled {
  502. ready := Ready{
  503. Version: r.Version,
  504. SessionID: r.SessionID,
  505. HeartbeatInterval: r.HeartbeatInterval,
  506. User: r.User,
  507. }
  508. s.Ready = ready
  509. return nil
  510. }
  511. s.Ready = *r
  512. for _, g := range s.Guilds {
  513. s.guildMap[g.ID] = g
  514. for _, c := range g.Channels {
  515. c.GuildID = g.ID
  516. s.channelMap[c.ID] = c
  517. }
  518. for _, m := range g.Members {
  519. m.GuildID = g.ID
  520. }
  521. for _, vs := range g.VoiceStates {
  522. vs.GuildID = g.ID
  523. }
  524. }
  525. for _, c := range s.PrivateChannels {
  526. s.channelMap[c.ID] = c
  527. }
  528. return nil
  529. }
  530. // onInterface handles all events related to states.
  531. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  532. if s == nil {
  533. return ErrNilState
  534. }
  535. r, ok := i.(*Ready)
  536. if ok {
  537. return s.onReady(se, r)
  538. }
  539. if !se.StateEnabled {
  540. return nil
  541. }
  542. switch t := i.(type) {
  543. case *GuildCreate:
  544. err = s.GuildAdd(t.Guild)
  545. case *GuildUpdate:
  546. err = s.GuildAdd(t.Guild)
  547. case *GuildDelete:
  548. err = s.GuildRemove(t.Guild)
  549. case *GuildMemberAdd:
  550. if s.TrackMembers {
  551. err = s.MemberAdd(t.Member)
  552. }
  553. case *GuildMemberUpdate:
  554. if s.TrackMembers {
  555. err = s.MemberAdd(t.Member)
  556. }
  557. case *GuildMemberRemove:
  558. if s.TrackMembers {
  559. err = s.MemberRemove(t.Member)
  560. }
  561. case *GuildRoleCreate:
  562. if s.TrackRoles {
  563. err = s.RoleAdd(t.GuildID, t.Role)
  564. }
  565. case *GuildRoleUpdate:
  566. if s.TrackRoles {
  567. err = s.RoleAdd(t.GuildID, t.Role)
  568. }
  569. case *GuildRoleDelete:
  570. if s.TrackRoles {
  571. err = s.RoleRemove(t.GuildID, t.RoleID)
  572. }
  573. case *GuildEmojisUpdate:
  574. if s.TrackEmojis {
  575. err = s.EmojisAdd(t.GuildID, t.Emojis)
  576. }
  577. case *ChannelCreate:
  578. if s.TrackChannels {
  579. err = s.ChannelAdd(t.Channel)
  580. }
  581. case *ChannelUpdate:
  582. if s.TrackChannels {
  583. err = s.ChannelAdd(t.Channel)
  584. }
  585. case *ChannelDelete:
  586. if s.TrackChannels {
  587. err = s.ChannelRemove(t.Channel)
  588. }
  589. case *MessageCreate:
  590. if s.MaxMessageCount != 0 {
  591. err = s.MessageAdd(t.Message)
  592. }
  593. case *MessageUpdate:
  594. if s.MaxMessageCount != 0 {
  595. err = s.MessageAdd(t.Message)
  596. }
  597. case *MessageDelete:
  598. if s.MaxMessageCount != 0 {
  599. err = s.MessageRemove(t.Message)
  600. }
  601. case *VoiceStateUpdate:
  602. if s.TrackVoice {
  603. err = s.voiceStateUpdate(t)
  604. }
  605. }
  606. return
  607. }
  608. // UserChannelPermissions returns the permission of a user in a channel.
  609. // userID : The ID of the user to calculate permissions for.
  610. // channelID : The ID of the channel to calculate permission for.
  611. func (s *State) UserChannelPermissions(userID, channelID string) (apermissions int, err error) {
  612. if s == nil {
  613. return 0, ErrNilState
  614. }
  615. channel, err := s.Channel(channelID)
  616. if err != nil {
  617. return
  618. }
  619. guild, err := s.Guild(channel.GuildID)
  620. if err != nil {
  621. return
  622. }
  623. if userID == guild.OwnerID {
  624. apermissions = PermissionAll
  625. return
  626. }
  627. member, err := s.Member(guild.ID, userID)
  628. if err != nil {
  629. return
  630. }
  631. for _, role := range guild.Roles {
  632. if role.ID == guild.ID {
  633. apermissions |= role.Permissions
  634. break
  635. }
  636. }
  637. for _, role := range guild.Roles {
  638. for _, roleID := range member.Roles {
  639. if role.ID == roleID {
  640. apermissions |= role.Permissions
  641. break
  642. }
  643. }
  644. }
  645. if apermissions&PermissionAdministrator > 0 {
  646. apermissions |= PermissionAll
  647. }
  648. // Member overwrites can override role overrides, so do two passes
  649. for _, overwrite := range channel.PermissionOverwrites {
  650. for _, roleID := range member.Roles {
  651. if overwrite.Type == "role" && roleID == overwrite.ID {
  652. apermissions &= ^overwrite.Deny
  653. apermissions |= overwrite.Allow
  654. break
  655. }
  656. }
  657. }
  658. for _, overwrite := range channel.PermissionOverwrites {
  659. if overwrite.Type == "member" && overwrite.ID == userID {
  660. apermissions &= ^overwrite.Deny
  661. apermissions |= overwrite.Allow
  662. break
  663. }
  664. }
  665. if apermissions&PermissionAdministrator > 0 {
  666. apermissions |= PermissionAllChannel
  667. }
  668. return
  669. }