state.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "sync"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // A State contains the current known state.
  18. // As discord sends this in a READY blob, it seems reasonable to simply
  19. // use that struct as the data store.
  20. type State struct {
  21. sync.RWMutex
  22. Ready
  23. MaxMessageCount int
  24. guildMap map[string]*Guild
  25. channelMap map[string]*Channel
  26. }
  27. // NewState creates an empty state.
  28. func NewState() *State {
  29. return &State{
  30. Ready: Ready{
  31. PrivateChannels: []*Channel{},
  32. Guilds: []*Guild{},
  33. },
  34. guildMap: make(map[string]*Guild),
  35. channelMap: make(map[string]*Channel),
  36. }
  37. }
  38. // OnReady takes a Ready event and updates all internal state.
  39. func (s *State) OnReady(r *Ready) error {
  40. if s == nil {
  41. return ErrNilState
  42. }
  43. s.Lock()
  44. defer s.Unlock()
  45. s.Ready = *r
  46. for _, g := range s.Guilds {
  47. s.guildMap[g.ID] = g
  48. for _, c := range g.Channels {
  49. c.GuildID = g.ID
  50. s.channelMap[c.ID] = c
  51. }
  52. }
  53. for _, c := range s.PrivateChannels {
  54. s.channelMap[c.ID] = c
  55. }
  56. return nil
  57. }
  58. // GuildAdd adds a guild to the current world state, or
  59. // updates it if it already exists.
  60. func (s *State) GuildAdd(guild *Guild) error {
  61. if s == nil {
  62. return ErrNilState
  63. }
  64. s.Lock()
  65. defer s.Unlock()
  66. // Update the channels to point to the right guild, adding them to the channelMap as we go
  67. for _, c := range guild.Channels {
  68. c.GuildID = guild.ID
  69. s.channelMap[c.ID] = c
  70. }
  71. // If the guild exists, replace it.
  72. if g, ok := s.guildMap[guild.ID]; ok {
  73. // If this guild already exists with data, don't stomp on props.
  74. if g.Unavailable != nil && !*g.Unavailable {
  75. guild.Members = g.Members
  76. guild.Presences = g.Presences
  77. guild.Channels = g.Channels
  78. guild.VoiceStates = g.VoiceStates
  79. }
  80. *g = *guild
  81. return nil
  82. }
  83. s.Guilds = append(s.Guilds, guild)
  84. s.guildMap[guild.ID] = guild
  85. return nil
  86. }
  87. // GuildRemove removes a guild from current world state.
  88. func (s *State) GuildRemove(guild *Guild) error {
  89. if s == nil {
  90. return ErrNilState
  91. }
  92. _, err := s.Guild(guild.ID)
  93. if err != nil {
  94. return err
  95. }
  96. s.Lock()
  97. defer s.Unlock()
  98. delete(s.guildMap, guild.ID)
  99. for i, g := range s.Guilds {
  100. if g.ID == guild.ID {
  101. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  102. return nil
  103. }
  104. }
  105. return nil
  106. }
  107. // Guild gets a guild by ID.
  108. // Useful for querying if @me is in a guild:
  109. // _, err := discordgo.Session.State.Guild(guildID)
  110. // isInGuild := err == nil
  111. func (s *State) Guild(guildID string) (*Guild, error) {
  112. if s == nil {
  113. return nil, ErrNilState
  114. }
  115. s.RLock()
  116. defer s.RUnlock()
  117. if g, ok := s.guildMap[guildID]; ok {
  118. return g, nil
  119. }
  120. return nil, errors.New("Guild not found.")
  121. }
  122. // TODO: Consider moving Guild state update methods onto *Guild.
  123. // MemberAdd adds a member to the current world state, or
  124. // updates it if it already exists.
  125. func (s *State) MemberAdd(member *Member) error {
  126. if s == nil {
  127. return ErrNilState
  128. }
  129. guild, err := s.Guild(member.GuildID)
  130. if err != nil {
  131. return err
  132. }
  133. s.Lock()
  134. defer s.Unlock()
  135. for i, m := range guild.Members {
  136. if m.User.ID == member.User.ID {
  137. guild.Members[i] = member
  138. return nil
  139. }
  140. }
  141. guild.Members = append(guild.Members, member)
  142. return nil
  143. }
  144. // MemberRemove removes a member from current world state.
  145. func (s *State) MemberRemove(member *Member) error {
  146. if s == nil {
  147. return ErrNilState
  148. }
  149. guild, err := s.Guild(member.GuildID)
  150. if err != nil {
  151. return err
  152. }
  153. s.Lock()
  154. defer s.Unlock()
  155. for i, m := range guild.Members {
  156. if m.User.ID == member.User.ID {
  157. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  158. return nil
  159. }
  160. }
  161. return errors.New("Member not found.")
  162. }
  163. // Member gets a member by ID from a guild.
  164. func (s *State) Member(guildID, userID string) (*Member, error) {
  165. if s == nil {
  166. return nil, ErrNilState
  167. }
  168. guild, err := s.Guild(guildID)
  169. if err != nil {
  170. return nil, err
  171. }
  172. s.RLock()
  173. defer s.RUnlock()
  174. for _, m := range guild.Members {
  175. if m.User.ID == userID {
  176. return m, nil
  177. }
  178. }
  179. return nil, errors.New("Member not found.")
  180. }
  181. // ChannelAdd adds a guild to the current world state, or
  182. // updates it if it already exists.
  183. // Channels may exist either as PrivateChannels or inside
  184. // a guild.
  185. func (s *State) ChannelAdd(channel *Channel) error {
  186. if s == nil {
  187. return ErrNilState
  188. }
  189. s.Lock()
  190. defer s.Unlock()
  191. // If the channel exists, replace it
  192. if c, ok := s.channelMap[channel.ID]; ok {
  193. channel.Messages = c.Messages
  194. channel.PermissionOverwrites = c.PermissionOverwrites
  195. *c = *channel
  196. return nil
  197. }
  198. if channel.IsPrivate {
  199. s.PrivateChannels = append(s.PrivateChannels, channel)
  200. } else {
  201. guild, ok := s.guildMap[channel.GuildID]
  202. if !ok {
  203. return errors.New("Guild for channel not found.")
  204. }
  205. guild.Channels = append(guild.Channels, channel)
  206. }
  207. s.channelMap[channel.ID] = channel
  208. return nil
  209. }
  210. // ChannelRemove removes a channel from current world state.
  211. func (s *State) ChannelRemove(channel *Channel) error {
  212. if s == nil {
  213. return ErrNilState
  214. }
  215. _, err := s.Channel(channel.ID)
  216. if err != nil {
  217. return err
  218. }
  219. if channel.IsPrivate {
  220. s.Lock()
  221. defer s.Unlock()
  222. for i, c := range s.PrivateChannels {
  223. if c.ID == channel.ID {
  224. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  225. break
  226. }
  227. }
  228. } else {
  229. guild, err := s.Guild(channel.GuildID)
  230. if err != nil {
  231. return err
  232. }
  233. s.Lock()
  234. defer s.Unlock()
  235. for i, c := range guild.Channels {
  236. if c.ID == channel.ID {
  237. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  238. break
  239. }
  240. }
  241. }
  242. delete(s.channelMap, channel.ID)
  243. return nil
  244. }
  245. // GuildChannel gets a channel by ID from a guild.
  246. // This method is Deprecated, use Channel(channelID)
  247. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  248. return s.Channel(channelID)
  249. }
  250. // PrivateChannel gets a private channel by ID.
  251. // This method is Deprecated, use Channel(channelID)
  252. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  253. return s.Channel(channelID)
  254. }
  255. // Channel gets a channel by ID, it will look in all guilds an private channels.
  256. func (s *State) Channel(channelID string) (*Channel, error) {
  257. if s == nil {
  258. return nil, ErrNilState
  259. }
  260. if c, ok := s.channelMap[channelID]; ok {
  261. return c, nil
  262. }
  263. return nil, errors.New("Channel not found.")
  264. }
  265. // Emoji returns an emoji for a guild and emoji id.
  266. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  267. if s == nil {
  268. return nil, ErrNilState
  269. }
  270. guild, err := s.Guild(guildID)
  271. if err != nil {
  272. return nil, err
  273. }
  274. s.RLock()
  275. defer s.RUnlock()
  276. for _, e := range guild.Emojis {
  277. if e.ID == emojiID {
  278. return e, nil
  279. }
  280. }
  281. return nil, errors.New("Emoji not found.")
  282. }
  283. // EmojiAdd adds an emoji to the current world state.
  284. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  285. if s == nil {
  286. return ErrNilState
  287. }
  288. guild, err := s.Guild(guildID)
  289. if err != nil {
  290. return err
  291. }
  292. s.Lock()
  293. defer s.Unlock()
  294. for i, e := range guild.Emojis {
  295. if e.ID == emoji.ID {
  296. guild.Emojis[i] = emoji
  297. return nil
  298. }
  299. }
  300. guild.Emojis = append(guild.Emojis, emoji)
  301. return nil
  302. }
  303. // EmojisAdd adds multiple emojis to the world state.
  304. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  305. for _, e := range emojis {
  306. if err := s.EmojiAdd(guildID, e); err != nil {
  307. return err
  308. }
  309. }
  310. return nil
  311. }
  312. // MessageAdd adds a message to the current world state, or updates it if it exists.
  313. // If the channel cannot be found, the message is discarded.
  314. // Messages are kept in state up to s.MaxMessageCount
  315. func (s *State) MessageAdd(message *Message) error {
  316. if s == nil {
  317. return ErrNilState
  318. }
  319. if s.MaxMessageCount == 0 {
  320. return nil
  321. }
  322. c, err := s.Channel(message.ChannelID)
  323. if err != nil {
  324. return err
  325. }
  326. s.Lock()
  327. defer s.Unlock()
  328. // If the message exists, replace it.
  329. for i, m := range c.Messages {
  330. if m.ID == message.ID {
  331. c.Messages[i] = message
  332. return nil
  333. }
  334. }
  335. c.Messages = append(c.Messages, message)
  336. if len(c.Messages) > s.MaxMessageCount {
  337. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  338. }
  339. return nil
  340. }
  341. // MessageRemove removes a message from the world state.
  342. func (s *State) MessageRemove(message *Message) error {
  343. if s == nil {
  344. return ErrNilState
  345. }
  346. if s.MaxMessageCount == 0 {
  347. return nil
  348. }
  349. c, err := s.Channel(message.ChannelID)
  350. if err != nil {
  351. return err
  352. }
  353. s.Lock()
  354. defer s.Unlock()
  355. for i, m := range c.Messages {
  356. if m.ID == message.ID {
  357. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  358. return nil
  359. }
  360. }
  361. return errors.New("Message not found.")
  362. }
  363. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  364. guild, err := s.Guild(update.GuildID)
  365. if err != nil {
  366. return err
  367. }
  368. s.Lock()
  369. defer s.Unlock()
  370. // Handle Leaving Channel
  371. if update.ChannelID == "" {
  372. for i, state := range guild.VoiceStates {
  373. if state.UserID == update.UserID {
  374. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  375. return nil
  376. }
  377. }
  378. } else {
  379. for i, state := range guild.VoiceStates {
  380. if state.UserID == update.UserID {
  381. guild.VoiceStates[i] = update.VoiceState
  382. return nil
  383. }
  384. }
  385. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  386. }
  387. return nil
  388. }
  389. // Message gets a message by channel and message ID.
  390. func (s *State) Message(channelID, messageID string) (*Message, error) {
  391. if s == nil {
  392. return nil, ErrNilState
  393. }
  394. c, err := s.Channel(channelID)
  395. if err != nil {
  396. return nil, err
  397. }
  398. s.RLock()
  399. defer s.RUnlock()
  400. for _, m := range c.Messages {
  401. if m.ID == messageID {
  402. return m, nil
  403. }
  404. }
  405. return nil, errors.New("Message not found.")
  406. }
  407. // onInterface handles all events related to states.
  408. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  409. if s == nil {
  410. return ErrNilState
  411. }
  412. if !se.StateEnabled {
  413. return nil
  414. }
  415. switch t := i.(type) {
  416. case *Ready:
  417. err = s.OnReady(t)
  418. case *GuildCreate:
  419. err = s.GuildAdd(t.Guild)
  420. case *GuildUpdate:
  421. err = s.GuildAdd(t.Guild)
  422. case *GuildDelete:
  423. err = s.GuildRemove(t.Guild)
  424. case *GuildMemberAdd:
  425. err = s.MemberAdd(t.Member)
  426. case *GuildMemberUpdate:
  427. err = s.MemberAdd(t.Member)
  428. case *GuildMemberRemove:
  429. err = s.MemberRemove(t.Member)
  430. case *GuildEmojisUpdate:
  431. err = s.EmojisAdd(t.GuildID, t.Emojis)
  432. case *ChannelCreate:
  433. err = s.ChannelAdd(t.Channel)
  434. case *ChannelUpdate:
  435. err = s.ChannelAdd(t.Channel)
  436. case *ChannelDelete:
  437. err = s.ChannelRemove(t.Channel)
  438. case *MessageCreate:
  439. err = s.MessageAdd(t.Message)
  440. case *MessageUpdate:
  441. err = s.MessageAdd(t.Message)
  442. case *MessageDelete:
  443. err = s.MessageRemove(t.Message)
  444. case *VoiceStateUpdate:
  445. err = s.voiceStateUpdate(t)
  446. }
  447. return
  448. }