state.go 11 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "log"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // NewState creates an empty state.
  18. func NewState() *State {
  19. return &State{
  20. Ready: Ready{
  21. PrivateChannels: []*Channel{},
  22. Guilds: []*Guild{},
  23. },
  24. }
  25. }
  26. // OnReady takes a Ready event and updates all internal state.
  27. func (s *State) OnReady(r *Ready) error {
  28. if s == nil {
  29. return ErrNilState
  30. }
  31. s.Lock()
  32. defer s.Unlock()
  33. s.Ready = *r
  34. for _, g := range s.Guilds {
  35. for _, c := range g.Channels {
  36. c.GuildID = g.ID
  37. }
  38. }
  39. return nil
  40. }
  41. // GuildAdd adds a guild to the current world state, or
  42. // updates it if it already exists.
  43. func (s *State) GuildAdd(guild *Guild) error {
  44. if s == nil {
  45. return ErrNilState
  46. }
  47. s.Lock()
  48. defer s.Unlock()
  49. // If the guild exists, replace it.
  50. for i, g := range s.Guilds {
  51. if g.ID == guild.ID {
  52. // Don't stomp on properties that don't come in updates.
  53. guild.Members = g.Members
  54. guild.Presences = g.Presences
  55. guild.Channels = g.Channels
  56. guild.VoiceStates = g.VoiceStates
  57. s.Guilds[i] = guild
  58. return nil
  59. }
  60. }
  61. // Otherwise, update the channels to point to the right guild
  62. for _, c := range guild.Channels {
  63. c.GuildID = guild.ID
  64. }
  65. s.Guilds = append(s.Guilds, guild)
  66. return nil
  67. }
  68. // GuildRemove removes a guild from current world state.
  69. func (s *State) GuildRemove(guild *Guild) error {
  70. if s == nil {
  71. return ErrNilState
  72. }
  73. s.Lock()
  74. defer s.Unlock()
  75. for i, g := range s.Guilds {
  76. if g.ID == guild.ID {
  77. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  78. return nil
  79. }
  80. }
  81. return errors.New("Guild not found.")
  82. }
  83. // Guild gets a guild by ID.
  84. // Useful for querying if @me is in a guild:
  85. // _, err := discordgo.Session.State.Guild(guildID)
  86. // isInGuild := err == nil
  87. func (s *State) Guild(guildID string) (*Guild, error) {
  88. if s == nil {
  89. return nil, ErrNilState
  90. }
  91. s.RLock()
  92. defer s.RUnlock()
  93. for _, g := range s.Guilds {
  94. if g.ID == guildID {
  95. return g, nil
  96. }
  97. }
  98. return nil, errors.New("Guild not found.")
  99. }
  100. // TODO: Consider moving Guild state update methods onto *Guild.
  101. // MemberAdd adds a member to the current world state, or
  102. // updates it if it already exists.
  103. func (s *State) MemberAdd(member *Member) error {
  104. if s == nil {
  105. return ErrNilState
  106. }
  107. guild, err := s.Guild(member.GuildID)
  108. if err != nil {
  109. return err
  110. }
  111. s.Lock()
  112. defer s.Unlock()
  113. for i, m := range guild.Members {
  114. if m.User.ID == member.User.ID {
  115. guild.Members[i] = member
  116. return nil
  117. }
  118. }
  119. guild.Members = append(guild.Members, member)
  120. return nil
  121. }
  122. // MemberRemove removes a member from current world state.
  123. func (s *State) MemberRemove(member *Member) error {
  124. if s == nil {
  125. return ErrNilState
  126. }
  127. guild, err := s.Guild(member.GuildID)
  128. if err != nil {
  129. return err
  130. }
  131. s.Lock()
  132. defer s.Unlock()
  133. for i, m := range guild.Members {
  134. if m.User.ID == member.User.ID {
  135. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  136. return nil
  137. }
  138. }
  139. return errors.New("Member not found.")
  140. }
  141. // Member gets a member by ID from a guild.
  142. func (s *State) Member(guildID, userID string) (*Member, error) {
  143. if s == nil {
  144. return nil, ErrNilState
  145. }
  146. guild, err := s.Guild(guildID)
  147. if err != nil {
  148. return nil, err
  149. }
  150. s.RLock()
  151. defer s.RUnlock()
  152. for _, m := range guild.Members {
  153. if m.User.ID == userID {
  154. return m, nil
  155. }
  156. }
  157. return nil, errors.New("Member not found.")
  158. }
  159. // ChannelAdd adds a guild to the current world state, or
  160. // updates it if it already exists.
  161. // Channels may exist either as PrivateChannels or inside
  162. // a guild.
  163. func (s *State) ChannelAdd(channel *Channel) error {
  164. if s == nil {
  165. return ErrNilState
  166. }
  167. if channel.IsPrivate {
  168. s.Lock()
  169. defer s.Unlock()
  170. // If the channel exists, replace it.
  171. for i, c := range s.PrivateChannels {
  172. if c.ID == channel.ID {
  173. // Don't stomp on messages.
  174. channel.Messages = c.Messages
  175. s.PrivateChannels[i] = channel
  176. return nil
  177. }
  178. }
  179. s.PrivateChannels = append(s.PrivateChannels, channel)
  180. } else {
  181. guild, err := s.Guild(channel.GuildID)
  182. if err != nil {
  183. return err
  184. }
  185. s.Lock()
  186. defer s.Unlock()
  187. // If the channel exists, replace it.
  188. for i, c := range guild.Channels {
  189. if c.ID == channel.ID {
  190. // Don't stomp on messages.
  191. channel.Messages = c.Messages
  192. guild.Channels[i] = channel
  193. return nil
  194. }
  195. }
  196. guild.Channels = append(guild.Channels, channel)
  197. }
  198. return nil
  199. }
  200. // ChannelRemove removes a channel from current world state.
  201. func (s *State) ChannelRemove(channel *Channel) error {
  202. if s == nil {
  203. return ErrNilState
  204. }
  205. if channel.IsPrivate {
  206. s.Lock()
  207. defer s.Unlock()
  208. for i, c := range s.PrivateChannels {
  209. if c.ID == channel.ID {
  210. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  211. return nil
  212. }
  213. }
  214. } else {
  215. guild, err := s.Guild(channel.GuildID)
  216. if err != nil {
  217. return err
  218. }
  219. s.Lock()
  220. defer s.Unlock()
  221. for i, c := range guild.Channels {
  222. if c.ID == channel.ID {
  223. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  224. return nil
  225. }
  226. }
  227. }
  228. return errors.New("Channel not found.")
  229. }
  230. // GuildChannel gets a channel by ID from a guild.
  231. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  232. if s == nil {
  233. return nil, ErrNilState
  234. }
  235. guild, err := s.Guild(guildID)
  236. if err != nil {
  237. return nil, err
  238. }
  239. s.RLock()
  240. defer s.RUnlock()
  241. for _, c := range guild.Channels {
  242. if c.ID == channelID {
  243. return c, nil
  244. }
  245. }
  246. return nil, errors.New("Channel not found.")
  247. }
  248. // PrivateChannel gets a private channel by ID.
  249. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  250. if s == nil {
  251. return nil, ErrNilState
  252. }
  253. s.RLock()
  254. defer s.RUnlock()
  255. for _, c := range s.PrivateChannels {
  256. if c.ID == channelID {
  257. return c, nil
  258. }
  259. }
  260. return nil, errors.New("Channel not found.")
  261. }
  262. // Channel gets a channel by ID, it will look in all guilds an private channels.
  263. func (s *State) Channel(channelID string) (*Channel, error) {
  264. if s == nil {
  265. return nil, ErrNilState
  266. }
  267. c, err := s.PrivateChannel(channelID)
  268. if err == nil {
  269. return c, nil
  270. }
  271. for _, g := range s.Guilds {
  272. c, err := s.GuildChannel(g.ID, channelID)
  273. if err == nil {
  274. return c, nil
  275. }
  276. }
  277. return nil, errors.New("Channel not found.")
  278. }
  279. // Emoji returns an emoji for a guild and emoji id.
  280. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  281. if s == nil {
  282. return nil, ErrNilState
  283. }
  284. guild, err := s.Guild(guildID)
  285. if err != nil {
  286. return nil, err
  287. }
  288. s.RLock()
  289. defer s.RUnlock()
  290. for _, e := range guild.Emojis {
  291. if e.ID == emojiID {
  292. return e, nil
  293. }
  294. }
  295. return nil, errors.New("Emoji not found.")
  296. }
  297. // EmojiAdd adds an emoji to the current world state.
  298. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  299. if s == nil {
  300. return ErrNilState
  301. }
  302. guild, err := s.Guild(guildID)
  303. if err != nil {
  304. return err
  305. }
  306. s.Lock()
  307. defer s.Unlock()
  308. for i, e := range guild.Emojis {
  309. if e.ID == emoji.ID {
  310. guild.Emojis[i] = emoji
  311. return nil
  312. }
  313. }
  314. guild.Emojis = append(guild.Emojis, emoji)
  315. return nil
  316. }
  317. // EmojisAdd adds multiple emojis to the world state.
  318. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  319. for _, e := range emojis {
  320. if err := s.EmojiAdd(guildID, e); err != nil {
  321. return err
  322. }
  323. }
  324. return nil
  325. }
  326. // MessageAdd adds a message to the current world state, or updates it if it exists.
  327. // If the channel cannot be found, the message is discarded.
  328. // Messages are kept in state up to s.MaxMessageCount
  329. func (s *State) MessageAdd(message *Message) error {
  330. if s == nil {
  331. return ErrNilState
  332. }
  333. c, err := s.Channel(message.ChannelID)
  334. if err != nil {
  335. return err
  336. }
  337. s.Lock()
  338. defer s.Unlock()
  339. // If the message exists, replace it.
  340. for i, m := range c.Messages {
  341. if m.ID == message.ID {
  342. c.Messages[i] = message
  343. return nil
  344. }
  345. }
  346. c.Messages = append(c.Messages, message)
  347. if len(c.Messages) > s.MaxMessageCount {
  348. s.Unlock()
  349. for len(c.Messages) > s.MaxMessageCount {
  350. err := s.MessageRemove(c.Messages[0])
  351. if err != nil {
  352. log.Println("message remove error: ", err)
  353. }
  354. }
  355. s.Lock()
  356. }
  357. return nil
  358. }
  359. // MessageRemove removes a message from the world state.
  360. func (s *State) MessageRemove(message *Message) error {
  361. if s == nil {
  362. return ErrNilState
  363. }
  364. c, err := s.Channel(message.ChannelID)
  365. if err != nil {
  366. return err
  367. }
  368. s.Lock()
  369. defer s.Unlock()
  370. for i, m := range c.Messages {
  371. if m.ID == message.ID {
  372. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  373. return nil
  374. }
  375. }
  376. return errors.New("Message not found.")
  377. }
  378. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  379. guild, err := s.Guild(update.GuildID)
  380. if err != nil {
  381. return err
  382. }
  383. s.Lock()
  384. defer s.Unlock()
  385. // Handle Leaving Channel
  386. if update.ChannelID == "" {
  387. for i, state := range guild.VoiceStates {
  388. if state.UserID == update.UserID {
  389. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  390. return nil
  391. }
  392. }
  393. } else {
  394. for i, state := range guild.VoiceStates {
  395. if state.UserID == update.UserID {
  396. guild.VoiceStates[i] = update.VoiceState
  397. return nil
  398. }
  399. }
  400. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  401. }
  402. return nil
  403. }
  404. // Message gets a message by channel and message ID.
  405. func (s *State) Message(channelID, messageID string) (*Message, error) {
  406. if s == nil {
  407. return nil, ErrNilState
  408. }
  409. c, err := s.Channel(channelID)
  410. if err != nil {
  411. return nil, err
  412. }
  413. s.RLock()
  414. defer s.RUnlock()
  415. for _, m := range c.Messages {
  416. if m.ID == messageID {
  417. return m, nil
  418. }
  419. }
  420. return nil, errors.New("Message not found.")
  421. }
  422. // onInterface handles all events related to states.
  423. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  424. if s == nil {
  425. return ErrNilState
  426. }
  427. if !se.StateEnabled {
  428. return nil
  429. }
  430. switch t := i.(type) {
  431. case *Ready:
  432. err = s.OnReady(t)
  433. case *GuildCreate:
  434. err = s.GuildAdd(t.Guild)
  435. case *GuildUpdate:
  436. err = s.GuildAdd(t.Guild)
  437. case *GuildDelete:
  438. err = s.GuildRemove(t.Guild)
  439. case *GuildMemberAdd:
  440. err = s.MemberAdd(t.Member)
  441. case *GuildMemberUpdate:
  442. err = s.MemberAdd(t.Member)
  443. case *GuildMemberRemove:
  444. err = s.MemberRemove(t.Member)
  445. case *GuildEmojisUpdate:
  446. err = s.EmojisAdd(t.GuildID, t.Emojis)
  447. case *ChannelCreate:
  448. err = s.ChannelAdd(t.Channel)
  449. case *ChannelUpdate:
  450. err = s.ChannelAdd(t.Channel)
  451. case *ChannelDelete:
  452. err = s.ChannelRemove(t.Channel)
  453. case *MessageCreate:
  454. err = s.MessageAdd(t.Message)
  455. case *MessageUpdate:
  456. err = s.MessageAdd(t.Message)
  457. case *MessageDelete:
  458. err = s.MessageRemove(t.Message)
  459. case *VoiceStateUpdate:
  460. err = s.voiceStateUpdate(t)
  461. }
  462. return
  463. }