state.go 10 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import (
  12. "errors"
  13. "fmt"
  14. )
  15. // ErrNilState is returned when the state is nil.
  16. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  17. // NewState creates an empty state.
  18. func NewState() *State {
  19. return &State{
  20. Ready: Ready{
  21. PrivateChannels: []*Channel{},
  22. Guilds: []*Guild{},
  23. },
  24. }
  25. }
  26. // OnReady takes a Ready event and updates all internal state.
  27. func (s *State) OnReady(r *Ready) error {
  28. if s == nil {
  29. return ErrNilState
  30. }
  31. s.Lock()
  32. defer s.Unlock()
  33. s.Ready = *r
  34. return nil
  35. }
  36. // GuildAdd adds a guild to the current world state, or
  37. // updates it if it already exists.
  38. func (s *State) GuildAdd(guild *Guild) error {
  39. if s == nil {
  40. return ErrNilState
  41. }
  42. s.Lock()
  43. defer s.Unlock()
  44. // If the guild exists, replace it.
  45. for i, g := range s.Guilds {
  46. if g.ID == guild.ID {
  47. // Don't stomp on properties that don't come in updates.
  48. guild.Members = g.Members
  49. guild.Presences = g.Presences
  50. guild.Channels = g.Channels
  51. guild.VoiceStates = g.VoiceStates
  52. s.Guilds[i] = guild
  53. return nil
  54. }
  55. }
  56. s.Guilds = append(s.Guilds, guild)
  57. return nil
  58. }
  59. // GuildRemove removes a guild from current world state.
  60. func (s *State) GuildRemove(guild *Guild) error {
  61. if s == nil {
  62. return ErrNilState
  63. }
  64. s.Lock()
  65. defer s.Unlock()
  66. for i, g := range s.Guilds {
  67. if g.ID == guild.ID {
  68. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  69. return nil
  70. }
  71. }
  72. return errors.New("Guild not found.")
  73. }
  74. // Guild gets a guild by ID.
  75. // Useful for querying if @me is in a guild:
  76. // _, err := discordgo.Session.State.Guild(guildID)
  77. // isInGuild := err == nil
  78. func (s *State) Guild(guildID string) (*Guild, error) {
  79. if s == nil {
  80. return nil, ErrNilState
  81. }
  82. s.RLock()
  83. defer s.RUnlock()
  84. for _, g := range s.Guilds {
  85. if g.ID == guildID {
  86. return g, nil
  87. }
  88. }
  89. return nil, errors.New("Guild not found.")
  90. }
  91. // TODO: Consider moving Guild state update methods onto *Guild.
  92. // MemberAdd adds a member to the current world state, or
  93. // updates it if it already exists.
  94. func (s *State) MemberAdd(member *Member) error {
  95. if s == nil {
  96. return ErrNilState
  97. }
  98. guild, err := s.Guild(member.GuildID)
  99. if err != nil {
  100. return err
  101. }
  102. s.Lock()
  103. defer s.Unlock()
  104. for i, m := range guild.Members {
  105. if m.User.ID == member.User.ID {
  106. guild.Members[i] = member
  107. return nil
  108. }
  109. }
  110. guild.Members = append(guild.Members, member)
  111. return nil
  112. }
  113. // MemberRemove removes a member from current world state.
  114. func (s *State) MemberRemove(member *Member) error {
  115. if s == nil {
  116. return ErrNilState
  117. }
  118. guild, err := s.Guild(member.GuildID)
  119. if err != nil {
  120. return err
  121. }
  122. s.Lock()
  123. defer s.Unlock()
  124. for i, m := range guild.Members {
  125. if m.User.ID == member.User.ID {
  126. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  127. return nil
  128. }
  129. }
  130. return errors.New("Member not found.")
  131. }
  132. // Member gets a member by ID from a guild.
  133. func (s *State) Member(guildID, userID string) (*Member, error) {
  134. if s == nil {
  135. return nil, ErrNilState
  136. }
  137. guild, err := s.Guild(guildID)
  138. if err != nil {
  139. return nil, err
  140. }
  141. s.RLock()
  142. defer s.RUnlock()
  143. for _, m := range guild.Members {
  144. if m.User.ID == userID {
  145. return m, nil
  146. }
  147. }
  148. return nil, errors.New("Member not found.")
  149. }
  150. // ChannelAdd adds a guild to the current world state, or
  151. // updates it if it already exists.
  152. // Channels may exist either as PrivateChannels or inside
  153. // a guild.
  154. func (s *State) ChannelAdd(channel *Channel) error {
  155. if s == nil {
  156. return ErrNilState
  157. }
  158. if channel.IsPrivate {
  159. s.Lock()
  160. defer s.Unlock()
  161. // If the channel exists, replace it.
  162. for i, c := range s.PrivateChannels {
  163. if c.ID == channel.ID {
  164. // Don't stomp on messages.
  165. channel.Messages = c.Messages
  166. s.PrivateChannels[i] = channel
  167. return nil
  168. }
  169. }
  170. s.PrivateChannels = append(s.PrivateChannels, channel)
  171. } else {
  172. guild, err := s.Guild(channel.GuildID)
  173. if err != nil {
  174. return err
  175. }
  176. s.Lock()
  177. defer s.Unlock()
  178. // If the channel exists, replace it.
  179. for i, c := range guild.Channels {
  180. if c.ID == channel.ID {
  181. // Don't stomp on messages.
  182. channel.Messages = c.Messages
  183. guild.Channels[i] = channel
  184. return nil
  185. }
  186. }
  187. guild.Channels = append(guild.Channels, channel)
  188. }
  189. return nil
  190. }
  191. // ChannelRemove removes a channel from current world state.
  192. func (s *State) ChannelRemove(channel *Channel) error {
  193. if s == nil {
  194. return ErrNilState
  195. }
  196. if channel.IsPrivate {
  197. s.Lock()
  198. defer s.Unlock()
  199. for i, c := range s.PrivateChannels {
  200. if c.ID == channel.ID {
  201. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  202. return nil
  203. }
  204. }
  205. } else {
  206. guild, err := s.Guild(channel.GuildID)
  207. if err != nil {
  208. return err
  209. }
  210. s.Lock()
  211. defer s.Unlock()
  212. for i, c := range guild.Channels {
  213. if c.ID == channel.ID {
  214. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  215. return nil
  216. }
  217. }
  218. }
  219. return errors.New("Channel not found.")
  220. }
  221. // GuildChannel gets a channel by ID from a guild.
  222. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  223. if s == nil {
  224. return nil, ErrNilState
  225. }
  226. guild, err := s.Guild(guildID)
  227. if err != nil {
  228. return nil, err
  229. }
  230. s.RLock()
  231. defer s.RUnlock()
  232. for _, c := range guild.Channels {
  233. if c.ID == channelID {
  234. return c, nil
  235. }
  236. }
  237. return nil, errors.New("Channel not found.")
  238. }
  239. // PrivateChannel gets a private channel by ID.
  240. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  241. if s == nil {
  242. return nil, ErrNilState
  243. }
  244. s.RLock()
  245. defer s.RUnlock()
  246. for _, c := range s.PrivateChannels {
  247. if c.ID == channelID {
  248. return c, nil
  249. }
  250. }
  251. return nil, errors.New("Channel not found.")
  252. }
  253. // Channel gets a channel by ID, it will look in all guilds an private channels.
  254. func (s *State) Channel(channelID string) (*Channel, error) {
  255. if s == nil {
  256. return nil, ErrNilState
  257. }
  258. c, err := s.PrivateChannel(channelID)
  259. if err == nil {
  260. return c, nil
  261. }
  262. for _, g := range s.Guilds {
  263. c, err := s.GuildChannel(g.ID, channelID)
  264. if err == nil {
  265. return c, nil
  266. }
  267. }
  268. return nil, errors.New("Channel not found.")
  269. }
  270. // Emoji returns an emoji for a guild and emoji id.
  271. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  272. if s == nil {
  273. return nil, ErrNilState
  274. }
  275. guild, err := s.Guild(guildID)
  276. if err != nil {
  277. return nil, err
  278. }
  279. s.RLock()
  280. defer s.RUnlock()
  281. for _, e := range guild.Emojis {
  282. if e.ID == emojiID {
  283. return e, nil
  284. }
  285. }
  286. return nil, errors.New("Emoji not found.")
  287. }
  288. // EmojiAdd adds an emoji to the current world state.
  289. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  290. if s == nil {
  291. return ErrNilState
  292. }
  293. guild, err := s.Guild(guildID)
  294. if err != nil {
  295. return err
  296. }
  297. s.Lock()
  298. defer s.Unlock()
  299. for i, e := range guild.Emojis {
  300. if e.ID == emoji.ID {
  301. guild.Emojis[i] = emoji
  302. return nil
  303. }
  304. }
  305. guild.Emojis = append(guild.Emojis, emoji)
  306. return nil
  307. }
  308. // EmojisAdd adds multiple emojis to the world state.
  309. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  310. for _, e := range emojis {
  311. if err := s.EmojiAdd(guildID, e); err != nil {
  312. return err
  313. }
  314. }
  315. return nil
  316. }
  317. // MessageAdd adds a message to the current world state, or updates it if it exists.
  318. // If the channel cannot be found, the message is discarded.
  319. // Messages are kept in state up to s.MaxMessageCount
  320. func (s *State) MessageAdd(message *Message) error {
  321. if s == nil {
  322. return ErrNilState
  323. }
  324. c, err := s.Channel(message.ChannelID)
  325. if err != nil {
  326. return err
  327. }
  328. s.Lock()
  329. defer s.Unlock()
  330. // If the message exists, replace it.
  331. for i, m := range c.Messages {
  332. if m.ID == message.ID {
  333. c.Messages[i] = message
  334. return nil
  335. }
  336. }
  337. c.Messages = append(c.Messages, message)
  338. if len(c.Messages) > s.MaxMessageCount {
  339. s.Unlock()
  340. for len(c.Messages) > s.MaxMessageCount {
  341. err := s.MessageRemove(c.Messages[0])
  342. if err != nil {
  343. fmt.Println("message remove error: ", err)
  344. }
  345. }
  346. s.Lock()
  347. }
  348. return nil
  349. }
  350. // MessageRemove removes a message from the world state.
  351. func (s *State) MessageRemove(message *Message) error {
  352. if s == nil {
  353. return ErrNilState
  354. }
  355. c, err := s.Channel(message.ChannelID)
  356. if err != nil {
  357. return err
  358. }
  359. s.Lock()
  360. defer s.Unlock()
  361. for i, m := range c.Messages {
  362. if m.ID == message.ID {
  363. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  364. return nil
  365. }
  366. }
  367. return errors.New("Message not found.")
  368. }
  369. // Message gets a message by channel and message ID.
  370. func (s *State) Message(channelID, messageID string) (*Message, error) {
  371. if s == nil {
  372. return nil, ErrNilState
  373. }
  374. c, err := s.Channel(channelID)
  375. if err != nil {
  376. return nil, err
  377. }
  378. s.RLock()
  379. defer s.RUnlock()
  380. for _, m := range c.Messages {
  381. if m.ID == messageID {
  382. return m, nil
  383. }
  384. }
  385. return nil, errors.New("Message not found.")
  386. }
  387. // onInterface handles all events related to states.
  388. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  389. if s == nil {
  390. return ErrNilState
  391. }
  392. if !se.StateEnabled {
  393. return nil
  394. }
  395. switch t := i.(type) {
  396. case *Ready:
  397. err = s.OnReady(t)
  398. case *GuildCreate:
  399. err = s.GuildAdd(t.Guild)
  400. case *GuildUpdate:
  401. err = s.GuildAdd(t.Guild)
  402. case *GuildDelete:
  403. err = s.GuildRemove(t.Guild)
  404. case *GuildMemberAdd:
  405. err = s.MemberAdd(t.Member)
  406. case *GuildMemberUpdate:
  407. err = s.MemberAdd(t.Member)
  408. case *GuildMemberRemove:
  409. err = s.MemberRemove(t.Member)
  410. case *GuildEmojisUpdate:
  411. err = s.EmojisAdd(t.GuildID, t.Emojis)
  412. case *ChannelCreate:
  413. err = s.ChannelAdd(t.Channel)
  414. case *ChannelUpdate:
  415. err = s.ChannelAdd(t.Channel)
  416. case *ChannelDelete:
  417. err = s.ChannelRemove(t.Channel)
  418. case *MessageCreate:
  419. err = s.MessageAdd(t.Message)
  420. case *MessageUpdate:
  421. err = s.MessageAdd(t.Message)
  422. case *MessageDelete:
  423. err = s.MessageRemove(t.Message)
  424. }
  425. return
  426. }