state.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "fmt"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // NewState creates an empty state.
  18. func NewState() *State {
  19. return &State{
  20. Ready: Ready{
  21. PrivateChannels: []*Channel{},
  22. Guilds: []*Guild{},
  23. },
  24. }
  25. }
  26. // OnReady takes a Ready event and updates all internal state.
  27. func (s *State) OnReady(r *Ready) error {
  28. if s == nil {
  29. return ErrNilState
  30. }
  31. s.Lock()
  32. defer s.Unlock()
  33. s.Ready = *r
  34. for _, g := range s.Guilds {
  35. for _, c := range g.Channels {
  36. c.GuildID = g.ID
  37. }
  38. }
  39. return nil
  40. }
  41. // GuildAdd adds a guild to the current world state, or
  42. // updates it if it already exists.
  43. func (s *State) GuildAdd(guild *Guild) error {
  44. if s == nil {
  45. return ErrNilState
  46. }
  47. s.Lock()
  48. defer s.Unlock()
  49. // If the guild exists, replace it.
  50. for i, g := range s.Guilds {
  51. if g.ID == guild.ID {
  52. // Don't stomp on properties that don't come in updates.
  53. guild.Members = g.Members
  54. guild.Presences = g.Presences
  55. guild.Channels = g.Channels
  56. guild.VoiceStates = g.VoiceStates
  57. s.Guilds[i] = guild
  58. return nil
  59. }
  60. }
  61. s.Guilds = append(s.Guilds, guild)
  62. return nil
  63. }
  64. // GuildRemove removes a guild from current world state.
  65. func (s *State) GuildRemove(guild *Guild) error {
  66. if s == nil {
  67. return ErrNilState
  68. }
  69. s.Lock()
  70. defer s.Unlock()
  71. for i, g := range s.Guilds {
  72. if g.ID == guild.ID {
  73. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  74. return nil
  75. }
  76. }
  77. return errors.New("Guild not found.")
  78. }
  79. // Guild gets a guild by ID.
  80. // Useful for querying if @me is in a guild:
  81. // _, err := discordgo.Session.State.Guild(guildID)
  82. // isInGuild := err == nil
  83. func (s *State) Guild(guildID string) (*Guild, error) {
  84. if s == nil {
  85. return nil, ErrNilState
  86. }
  87. s.RLock()
  88. defer s.RUnlock()
  89. for _, g := range s.Guilds {
  90. if g.ID == guildID {
  91. return g, nil
  92. }
  93. }
  94. return nil, errors.New("Guild not found.")
  95. }
  96. // TODO: Consider moving Guild state update methods onto *Guild.
  97. // MemberAdd adds a member to the current world state, or
  98. // updates it if it already exists.
  99. func (s *State) MemberAdd(member *Member) error {
  100. if s == nil {
  101. return ErrNilState
  102. }
  103. guild, err := s.Guild(member.GuildID)
  104. if err != nil {
  105. return err
  106. }
  107. s.Lock()
  108. defer s.Unlock()
  109. for i, m := range guild.Members {
  110. if m.User.ID == member.User.ID {
  111. guild.Members[i] = member
  112. return nil
  113. }
  114. }
  115. guild.Members = append(guild.Members, member)
  116. return nil
  117. }
  118. // MemberRemove removes a member from current world state.
  119. func (s *State) MemberRemove(member *Member) error {
  120. if s == nil {
  121. return ErrNilState
  122. }
  123. guild, err := s.Guild(member.GuildID)
  124. if err != nil {
  125. return err
  126. }
  127. s.Lock()
  128. defer s.Unlock()
  129. for i, m := range guild.Members {
  130. if m.User.ID == member.User.ID {
  131. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  132. return nil
  133. }
  134. }
  135. return errors.New("Member not found.")
  136. }
  137. // Member gets a member by ID from a guild.
  138. func (s *State) Member(guildID, userID string) (*Member, error) {
  139. if s == nil {
  140. return nil, ErrNilState
  141. }
  142. guild, err := s.Guild(guildID)
  143. if err != nil {
  144. return nil, err
  145. }
  146. s.RLock()
  147. defer s.RUnlock()
  148. for _, m := range guild.Members {
  149. if m.User.ID == userID {
  150. return m, nil
  151. }
  152. }
  153. return nil, errors.New("Member not found.")
  154. }
  155. // ChannelAdd adds a guild to the current world state, or
  156. // updates it if it already exists.
  157. // Channels may exist either as PrivateChannels or inside
  158. // a guild.
  159. func (s *State) ChannelAdd(channel *Channel) error {
  160. if s == nil {
  161. return ErrNilState
  162. }
  163. if channel.IsPrivate {
  164. s.Lock()
  165. defer s.Unlock()
  166. // If the channel exists, replace it.
  167. for i, c := range s.PrivateChannels {
  168. if c.ID == channel.ID {
  169. // Don't stomp on messages.
  170. channel.Messages = c.Messages
  171. s.PrivateChannels[i] = channel
  172. return nil
  173. }
  174. }
  175. s.PrivateChannels = append(s.PrivateChannels, channel)
  176. } else {
  177. guild, err := s.Guild(channel.GuildID)
  178. if err != nil {
  179. return err
  180. }
  181. s.Lock()
  182. defer s.Unlock()
  183. // If the channel exists, replace it.
  184. for i, c := range guild.Channels {
  185. if c.ID == channel.ID {
  186. // Don't stomp on messages.
  187. channel.Messages = c.Messages
  188. guild.Channels[i] = channel
  189. return nil
  190. }
  191. }
  192. guild.Channels = append(guild.Channels, channel)
  193. }
  194. return nil
  195. }
  196. // ChannelRemove removes a channel from current world state.
  197. func (s *State) ChannelRemove(channel *Channel) error {
  198. if s == nil {
  199. return ErrNilState
  200. }
  201. if channel.IsPrivate {
  202. s.Lock()
  203. defer s.Unlock()
  204. for i, c := range s.PrivateChannels {
  205. if c.ID == channel.ID {
  206. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  207. return nil
  208. }
  209. }
  210. } else {
  211. guild, err := s.Guild(channel.GuildID)
  212. if err != nil {
  213. return err
  214. }
  215. s.Lock()
  216. defer s.Unlock()
  217. for i, c := range guild.Channels {
  218. if c.ID == channel.ID {
  219. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  220. return nil
  221. }
  222. }
  223. }
  224. return errors.New("Channel not found.")
  225. }
  226. // GuildChannel gets a channel by ID from a guild.
  227. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  228. if s == nil {
  229. return nil, ErrNilState
  230. }
  231. guild, err := s.Guild(guildID)
  232. if err != nil {
  233. return nil, err
  234. }
  235. s.RLock()
  236. defer s.RUnlock()
  237. for _, c := range guild.Channels {
  238. if c.ID == channelID {
  239. return c, nil
  240. }
  241. }
  242. return nil, errors.New("Channel not found.")
  243. }
  244. // PrivateChannel gets a private channel by ID.
  245. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  246. if s == nil {
  247. return nil, ErrNilState
  248. }
  249. s.RLock()
  250. defer s.RUnlock()
  251. for _, c := range s.PrivateChannels {
  252. if c.ID == channelID {
  253. return c, nil
  254. }
  255. }
  256. return nil, errors.New("Channel not found.")
  257. }
  258. // Channel gets a channel by ID, it will look in all guilds an private channels.
  259. func (s *State) Channel(channelID string) (*Channel, error) {
  260. if s == nil {
  261. return nil, ErrNilState
  262. }
  263. c, err := s.PrivateChannel(channelID)
  264. if err == nil {
  265. return c, nil
  266. }
  267. for _, g := range s.Guilds {
  268. c, err := s.GuildChannel(g.ID, channelID)
  269. if err == nil {
  270. return c, nil
  271. }
  272. }
  273. return nil, errors.New("Channel not found.")
  274. }
  275. // Emoji returns an emoji for a guild and emoji id.
  276. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  277. if s == nil {
  278. return nil, ErrNilState
  279. }
  280. guild, err := s.Guild(guildID)
  281. if err != nil {
  282. return nil, err
  283. }
  284. s.RLock()
  285. defer s.RUnlock()
  286. for _, e := range guild.Emojis {
  287. if e.ID == emojiID {
  288. return e, nil
  289. }
  290. }
  291. return nil, errors.New("Emoji not found.")
  292. }
  293. // EmojiAdd adds an emoji to the current world state.
  294. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  295. if s == nil {
  296. return ErrNilState
  297. }
  298. guild, err := s.Guild(guildID)
  299. if err != nil {
  300. return err
  301. }
  302. s.Lock()
  303. defer s.Unlock()
  304. for i, e := range guild.Emojis {
  305. if e.ID == emoji.ID {
  306. guild.Emojis[i] = emoji
  307. return nil
  308. }
  309. }
  310. guild.Emojis = append(guild.Emojis, emoji)
  311. return nil
  312. }
  313. // EmojisAdd adds multiple emojis to the world state.
  314. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  315. for _, e := range emojis {
  316. if err := s.EmojiAdd(guildID, e); err != nil {
  317. return err
  318. }
  319. }
  320. return nil
  321. }
  322. // MessageAdd adds a message to the current world state, or updates it if it exists.
  323. // If the channel cannot be found, the message is discarded.
  324. // Messages are kept in state up to s.MaxMessageCount
  325. func (s *State) MessageAdd(message *Message) error {
  326. if s == nil {
  327. return ErrNilState
  328. }
  329. c, err := s.Channel(message.ChannelID)
  330. if err != nil {
  331. return err
  332. }
  333. s.Lock()
  334. defer s.Unlock()
  335. // If the message exists, replace it.
  336. for i, m := range c.Messages {
  337. if m.ID == message.ID {
  338. c.Messages[i] = message
  339. return nil
  340. }
  341. }
  342. c.Messages = append(c.Messages, message)
  343. if len(c.Messages) > s.MaxMessageCount {
  344. s.Unlock()
  345. for len(c.Messages) > s.MaxMessageCount {
  346. err := s.MessageRemove(c.Messages[0])
  347. if err != nil {
  348. fmt.Println("message remove error: ", err)
  349. }
  350. }
  351. s.Lock()
  352. }
  353. return nil
  354. }
  355. // MessageRemove removes a message from the world state.
  356. func (s *State) MessageRemove(message *Message) error {
  357. if s == nil {
  358. return ErrNilState
  359. }
  360. c, err := s.Channel(message.ChannelID)
  361. if err != nil {
  362. return err
  363. }
  364. s.Lock()
  365. defer s.Unlock()
  366. for i, m := range c.Messages {
  367. if m.ID == message.ID {
  368. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  369. return nil
  370. }
  371. }
  372. return errors.New("Message not found.")
  373. }
  374. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  375. guild, err := s.Guild(update.GuildID)
  376. if err != nil {
  377. return err
  378. }
  379. s.Lock()
  380. defer s.Unlock()
  381. // Handle Leaving Channel
  382. if update.ChannelID == "" {
  383. for i, state := range guild.VoiceStates {
  384. if state.UserID == update.UserID {
  385. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  386. return nil
  387. }
  388. }
  389. } else {
  390. for i, state := range guild.VoiceStates {
  391. if state.UserID == update.UserID {
  392. guild.VoiceStates[i] = update.VoiceState
  393. return nil
  394. }
  395. }
  396. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  397. }
  398. return nil
  399. }
  400. // Message gets a message by channel and message ID.
  401. func (s *State) Message(channelID, messageID string) (*Message, error) {
  402. if s == nil {
  403. return nil, ErrNilState
  404. }
  405. c, err := s.Channel(channelID)
  406. if err != nil {
  407. return nil, err
  408. }
  409. s.RLock()
  410. defer s.RUnlock()
  411. for _, m := range c.Messages {
  412. if m.ID == messageID {
  413. return m, nil
  414. }
  415. }
  416. return nil, errors.New("Message not found.")
  417. }
  418. // onInterface handles all events related to states.
  419. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  420. if s == nil {
  421. return ErrNilState
  422. }
  423. if !se.StateEnabled {
  424. return nil
  425. }
  426. switch t := i.(type) {
  427. case *Ready:
  428. err = s.OnReady(t)
  429. case *GuildCreate:
  430. err = s.GuildAdd(t.Guild)
  431. case *GuildUpdate:
  432. err = s.GuildAdd(t.Guild)
  433. case *GuildDelete:
  434. err = s.GuildRemove(t.Guild)
  435. case *GuildMemberAdd:
  436. err = s.MemberAdd(t.Member)
  437. case *GuildMemberUpdate:
  438. err = s.MemberAdd(t.Member)
  439. case *GuildMemberRemove:
  440. err = s.MemberRemove(t.Member)
  441. case *GuildEmojisUpdate:
  442. err = s.EmojisAdd(t.GuildID, t.Emojis)
  443. case *ChannelCreate:
  444. err = s.ChannelAdd(t.Channel)
  445. case *ChannelUpdate:
  446. err = s.ChannelAdd(t.Channel)
  447. case *ChannelDelete:
  448. err = s.ChannelRemove(t.Channel)
  449. case *MessageCreate:
  450. err = s.MessageAdd(t.Message)
  451. case *MessageUpdate:
  452. err = s.MessageAdd(t.Message)
  453. case *MessageDelete:
  454. err = s.MessageRemove(t.Message)
  455. case *VoiceStateUpdate:
  456. err = s.voiceStateUpdate(t)
  457. }
  458. return
  459. }