state.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sync"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // A State contains the current known state.
  18. // As discord sends this in a READY blob, it seems reasonable to simply
  19. // use that struct as the data store.
  20. type State struct {
  21. sync.RWMutex
  22. Ready
  23. MaxMessageCount int
  24. TrackChannels bool
  25. TrackEmojis bool
  26. TrackMembers bool
  27. TrackRoles bool
  28. TrackVoice bool
  29. guildMap map[string]*Guild
  30. channelMap map[string]*Channel
  31. }
  32. // NewState creates an empty state.
  33. func NewState() *State {
  34. return &State{
  35. Ready: Ready{
  36. PrivateChannels: []*Channel{},
  37. Guilds: []*Guild{},
  38. },
  39. TrackChannels: true,
  40. TrackEmojis: true,
  41. TrackMembers: true,
  42. TrackRoles: true,
  43. TrackVoice: true,
  44. guildMap: make(map[string]*Guild),
  45. channelMap: make(map[string]*Channel),
  46. }
  47. }
  48. // GuildAdd adds a guild to the current world state, or
  49. // updates it if it already exists.
  50. func (s *State) GuildAdd(guild *Guild) error {
  51. if s == nil {
  52. return ErrNilState
  53. }
  54. s.Lock()
  55. defer s.Unlock()
  56. // Update the channels to point to the right guild, adding them to the channelMap as we go
  57. for _, c := range guild.Channels {
  58. c.GuildID = guild.ID
  59. s.channelMap[c.ID] = c
  60. }
  61. if g, ok := s.guildMap[guild.ID]; ok {
  62. // We are about to replace `g` in the state with `guild`, but first we need to
  63. // make sure we preserve any fields that the `guild` doesn't contain from `g`.
  64. if guild.Roles == nil {
  65. guild.Roles = g.Roles
  66. }
  67. if guild.Emojis == nil {
  68. guild.Emojis = g.Emojis
  69. }
  70. if guild.Members == nil {
  71. guild.Members = g.Members
  72. }
  73. if guild.Presences == nil {
  74. guild.Presences = g.Presences
  75. }
  76. if guild.Channels == nil {
  77. guild.Channels = g.Channels
  78. }
  79. if guild.VoiceStates == nil {
  80. guild.VoiceStates = g.VoiceStates
  81. }
  82. *g = *guild
  83. return nil
  84. }
  85. s.Guilds = append(s.Guilds, guild)
  86. s.guildMap[guild.ID] = guild
  87. return nil
  88. }
  89. // GuildRemove removes a guild from current world state.
  90. func (s *State) GuildRemove(guild *Guild) error {
  91. if s == nil {
  92. return ErrNilState
  93. }
  94. _, err := s.Guild(guild.ID)
  95. if err != nil {
  96. return err
  97. }
  98. s.Lock()
  99. defer s.Unlock()
  100. delete(s.guildMap, guild.ID)
  101. for i, g := range s.Guilds {
  102. if g.ID == guild.ID {
  103. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  104. return nil
  105. }
  106. }
  107. return nil
  108. }
  109. // Guild gets a guild by ID.
  110. // Useful for querying if @me is in a guild:
  111. // _, err := discordgo.Session.State.Guild(guildID)
  112. // isInGuild := err == nil
  113. func (s *State) Guild(guildID string) (*Guild, error) {
  114. if s == nil {
  115. return nil, ErrNilState
  116. }
  117. s.RLock()
  118. defer s.RUnlock()
  119. if g, ok := s.guildMap[guildID]; ok {
  120. return g, nil
  121. }
  122. return nil, errors.New("Guild not found.")
  123. }
  124. // TODO: Consider moving Guild state update methods onto *Guild.
  125. // MemberAdd adds a member to the current world state, or
  126. // updates it if it already exists.
  127. func (s *State) MemberAdd(member *Member) error {
  128. if s == nil {
  129. return ErrNilState
  130. }
  131. guild, err := s.Guild(member.GuildID)
  132. if err != nil {
  133. return err
  134. }
  135. s.Lock()
  136. defer s.Unlock()
  137. for i, m := range guild.Members {
  138. if m.User.ID == member.User.ID {
  139. guild.Members[i] = member
  140. return nil
  141. }
  142. }
  143. guild.Members = append(guild.Members, member)
  144. return nil
  145. }
  146. // MemberRemove removes a member from current world state.
  147. func (s *State) MemberRemove(member *Member) error {
  148. if s == nil {
  149. return ErrNilState
  150. }
  151. guild, err := s.Guild(member.GuildID)
  152. if err != nil {
  153. return err
  154. }
  155. s.Lock()
  156. defer s.Unlock()
  157. for i, m := range guild.Members {
  158. if m.User.ID == member.User.ID {
  159. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  160. return nil
  161. }
  162. }
  163. return errors.New("Member not found.")
  164. }
  165. // Member gets a member by ID from a guild.
  166. func (s *State) Member(guildID, userID string) (*Member, error) {
  167. if s == nil {
  168. return nil, ErrNilState
  169. }
  170. guild, err := s.Guild(guildID)
  171. if err != nil {
  172. return nil, err
  173. }
  174. s.RLock()
  175. defer s.RUnlock()
  176. for _, m := range guild.Members {
  177. if m.User.ID == userID {
  178. return m, nil
  179. }
  180. }
  181. return nil, errors.New("Member not found.")
  182. }
  183. // RoleAdd adds a role to the current world state, or
  184. // updates it if it already exists.
  185. func (s *State) RoleAdd(guildID string, role *Role) error {
  186. if s == nil {
  187. return ErrNilState
  188. }
  189. guild, err := s.Guild(guildID)
  190. if err != nil {
  191. return err
  192. }
  193. s.Lock()
  194. defer s.Unlock()
  195. for i, r := range guild.Roles {
  196. if r.ID == role.ID {
  197. guild.Roles[i] = role
  198. return nil
  199. }
  200. }
  201. guild.Roles = append(guild.Roles, role)
  202. return nil
  203. }
  204. // RoleRemove removes a role from current world state by ID.
  205. func (s *State) RoleRemove(guildID, roleID string) error {
  206. if s == nil {
  207. return ErrNilState
  208. }
  209. guild, err := s.Guild(guildID)
  210. if err != nil {
  211. return err
  212. }
  213. s.Lock()
  214. defer s.Unlock()
  215. for i, r := range guild.Roles {
  216. if r.ID == roleID {
  217. guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...)
  218. return nil
  219. }
  220. }
  221. return errors.New("Role not found.")
  222. }
  223. // Role gets a role by ID from a guild.
  224. func (s *State) Role(guildID, roleID string) (*Role, error) {
  225. if s == nil {
  226. return nil, ErrNilState
  227. }
  228. guild, err := s.Guild(guildID)
  229. if err != nil {
  230. return nil, err
  231. }
  232. s.RLock()
  233. defer s.RUnlock()
  234. for _, r := range guild.Roles {
  235. if r.ID == roleID {
  236. return r, nil
  237. }
  238. }
  239. return nil, errors.New("Role not found.")
  240. }
  241. // ChannelAdd adds a guild to the current world state, or
  242. // updates it if it already exists.
  243. // Channels may exist either as PrivateChannels or inside
  244. // a guild.
  245. func (s *State) ChannelAdd(channel *Channel) error {
  246. if s == nil {
  247. return ErrNilState
  248. }
  249. s.Lock()
  250. defer s.Unlock()
  251. // If the channel exists, replace it
  252. if c, ok := s.channelMap[channel.ID]; ok {
  253. channel.Messages = c.Messages
  254. channel.PermissionOverwrites = c.PermissionOverwrites
  255. *c = *channel
  256. return nil
  257. }
  258. if channel.IsPrivate {
  259. s.PrivateChannels = append(s.PrivateChannels, channel)
  260. } else {
  261. guild, ok := s.guildMap[channel.GuildID]
  262. if !ok {
  263. return errors.New("Guild for channel not found.")
  264. }
  265. guild.Channels = append(guild.Channels, channel)
  266. }
  267. s.channelMap[channel.ID] = channel
  268. return nil
  269. }
  270. // ChannelRemove removes a channel from current world state.
  271. func (s *State) ChannelRemove(channel *Channel) error {
  272. if s == nil {
  273. return ErrNilState
  274. }
  275. _, err := s.Channel(channel.ID)
  276. if err != nil {
  277. return err
  278. }
  279. if channel.IsPrivate {
  280. s.Lock()
  281. defer s.Unlock()
  282. for i, c := range s.PrivateChannels {
  283. if c.ID == channel.ID {
  284. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  285. break
  286. }
  287. }
  288. } else {
  289. guild, err := s.Guild(channel.GuildID)
  290. if err != nil {
  291. return err
  292. }
  293. s.Lock()
  294. defer s.Unlock()
  295. for i, c := range guild.Channels {
  296. if c.ID == channel.ID {
  297. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  298. break
  299. }
  300. }
  301. }
  302. delete(s.channelMap, channel.ID)
  303. return nil
  304. }
  305. // GuildChannel gets a channel by ID from a guild.
  306. // This method is Deprecated, use Channel(channelID)
  307. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  308. return s.Channel(channelID)
  309. }
  310. // PrivateChannel gets a private channel by ID.
  311. // This method is Deprecated, use Channel(channelID)
  312. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  313. return s.Channel(channelID)
  314. }
  315. // Channel gets a channel by ID, it will look in all guilds an private channels.
  316. func (s *State) Channel(channelID string) (*Channel, error) {
  317. if s == nil {
  318. return nil, ErrNilState
  319. }
  320. s.RLock()
  321. defer s.RUnlock()
  322. if c, ok := s.channelMap[channelID]; ok {
  323. return c, nil
  324. }
  325. return nil, errors.New("Channel not found.")
  326. }
  327. // Emoji returns an emoji for a guild and emoji id.
  328. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  329. if s == nil {
  330. return nil, ErrNilState
  331. }
  332. guild, err := s.Guild(guildID)
  333. if err != nil {
  334. return nil, err
  335. }
  336. s.RLock()
  337. defer s.RUnlock()
  338. for _, e := range guild.Emojis {
  339. if e.ID == emojiID {
  340. return e, nil
  341. }
  342. }
  343. return nil, errors.New("Emoji not found.")
  344. }
  345. // EmojiAdd adds an emoji to the current world state.
  346. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  347. if s == nil {
  348. return ErrNilState
  349. }
  350. guild, err := s.Guild(guildID)
  351. if err != nil {
  352. return err
  353. }
  354. s.Lock()
  355. defer s.Unlock()
  356. for i, e := range guild.Emojis {
  357. if e.ID == emoji.ID {
  358. guild.Emojis[i] = emoji
  359. return nil
  360. }
  361. }
  362. guild.Emojis = append(guild.Emojis, emoji)
  363. return nil
  364. }
  365. // EmojisAdd adds multiple emojis to the world state.
  366. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  367. for _, e := range emojis {
  368. if err := s.EmojiAdd(guildID, e); err != nil {
  369. return err
  370. }
  371. }
  372. return nil
  373. }
  374. // MessageAdd adds a message to the current world state, or updates it if it exists.
  375. // If the channel cannot be found, the message is discarded.
  376. // Messages are kept in state up to s.MaxMessageCount
  377. func (s *State) MessageAdd(message *Message) error {
  378. if s == nil {
  379. return ErrNilState
  380. }
  381. c, err := s.Channel(message.ChannelID)
  382. if err != nil {
  383. return err
  384. }
  385. s.Lock()
  386. defer s.Unlock()
  387. // If the message exists, merge in the new message contents.
  388. for _, m := range c.Messages {
  389. if m.ID == message.ID {
  390. if message.Content != "" {
  391. m.Content = message.Content
  392. }
  393. if message.EditedTimestamp != "" {
  394. m.EditedTimestamp = message.EditedTimestamp
  395. }
  396. if message.Mentions != nil {
  397. m.Mentions = message.Mentions
  398. }
  399. if message.Embeds != nil {
  400. m.Embeds = message.Embeds
  401. }
  402. if message.Attachments != nil {
  403. m.Attachments = message.Attachments
  404. }
  405. if message.Timestamp != "" {
  406. m.Timestamp = message.Timestamp
  407. }
  408. if message.Author != nil {
  409. m.Author = message.Author
  410. }
  411. return nil
  412. }
  413. }
  414. c.Messages = append(c.Messages, message)
  415. if len(c.Messages) > s.MaxMessageCount {
  416. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  417. }
  418. return nil
  419. }
  420. // MessageRemove removes a message from the world state.
  421. func (s *State) MessageRemove(message *Message) error {
  422. if s == nil {
  423. return ErrNilState
  424. }
  425. c, err := s.Channel(message.ChannelID)
  426. if err != nil {
  427. return err
  428. }
  429. s.Lock()
  430. defer s.Unlock()
  431. for i, m := range c.Messages {
  432. if m.ID == message.ID {
  433. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  434. return nil
  435. }
  436. }
  437. return errors.New("Message not found.")
  438. }
  439. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  440. guild, err := s.Guild(update.GuildID)
  441. if err != nil {
  442. return err
  443. }
  444. s.Lock()
  445. defer s.Unlock()
  446. // Handle Leaving Channel
  447. if update.ChannelID == "" {
  448. for i, state := range guild.VoiceStates {
  449. if state.UserID == update.UserID {
  450. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  451. return nil
  452. }
  453. }
  454. } else {
  455. for i, state := range guild.VoiceStates {
  456. if state.UserID == update.UserID {
  457. guild.VoiceStates[i] = update.VoiceState
  458. return nil
  459. }
  460. }
  461. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  462. }
  463. return nil
  464. }
  465. // Message gets a message by channel and message ID.
  466. func (s *State) Message(channelID, messageID string) (*Message, error) {
  467. if s == nil {
  468. return nil, ErrNilState
  469. }
  470. c, err := s.Channel(channelID)
  471. if err != nil {
  472. return nil, err
  473. }
  474. s.RLock()
  475. defer s.RUnlock()
  476. for _, m := range c.Messages {
  477. if m.ID == messageID {
  478. return m, nil
  479. }
  480. }
  481. return nil, errors.New("Message not found.")
  482. }
  483. // OnReady takes a Ready event and updates all internal state.
  484. func (s *State) onReady(se *Session, r *Ready) (err error) {
  485. if s == nil {
  486. return ErrNilState
  487. }
  488. s.Lock()
  489. defer s.Unlock()
  490. // We must track at least the current user for Voice, even
  491. // if state is disabled, store the bare essentials.
  492. if !se.StateEnabled {
  493. ready := Ready{
  494. Version: r.Version,
  495. SessionID: r.SessionID,
  496. HeartbeatInterval: r.HeartbeatInterval,
  497. User: r.User,
  498. }
  499. s.Ready = ready
  500. return nil
  501. }
  502. s.Ready = *r
  503. for _, g := range s.Guilds {
  504. s.guildMap[g.ID] = g
  505. for _, c := range g.Channels {
  506. c.GuildID = g.ID
  507. s.channelMap[c.ID] = c
  508. }
  509. }
  510. for _, c := range s.PrivateChannels {
  511. s.channelMap[c.ID] = c
  512. }
  513. return nil
  514. }
  515. // onInterface handles all events related to states.
  516. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  517. if s == nil {
  518. return ErrNilState
  519. }
  520. if !se.StateEnabled {
  521. return nil
  522. }
  523. switch t := i.(type) {
  524. case *GuildCreate:
  525. err = s.GuildAdd(t.Guild)
  526. case *GuildUpdate:
  527. err = s.GuildAdd(t.Guild)
  528. case *GuildDelete:
  529. err = s.GuildRemove(t.Guild)
  530. case *GuildMemberAdd:
  531. if s.TrackMembers {
  532. err = s.MemberAdd(t.Member)
  533. }
  534. case *GuildMemberUpdate:
  535. if s.TrackMembers {
  536. err = s.MemberAdd(t.Member)
  537. }
  538. case *GuildMemberRemove:
  539. if s.TrackMembers {
  540. err = s.MemberRemove(t.Member)
  541. }
  542. case *GuildRoleCreate:
  543. if s.TrackRoles {
  544. err = s.RoleAdd(t.GuildID, t.Role)
  545. }
  546. case *GuildRoleUpdate:
  547. if s.TrackRoles {
  548. err = s.RoleAdd(t.GuildID, t.Role)
  549. }
  550. case *GuildRoleDelete:
  551. if s.TrackRoles {
  552. err = s.RoleRemove(t.GuildID, t.RoleID)
  553. }
  554. case *GuildEmojisUpdate:
  555. if s.TrackEmojis {
  556. err = s.EmojisAdd(t.GuildID, t.Emojis)
  557. }
  558. case *ChannelCreate:
  559. if s.TrackChannels {
  560. err = s.ChannelAdd(t.Channel)
  561. }
  562. case *ChannelUpdate:
  563. if s.TrackChannels {
  564. err = s.ChannelAdd(t.Channel)
  565. }
  566. case *ChannelDelete:
  567. if s.TrackChannels {
  568. err = s.ChannelRemove(t.Channel)
  569. }
  570. case *MessageCreate:
  571. if s.MaxMessageCount != 0 {
  572. err = s.MessageAdd(t.Message)
  573. }
  574. case *MessageUpdate:
  575. if s.MaxMessageCount != 0 {
  576. err = s.MessageAdd(t.Message)
  577. }
  578. case *MessageDelete:
  579. if s.MaxMessageCount != 0 {
  580. err = s.MessageRemove(t.Message)
  581. }
  582. case *VoiceStateUpdate:
  583. if s.TrackVoice {
  584. err = s.voiceStateUpdate(t)
  585. }
  586. }
  587. return
  588. }
  589. // UserChannelPermissions returns the permission of a user in a channel.
  590. // userID : The ID of the user to calculate permissions for.
  591. // channelID : The ID of the channel to calculate permission for.
  592. func (s *State) UserChannelPermissions(userID, channelID string) (apermissions int, err error) {
  593. if s == nil {
  594. return 0, ErrNilState
  595. }
  596. channel, err := s.Channel(channelID)
  597. if err != nil {
  598. return
  599. }
  600. guild, err := s.Guild(channel.GuildID)
  601. if err != nil {
  602. return
  603. }
  604. if userID == guild.OwnerID {
  605. apermissions = PermissionAll
  606. return
  607. }
  608. member, err := s.Member(guild.ID, userID)
  609. if err != nil {
  610. return
  611. }
  612. for _, role := range guild.Roles {
  613. if role.ID == guild.ID {
  614. apermissions |= role.Permissions
  615. break
  616. }
  617. }
  618. for _, role := range guild.Roles {
  619. for _, roleID := range member.Roles {
  620. if role.ID == roleID {
  621. apermissions |= role.Permissions
  622. break
  623. }
  624. }
  625. }
  626. if apermissions&PermissionAdministrator > 0 {
  627. apermissions |= PermissionAll
  628. }
  629. // Member overwrites can override role overrides, so do two passes
  630. for _, overwrite := range channel.PermissionOverwrites {
  631. for _, roleID := range member.Roles {
  632. if overwrite.Type == "role" && roleID == overwrite.ID {
  633. apermissions &= ^overwrite.Deny
  634. apermissions |= overwrite.Allow
  635. break
  636. }
  637. }
  638. }
  639. for _, overwrite := range channel.PermissionOverwrites {
  640. if overwrite.Type == "member" && overwrite.ID == userID {
  641. apermissions &= ^overwrite.Deny
  642. apermissions |= overwrite.Allow
  643. break
  644. }
  645. }
  646. if apermissions&PermissionManageRoles > 0 {
  647. apermissions |= PermissionAllChannel
  648. }
  649. return
  650. }