state.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import "errors"
  12. // ErrNilState is returned when the state is nil.
  13. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  14. // NewState creates an empty state.
  15. func NewState() *State {
  16. return &State{
  17. Ready: Ready{
  18. PrivateChannels: []*Channel{},
  19. Guilds: []*Guild{},
  20. },
  21. guildMap: make(map[string]*Guild),
  22. channelMap: make(map[string]*Channel),
  23. }
  24. }
  25. // OnReady takes a Ready event and updates all internal state.
  26. func (s *State) OnReady(r *Ready) error {
  27. if s == nil {
  28. return ErrNilState
  29. }
  30. s.Lock()
  31. defer s.Unlock()
  32. s.Ready = *r
  33. for _, g := range s.Guilds {
  34. for _, c := range g.Channels {
  35. c.GuildID = g.ID
  36. }
  37. s.guildMap[g.ID] = g
  38. }
  39. return nil
  40. }
  41. // GuildAdd adds a guild to the current world state, or
  42. // updates it if it already exists.
  43. func (s *State) GuildAdd(guild *Guild) error {
  44. if s == nil {
  45. return ErrNilState
  46. }
  47. // Update the channels to point to the right guild
  48. for _, c := range guild.Channels {
  49. c.GuildID = guild.ID
  50. }
  51. // If the guild exists, replace it.
  52. if g, err := s.Guild(guild.ID); err == nil {
  53. s.Lock()
  54. defer s.Unlock()
  55. // If this guild already exists with data, don't stomp on props.
  56. if g.Unavailable != nil && !*g.Unavailable {
  57. guild.Members = g.Members
  58. guild.Presences = g.Presences
  59. guild.Channels = g.Channels
  60. guild.VoiceStates = g.VoiceStates
  61. }
  62. *g = *guild
  63. return nil
  64. }
  65. s.Lock()
  66. defer s.Unlock()
  67. s.Guilds = append(s.Guilds, guild)
  68. s.guildMap[guild.ID] = guild
  69. return nil
  70. }
  71. // GuildRemove removes a guild from current world state.
  72. func (s *State) GuildRemove(guild *Guild) error {
  73. if s == nil {
  74. return ErrNilState
  75. }
  76. _, err := s.Guild(guild.ID)
  77. if err != nil {
  78. return err
  79. }
  80. s.Lock()
  81. defer s.Unlock()
  82. for i, g := range s.Guilds {
  83. if g.ID == guild.ID {
  84. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  85. return nil
  86. }
  87. }
  88. delete(s.guildMap, guild.ID)
  89. return nil
  90. }
  91. // Guild gets a guild by ID.
  92. // Useful for querying if @me is in a guild:
  93. // _, err := discordgo.Session.State.Guild(guildID)
  94. // isInGuild := err == nil
  95. func (s *State) Guild(guildID string) (*Guild, error) {
  96. if s == nil {
  97. return nil, ErrNilState
  98. }
  99. s.RLock()
  100. defer s.RUnlock()
  101. if g, ok := s.guildMap[guildID]; ok {
  102. return g, nil
  103. }
  104. return nil, errors.New("Guild not found.")
  105. }
  106. // TODO: Consider moving Guild state update methods onto *Guild.
  107. // MemberAdd adds a member to the current world state, or
  108. // updates it if it already exists.
  109. func (s *State) MemberAdd(member *Member) error {
  110. if s == nil {
  111. return ErrNilState
  112. }
  113. guild, err := s.Guild(member.GuildID)
  114. if err != nil {
  115. return err
  116. }
  117. s.Lock()
  118. defer s.Unlock()
  119. for i, m := range guild.Members {
  120. if m.User.ID == member.User.ID {
  121. guild.Members[i] = member
  122. return nil
  123. }
  124. }
  125. guild.Members = append(guild.Members, member)
  126. return nil
  127. }
  128. // MemberRemove removes a member from current world state.
  129. func (s *State) MemberRemove(member *Member) error {
  130. if s == nil {
  131. return ErrNilState
  132. }
  133. guild, err := s.Guild(member.GuildID)
  134. if err != nil {
  135. return err
  136. }
  137. s.Lock()
  138. defer s.Unlock()
  139. for i, m := range guild.Members {
  140. if m.User.ID == member.User.ID {
  141. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  142. return nil
  143. }
  144. }
  145. return errors.New("Member not found.")
  146. }
  147. // Member gets a member by ID from a guild.
  148. func (s *State) Member(guildID, userID string) (*Member, error) {
  149. if s == nil {
  150. return nil, ErrNilState
  151. }
  152. guild, err := s.Guild(guildID)
  153. if err != nil {
  154. return nil, err
  155. }
  156. s.RLock()
  157. defer s.RUnlock()
  158. for _, m := range guild.Members {
  159. if m.User.ID == userID {
  160. return m, nil
  161. }
  162. }
  163. return nil, errors.New("Member not found.")
  164. }
  165. // ChannelAdd adds a guild to the current world state, or
  166. // updates it if it already exists.
  167. // Channels may exist either as PrivateChannels or inside
  168. // a guild.
  169. func (s *State) ChannelAdd(channel *Channel) error {
  170. if s == nil {
  171. return ErrNilState
  172. }
  173. // If the channel exists, replace it.
  174. if c, err := s.Channel(channel.ID); err == nil {
  175. s.Lock()
  176. defer s.Unlock()
  177. channel.Messages = c.Messages
  178. channel.PermissionOverwrites = c.PermissionOverwrites
  179. *c = *channel
  180. return nil
  181. }
  182. s.Lock()
  183. defer s.Unlock()
  184. if channel.IsPrivate {
  185. s.PrivateChannels = append(s.PrivateChannels, channel)
  186. } else {
  187. guild, err := s.Guild(channel.GuildID)
  188. if err != nil {
  189. return err
  190. }
  191. guild.Channels = append(guild.Channels, channel)
  192. }
  193. s.channelMap[channel.ID] = channel
  194. return nil
  195. }
  196. // ChannelRemove removes a channel from current world state.
  197. func (s *State) ChannelRemove(channel *Channel) error {
  198. if s == nil {
  199. return ErrNilState
  200. }
  201. _, err := s.Channel(channel.ID)
  202. if err != nil {
  203. return err
  204. }
  205. if channel.IsPrivate {
  206. s.Lock()
  207. defer s.Unlock()
  208. for i, c := range s.PrivateChannels {
  209. if c.ID == channel.ID {
  210. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  211. return nil
  212. }
  213. }
  214. } else {
  215. guild, err := s.Guild(channel.GuildID)
  216. if err != nil {
  217. return err
  218. }
  219. s.Lock()
  220. defer s.Unlock()
  221. for i, c := range guild.Channels {
  222. if c.ID == channel.ID {
  223. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  224. return nil
  225. }
  226. }
  227. }
  228. delete(s.channelMap, channel.ID)
  229. return nil
  230. }
  231. // GuildChannel gets a channel by ID from a guild.
  232. // This method is Deprecated, use Channel(channelID)
  233. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  234. return s.Channel(channelID)
  235. }
  236. // PrivateChannel gets a private channel by ID.
  237. // This method is Deprecated, use Channel(channelID)
  238. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  239. return s.Channel(channelID)
  240. }
  241. // Channel gets a channel by ID, it will look in all guilds an private channels.
  242. func (s *State) Channel(channelID string) (*Channel, error) {
  243. if s == nil {
  244. return nil, ErrNilState
  245. }
  246. if c, ok := s.channelMap[channelID]; ok {
  247. return c, nil
  248. }
  249. return nil, errors.New("Channel not found.")
  250. }
  251. // Emoji returns an emoji for a guild and emoji id.
  252. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  253. if s == nil {
  254. return nil, ErrNilState
  255. }
  256. guild, err := s.Guild(guildID)
  257. if err != nil {
  258. return nil, err
  259. }
  260. s.RLock()
  261. defer s.RUnlock()
  262. for _, e := range guild.Emojis {
  263. if e.ID == emojiID {
  264. return e, nil
  265. }
  266. }
  267. return nil, errors.New("Emoji not found.")
  268. }
  269. // EmojiAdd adds an emoji to the current world state.
  270. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  271. if s == nil {
  272. return ErrNilState
  273. }
  274. guild, err := s.Guild(guildID)
  275. if err != nil {
  276. return err
  277. }
  278. s.Lock()
  279. defer s.Unlock()
  280. for i, e := range guild.Emojis {
  281. if e.ID == emoji.ID {
  282. guild.Emojis[i] = emoji
  283. return nil
  284. }
  285. }
  286. guild.Emojis = append(guild.Emojis, emoji)
  287. return nil
  288. }
  289. // EmojisAdd adds multiple emojis to the world state.
  290. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  291. for _, e := range emojis {
  292. if err := s.EmojiAdd(guildID, e); err != nil {
  293. return err
  294. }
  295. }
  296. return nil
  297. }
  298. // MessageAdd adds a message to the current world state, or updates it if it exists.
  299. // If the channel cannot be found, the message is discarded.
  300. // Messages are kept in state up to s.MaxMessageCount
  301. func (s *State) MessageAdd(message *Message) error {
  302. if s == nil {
  303. return ErrNilState
  304. }
  305. if s.MaxMessageCount == 0 {
  306. return nil
  307. }
  308. c, err := s.Channel(message.ChannelID)
  309. if err != nil {
  310. return err
  311. }
  312. s.Lock()
  313. defer s.Unlock()
  314. // If the message exists, replace it.
  315. for i, m := range c.Messages {
  316. if m.ID == message.ID {
  317. c.Messages[i] = message
  318. return nil
  319. }
  320. }
  321. c.Messages = append(c.Messages, message)
  322. if len(c.Messages) > s.MaxMessageCount {
  323. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  324. }
  325. return nil
  326. }
  327. // MessageRemove removes a message from the world state.
  328. func (s *State) MessageRemove(message *Message) error {
  329. if s == nil {
  330. return ErrNilState
  331. }
  332. if s.MaxMessageCount == 0 {
  333. return nil
  334. }
  335. c, err := s.Channel(message.ChannelID)
  336. if err != nil {
  337. return err
  338. }
  339. s.Lock()
  340. defer s.Unlock()
  341. for i, m := range c.Messages {
  342. if m.ID == message.ID {
  343. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  344. return nil
  345. }
  346. }
  347. return errors.New("Message not found.")
  348. }
  349. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  350. guild, err := s.Guild(update.GuildID)
  351. if err != nil {
  352. return err
  353. }
  354. s.Lock()
  355. defer s.Unlock()
  356. // Handle Leaving Channel
  357. if update.ChannelID == "" {
  358. for i, state := range guild.VoiceStates {
  359. if state.UserID == update.UserID {
  360. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  361. return nil
  362. }
  363. }
  364. } else {
  365. for i, state := range guild.VoiceStates {
  366. if state.UserID == update.UserID {
  367. guild.VoiceStates[i] = update.VoiceState
  368. return nil
  369. }
  370. }
  371. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  372. }
  373. return nil
  374. }
  375. // Message gets a message by channel and message ID.
  376. func (s *State) Message(channelID, messageID string) (*Message, error) {
  377. if s == nil {
  378. return nil, ErrNilState
  379. }
  380. c, err := s.Channel(channelID)
  381. if err != nil {
  382. return nil, err
  383. }
  384. s.RLock()
  385. defer s.RUnlock()
  386. for _, m := range c.Messages {
  387. if m.ID == messageID {
  388. return m, nil
  389. }
  390. }
  391. return nil, errors.New("Message not found.")
  392. }
  393. // onInterface handles all events related to states.
  394. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  395. if s == nil {
  396. return ErrNilState
  397. }
  398. if !se.StateEnabled {
  399. return nil
  400. }
  401. switch t := i.(type) {
  402. case *Ready:
  403. err = s.OnReady(t)
  404. case *GuildCreate:
  405. err = s.GuildAdd(t.Guild)
  406. case *GuildUpdate:
  407. err = s.GuildAdd(t.Guild)
  408. case *GuildDelete:
  409. err = s.GuildRemove(t.Guild)
  410. case *GuildMemberAdd:
  411. err = s.MemberAdd(t.Member)
  412. case *GuildMemberUpdate:
  413. err = s.MemberAdd(t.Member)
  414. case *GuildMemberRemove:
  415. err = s.MemberRemove(t.Member)
  416. case *GuildEmojisUpdate:
  417. err = s.EmojisAdd(t.GuildID, t.Emojis)
  418. case *ChannelCreate:
  419. err = s.ChannelAdd(t.Channel)
  420. case *ChannelUpdate:
  421. err = s.ChannelAdd(t.Channel)
  422. case *ChannelDelete:
  423. err = s.ChannelRemove(t.Channel)
  424. case *MessageCreate:
  425. err = s.MessageAdd(t.Message)
  426. case *MessageUpdate:
  427. err = s.MessageAdd(t.Message)
  428. case *MessageDelete:
  429. err = s.MessageRemove(t.Message)
  430. case *VoiceStateUpdate:
  431. err = s.voiceStateUpdate(t)
  432. }
  433. return
  434. }