state.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sync"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // A State contains the current known state.
  18. // As discord sends this in a READY blob, it seems reasonable to simply
  19. // use that struct as the data store.
  20. type State struct {
  21. sync.RWMutex
  22. Ready
  23. MaxMessageCount int
  24. guildMap map[string]*Guild
  25. channelMap map[string]*Channel
  26. }
  27. // NewState creates an empty state.
  28. func NewState() *State {
  29. return &State{
  30. Ready: Ready{
  31. PrivateChannels: []*Channel{},
  32. Guilds: []*Guild{},
  33. },
  34. guildMap: make(map[string]*Guild),
  35. channelMap: make(map[string]*Channel),
  36. }
  37. }
  38. // OnReady takes a Ready event and updates all internal state.
  39. func (s *State) OnReady(r *Ready) error {
  40. if s == nil {
  41. return ErrNilState
  42. }
  43. s.Lock()
  44. defer s.Unlock()
  45. s.Ready = *r
  46. for _, g := range s.Guilds {
  47. s.guildMap[g.ID] = g
  48. for _, c := range g.Channels {
  49. c.GuildID = g.ID
  50. s.channelMap[c.ID] = c
  51. }
  52. }
  53. for _, c := range s.PrivateChannels {
  54. s.channelMap[c.ID] = c
  55. }
  56. return nil
  57. }
  58. // GuildAdd adds a guild to the current world state, or
  59. // updates it if it already exists.
  60. func (s *State) GuildAdd(guild *Guild) error {
  61. if s == nil {
  62. return ErrNilState
  63. }
  64. s.Lock()
  65. defer s.Unlock()
  66. // Update the channels to point to the right guild, adding them to the channelMap as we go
  67. for _, c := range guild.Channels {
  68. c.GuildID = guild.ID
  69. s.channelMap[c.ID] = c
  70. }
  71. // If the guild exists, replace it.
  72. if g, ok := s.guildMap[guild.ID]; ok {
  73. // If this guild already exists with data, don't stomp on props.
  74. if g.Unavailable != nil && !*g.Unavailable {
  75. guild.Members = g.Members
  76. guild.Presences = g.Presences
  77. guild.Channels = g.Channels
  78. guild.VoiceStates = g.VoiceStates
  79. }
  80. *g = *guild
  81. return nil
  82. }
  83. s.Guilds = append(s.Guilds, guild)
  84. s.guildMap[guild.ID] = guild
  85. return nil
  86. }
  87. // GuildRemove removes a guild from current world state.
  88. func (s *State) GuildRemove(guild *Guild) error {
  89. if s == nil {
  90. return ErrNilState
  91. }
  92. _, err := s.Guild(guild.ID)
  93. if err != nil {
  94. return err
  95. }
  96. s.Lock()
  97. defer s.Unlock()
  98. delete(s.guildMap, guild.ID)
  99. for i, g := range s.Guilds {
  100. if g.ID == guild.ID {
  101. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  102. return nil
  103. }
  104. }
  105. return nil
  106. }
  107. // Guild gets a guild by ID.
  108. // Useful for querying if @me is in a guild:
  109. // _, err := discordgo.Session.State.Guild(guildID)
  110. // isInGuild := err == nil
  111. func (s *State) Guild(guildID string) (*Guild, error) {
  112. if s == nil {
  113. return nil, ErrNilState
  114. }
  115. s.RLock()
  116. defer s.RUnlock()
  117. if g, ok := s.guildMap[guildID]; ok {
  118. return g, nil
  119. }
  120. return nil, errors.New("Guild not found.")
  121. }
  122. // TODO: Consider moving Guild state update methods onto *Guild.
  123. // MemberAdd adds a member to the current world state, or
  124. // updates it if it already exists.
  125. func (s *State) MemberAdd(member *Member) error {
  126. if s == nil {
  127. return ErrNilState
  128. }
  129. guild, err := s.Guild(member.GuildID)
  130. if err != nil {
  131. return err
  132. }
  133. s.Lock()
  134. defer s.Unlock()
  135. for i, m := range guild.Members {
  136. if m.User.ID == member.User.ID {
  137. guild.Members[i] = member
  138. return nil
  139. }
  140. }
  141. guild.Members = append(guild.Members, member)
  142. return nil
  143. }
  144. // MemberRemove removes a member from current world state.
  145. func (s *State) MemberRemove(member *Member) error {
  146. if s == nil {
  147. return ErrNilState
  148. }
  149. guild, err := s.Guild(member.GuildID)
  150. if err != nil {
  151. return err
  152. }
  153. s.Lock()
  154. defer s.Unlock()
  155. for i, m := range guild.Members {
  156. if m.User.ID == member.User.ID {
  157. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  158. return nil
  159. }
  160. }
  161. return errors.New("Member not found.")
  162. }
  163. // Member gets a member by ID from a guild.
  164. func (s *State) Member(guildID, userID string) (*Member, error) {
  165. if s == nil {
  166. return nil, ErrNilState
  167. }
  168. guild, err := s.Guild(guildID)
  169. if err != nil {
  170. return nil, err
  171. }
  172. s.RLock()
  173. defer s.RUnlock()
  174. for _, m := range guild.Members {
  175. if m.User.ID == userID {
  176. return m, nil
  177. }
  178. }
  179. return nil, errors.New("Member not found.")
  180. }
  181. // ChannelAdd adds a guild to the current world state, or
  182. // updates it if it already exists.
  183. // Channels may exist either as PrivateChannels or inside
  184. // a guild.
  185. func (s *State) ChannelAdd(channel *Channel) error {
  186. if s == nil {
  187. return ErrNilState
  188. }
  189. s.Lock()
  190. defer s.Unlock()
  191. // If the channel exists, replace it
  192. if c, ok := s.channelMap[channel.ID]; ok {
  193. channel.Messages = c.Messages
  194. channel.PermissionOverwrites = c.PermissionOverwrites
  195. *c = *channel
  196. return nil
  197. }
  198. if channel.IsPrivate {
  199. s.PrivateChannels = append(s.PrivateChannels, channel)
  200. } else {
  201. guild, ok := s.guildMap[channel.GuildID]
  202. if !ok {
  203. return errors.New("Guild for channel not found.")
  204. }
  205. guild.Channels = append(guild.Channels, channel)
  206. }
  207. s.channelMap[channel.ID] = channel
  208. return nil
  209. }
  210. // ChannelRemove removes a channel from current world state.
  211. func (s *State) ChannelRemove(channel *Channel) error {
  212. if s == nil {
  213. return ErrNilState
  214. }
  215. _, err := s.Channel(channel.ID)
  216. if err != nil {
  217. return err
  218. }
  219. if channel.IsPrivate {
  220. s.Lock()
  221. defer s.Unlock()
  222. for i, c := range s.PrivateChannels {
  223. if c.ID == channel.ID {
  224. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  225. break
  226. }
  227. }
  228. } else {
  229. guild, err := s.Guild(channel.GuildID)
  230. if err != nil {
  231. return err
  232. }
  233. s.Lock()
  234. defer s.Unlock()
  235. for i, c := range guild.Channels {
  236. if c.ID == channel.ID {
  237. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  238. break
  239. }
  240. }
  241. }
  242. delete(s.channelMap, channel.ID)
  243. return nil
  244. }
  245. // GuildChannel gets a channel by ID from a guild.
  246. // This method is Deprecated, use Channel(channelID)
  247. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  248. return s.Channel(channelID)
  249. }
  250. // PrivateChannel gets a private channel by ID.
  251. // This method is Deprecated, use Channel(channelID)
  252. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  253. return s.Channel(channelID)
  254. }
  255. // Channel gets a channel by ID, it will look in all guilds an private channels.
  256. func (s *State) Channel(channelID string) (*Channel, error) {
  257. if s == nil {
  258. return nil, ErrNilState
  259. }
  260. s.RLock()
  261. defer s.RUnlock()
  262. if c, ok := s.channelMap[channelID]; ok {
  263. return c, nil
  264. }
  265. return nil, errors.New("Channel not found.")
  266. }
  267. // Emoji returns an emoji for a guild and emoji id.
  268. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  269. if s == nil {
  270. return nil, ErrNilState
  271. }
  272. guild, err := s.Guild(guildID)
  273. if err != nil {
  274. return nil, err
  275. }
  276. s.RLock()
  277. defer s.RUnlock()
  278. for _, e := range guild.Emojis {
  279. if e.ID == emojiID {
  280. return e, nil
  281. }
  282. }
  283. return nil, errors.New("Emoji not found.")
  284. }
  285. // EmojiAdd adds an emoji to the current world state.
  286. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  287. if s == nil {
  288. return ErrNilState
  289. }
  290. guild, err := s.Guild(guildID)
  291. if err != nil {
  292. return err
  293. }
  294. s.Lock()
  295. defer s.Unlock()
  296. for i, e := range guild.Emojis {
  297. if e.ID == emoji.ID {
  298. guild.Emojis[i] = emoji
  299. return nil
  300. }
  301. }
  302. guild.Emojis = append(guild.Emojis, emoji)
  303. return nil
  304. }
  305. // EmojisAdd adds multiple emojis to the world state.
  306. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  307. for _, e := range emojis {
  308. if err := s.EmojiAdd(guildID, e); err != nil {
  309. return err
  310. }
  311. }
  312. return nil
  313. }
  314. // MessageAdd adds a message to the current world state, or updates it if it exists.
  315. // If the channel cannot be found, the message is discarded.
  316. // Messages are kept in state up to s.MaxMessageCount
  317. func (s *State) MessageAdd(message *Message) error {
  318. if s == nil {
  319. return ErrNilState
  320. }
  321. if s.MaxMessageCount == 0 {
  322. return nil
  323. }
  324. c, err := s.Channel(message.ChannelID)
  325. if err != nil {
  326. return err
  327. }
  328. s.Lock()
  329. defer s.Unlock()
  330. // If the message exists, merge in the new message contents.
  331. for _, m := range c.Messages {
  332. if m.ID == message.ID {
  333. if message.Content != "" {
  334. m.Content = message.Content
  335. }
  336. if message.EditedTimestamp != "" {
  337. m.EditedTimestamp = message.EditedTimestamp
  338. }
  339. if message.Mentions != nil {
  340. m.Mentions = message.Mentions
  341. }
  342. if message.Embeds != nil {
  343. m.Embeds = message.Embeds
  344. }
  345. if message.Attachments != nil {
  346. m.Attachments = message.Attachments
  347. }
  348. return nil
  349. }
  350. }
  351. c.Messages = append(c.Messages, message)
  352. if len(c.Messages) > s.MaxMessageCount {
  353. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  354. }
  355. return nil
  356. }
  357. // MessageRemove removes a message from the world state.
  358. func (s *State) MessageRemove(message *Message) error {
  359. if s == nil {
  360. return ErrNilState
  361. }
  362. if s.MaxMessageCount == 0 {
  363. return nil
  364. }
  365. c, err := s.Channel(message.ChannelID)
  366. if err != nil {
  367. return err
  368. }
  369. s.Lock()
  370. defer s.Unlock()
  371. for i, m := range c.Messages {
  372. if m.ID == message.ID {
  373. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  374. return nil
  375. }
  376. }
  377. return errors.New("Message not found.")
  378. }
  379. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  380. guild, err := s.Guild(update.GuildID)
  381. if err != nil {
  382. return err
  383. }
  384. s.Lock()
  385. defer s.Unlock()
  386. // Handle Leaving Channel
  387. if update.ChannelID == "" {
  388. for i, state := range guild.VoiceStates {
  389. if state.UserID == update.UserID {
  390. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  391. return nil
  392. }
  393. }
  394. } else {
  395. for i, state := range guild.VoiceStates {
  396. if state.UserID == update.UserID {
  397. guild.VoiceStates[i] = update.VoiceState
  398. return nil
  399. }
  400. }
  401. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  402. }
  403. return nil
  404. }
  405. // Message gets a message by channel and message ID.
  406. func (s *State) Message(channelID, messageID string) (*Message, error) {
  407. if s == nil {
  408. return nil, ErrNilState
  409. }
  410. c, err := s.Channel(channelID)
  411. if err != nil {
  412. return nil, err
  413. }
  414. s.RLock()
  415. defer s.RUnlock()
  416. for _, m := range c.Messages {
  417. if m.ID == messageID {
  418. return m, nil
  419. }
  420. }
  421. return nil, errors.New("Message not found.")
  422. }
  423. // onInterface handles all events related to states.
  424. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  425. if s == nil {
  426. return ErrNilState
  427. }
  428. if !se.StateEnabled {
  429. return nil
  430. }
  431. switch t := i.(type) {
  432. case *Ready:
  433. err = s.OnReady(t)
  434. case *GuildCreate:
  435. err = s.GuildAdd(t.Guild)
  436. case *GuildUpdate:
  437. err = s.GuildAdd(t.Guild)
  438. case *GuildDelete:
  439. err = s.GuildRemove(t.Guild)
  440. case *GuildMemberAdd:
  441. err = s.MemberAdd(t.Member)
  442. case *GuildMemberUpdate:
  443. err = s.MemberAdd(t.Member)
  444. case *GuildMemberRemove:
  445. err = s.MemberRemove(t.Member)
  446. case *GuildEmojisUpdate:
  447. err = s.EmojisAdd(t.GuildID, t.Emojis)
  448. case *ChannelCreate:
  449. err = s.ChannelAdd(t.Channel)
  450. case *ChannelUpdate:
  451. err = s.ChannelAdd(t.Channel)
  452. case *ChannelDelete:
  453. err = s.ChannelRemove(t.Channel)
  454. case *MessageCreate:
  455. err = s.MessageAdd(t.Message)
  456. case *MessageUpdate:
  457. err = s.MessageAdd(t.Message)
  458. case *MessageDelete:
  459. err = s.MessageRemove(t.Message)
  460. case *VoiceStateUpdate:
  461. err = s.voiceStateUpdate(t)
  462. }
  463. return
  464. }