state.go 11 KB


  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import "errors"
  12. // ErrNilState is returned when the state is nil.
  13. var ErrNilState = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  14. // NewState creates an empty state.
  15. func NewState() *State {
  16. return &State{
  17. Ready: Ready{
  18. PrivateChannels: []*Channel{},
  19. Guilds: []*Guild{},
  20. },
  21. guildMap: make(map[string]*Guild),
  22. channelMap: make(map[string]*Channel),
  23. }
  24. }
  25. // OnReady takes a Ready event and updates all internal state.
  26. func (s *State) OnReady(r *Ready) error {
  27. if s == nil {
  28. return ErrNilState
  29. }
  30. s.Lock()
  31. defer s.Unlock()
  32. s.Ready = *r
  33. for _, g := range s.Guilds {
  34. s.guildMap[g.ID] = g
  35. for _, c := range g.Channels {
  36. c.GuildID = g.ID
  37. s.channelMap[c.ID] = c
  38. }
  39. }
  40. for _, c := range s.PrivateChannels {
  41. s.channelMap[c.ID] = c
  42. }
  43. return nil
  44. }
  45. // GuildAdd adds a guild to the current world state, or
  46. // updates it if it already exists.
  47. func (s *State) GuildAdd(guild *Guild) error {
  48. if s == nil {
  49. return ErrNilState
  50. }
  51. s.Lock()
  52. defer s.Unlock()
  53. // Update the channels to point to the right guild, adding them to the channelMap as we go
  54. for _, c := range guild.Channels {
  55. c.GuildID = guild.ID
  56. s.channelMap[c.ID] = c
  57. }
  58. // If the guild exists, replace it.
  59. if g, ok := s.guildMap[guild.ID]; ok {
  60. // If this guild already exists with data, don't stomp on props.
  61. if g.Unavailable != nil && !*g.Unavailable {
  62. guild.Members = g.Members
  63. guild.Presences = g.Presences
  64. guild.Channels = g.Channels
  65. guild.VoiceStates = g.VoiceStates
  66. }
  67. *g = *guild
  68. return nil
  69. }
  70. s.Guilds = append(s.Guilds, guild)
  71. s.guildMap[guild.ID] = guild
  72. return nil
  73. }
  74. // GuildRemove removes a guild from current world state.
  75. func (s *State) GuildRemove(guild *Guild) error {
  76. if s == nil {
  77. return ErrNilState
  78. }
  79. _, err := s.Guild(guild.ID)
  80. if err != nil {
  81. return err
  82. }
  83. s.Lock()
  84. defer s.Unlock()
  85. for i, g := range s.Guilds {
  86. if g.ID == guild.ID {
  87. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  88. return nil
  89. }
  90. }
  91. delete(s.guildMap, guild.ID)
  92. return nil
  93. }
  94. // Guild gets a guild by ID.
  95. // Useful for querying if @me is in a guild:
  96. // _, err := discordgo.Session.State.Guild(guildID)
  97. // isInGuild := err == nil
  98. func (s *State) Guild(guildID string) (*Guild, error) {
  99. if s == nil {
  100. return nil, ErrNilState
  101. }
  102. s.RLock()
  103. defer s.RUnlock()
  104. if g, ok := s.guildMap[guildID]; ok {
  105. return g, nil
  106. }
  107. return nil, errors.New("Guild not found.")
  108. }
  109. // TODO: Consider moving Guild state update methods onto *Guild.
  110. // MemberAdd adds a member to the current world state, or
  111. // updates it if it already exists.
  112. func (s *State) MemberAdd(member *Member) error {
  113. if s == nil {
  114. return ErrNilState
  115. }
  116. guild, err := s.Guild(member.GuildID)
  117. if err != nil {
  118. return err
  119. }
  120. s.Lock()
  121. defer s.Unlock()
  122. for i, m := range guild.Members {
  123. if m.User.ID == member.User.ID {
  124. guild.Members[i] = member
  125. return nil
  126. }
  127. }
  128. guild.Members = append(guild.Members, member)
  129. return nil
  130. }
  131. // MemberRemove removes a member from current world state.
  132. func (s *State) MemberRemove(member *Member) error {
  133. if s == nil {
  134. return ErrNilState
  135. }
  136. guild, err := s.Guild(member.GuildID)
  137. if err != nil {
  138. return err
  139. }
  140. s.Lock()
  141. defer s.Unlock()
  142. for i, m := range guild.Members {
  143. if m.User.ID == member.User.ID {
  144. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  145. return nil
  146. }
  147. }
  148. return errors.New("Member not found.")
  149. }
  150. // Member gets a member by ID from a guild.
  151. func (s *State) Member(guildID, userID string) (*Member, error) {
  152. if s == nil {
  153. return nil, ErrNilState
  154. }
  155. guild, err := s.Guild(guildID)
  156. if err != nil {
  157. return nil, err
  158. }
  159. s.RLock()
  160. defer s.RUnlock()
  161. for _, m := range guild.Members {
  162. if m.User.ID == userID {
  163. return m, nil
  164. }
  165. }
  166. return nil, errors.New("Member not found.")
  167. }
  168. // ChannelAdd adds a guild to the current world state, or
  169. // updates it if it already exists.
  170. // Channels may exist either as PrivateChannels or inside
  171. // a guild.
  172. func (s *State) ChannelAdd(channel *Channel) error {
  173. if s == nil {
  174. return ErrNilState
  175. }
  176. s.Lock()
  177. defer s.Unlock()
  178. // If the channel exists, replace it
  179. if c, ok := s.channelMap[channel.ID]; ok {
  180. channel.Messages = c.Messages
  181. channel.PermissionOverwrites = c.PermissionOverwrites
  182. *c = *channel
  183. return nil
  184. }
  185. if channel.IsPrivate {
  186. s.PrivateChannels = append(s.PrivateChannels, channel)
  187. } else {
  188. guild, ok := s.guildMap[channel.GuildID]
  189. if !ok {
  190. return errors.New("Guild for channel not found.")
  191. }
  192. guild.Channels = append(guild.Channels, channel)
  193. }
  194. s.channelMap[channel.ID] = channel
  195. return nil
  196. }
  197. // ChannelRemove removes a channel from current world state.
  198. func (s *State) ChannelRemove(channel *Channel) error {
  199. if s == nil {
  200. return ErrNilState
  201. }
  202. _, err := s.Channel(channel.ID)
  203. if err != nil {
  204. return err
  205. }
  206. if channel.IsPrivate {
  207. s.Lock()
  208. defer s.Unlock()
  209. for i, c := range s.PrivateChannels {
  210. if c.ID == channel.ID {
  211. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  212. return nil
  213. }
  214. }
  215. } else {
  216. guild, err := s.Guild(channel.GuildID)
  217. if err != nil {
  218. return err
  219. }
  220. s.Lock()
  221. defer s.Unlock()
  222. for i, c := range guild.Channels {
  223. if c.ID == channel.ID {
  224. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  225. return nil
  226. }
  227. }
  228. }
  229. delete(s.channelMap, channel.ID)
  230. return nil
  231. }
  232. // GuildChannel gets a channel by ID from a guild.
  233. // This method is Deprecated, use Channel(channelID)
  234. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  235. return s.Channel(channelID)
  236. }
  237. // PrivateChannel gets a private channel by ID.
  238. // This method is Deprecated, use Channel(channelID)
  239. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  240. return s.Channel(channelID)
  241. }
  242. // Channel gets a channel by ID, it will look in all guilds an private channels.
  243. func (s *State) Channel(channelID string) (*Channel, error) {
  244. if s == nil {
  245. return nil, ErrNilState
  246. }
  247. if c, ok := s.channelMap[channelID]; ok {
  248. return c, nil
  249. }
  250. return nil, errors.New("Channel not found.")
  251. }
  252. // Emoji returns an emoji for a guild and emoji id.
  253. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  254. if s == nil {
  255. return nil, ErrNilState
  256. }
  257. guild, err := s.Guild(guildID)
  258. if err != nil {
  259. return nil, err
  260. }
  261. s.RLock()
  262. defer s.RUnlock()
  263. for _, e := range guild.Emojis {
  264. if e.ID == emojiID {
  265. return e, nil
  266. }
  267. }
  268. return nil, errors.New("Emoji not found.")
  269. }
  270. // EmojiAdd adds an emoji to the current world state.
  271. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  272. if s == nil {
  273. return ErrNilState
  274. }
  275. guild, err := s.Guild(guildID)
  276. if err != nil {
  277. return err
  278. }
  279. s.Lock()
  280. defer s.Unlock()
  281. for i, e := range guild.Emojis {
  282. if e.ID == emoji.ID {
  283. guild.Emojis[i] = emoji
  284. return nil
  285. }
  286. }
  287. guild.Emojis = append(guild.Emojis, emoji)
  288. return nil
  289. }
  290. // EmojisAdd adds multiple emojis to the world state.
  291. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  292. for _, e := range emojis {
  293. if err := s.EmojiAdd(guildID, e); err != nil {
  294. return err
  295. }
  296. }
  297. return nil
  298. }
  299. // MessageAdd adds a message to the current world state, or updates it if it exists.
  300. // If the channel cannot be found, the message is discarded.
  301. // Messages are kept in state up to s.MaxMessageCount
  302. func (s *State) MessageAdd(message *Message) error {
  303. if s == nil {
  304. return ErrNilState
  305. }
  306. if s.MaxMessageCount == 0 {
  307. return nil
  308. }
  309. c, err := s.Channel(message.ChannelID)
  310. if err != nil {
  311. return err
  312. }
  313. s.Lock()
  314. defer s.Unlock()
  315. // If the message exists, replace it.
  316. for i, m := range c.Messages {
  317. if m.ID == message.ID {
  318. c.Messages[i] = message
  319. return nil
  320. }
  321. }
  322. c.Messages = append(c.Messages, message)
  323. if len(c.Messages) > s.MaxMessageCount {
  324. c.Messages = c.Messages[len(c.Messages)-s.MaxMessageCount:]
  325. }
  326. return nil
  327. }
  328. // MessageRemove removes a message from the world state.
  329. func (s *State) MessageRemove(message *Message) error {
  330. if s == nil {
  331. return ErrNilState
  332. }
  333. if s.MaxMessageCount == 0 {
  334. return nil
  335. }
  336. c, err := s.Channel(message.ChannelID)
  337. if err != nil {
  338. return err
  339. }
  340. s.Lock()
  341. defer s.Unlock()
  342. for i, m := range c.Messages {
  343. if m.ID == message.ID {
  344. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  345. return nil
  346. }
  347. }
  348. return errors.New("Message not found.")
  349. }
  350. func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error {
  351. guild, err := s.Guild(update.GuildID)
  352. if err != nil {
  353. return err
  354. }
  355. s.Lock()
  356. defer s.Unlock()
  357. // Handle Leaving Channel
  358. if update.ChannelID == "" {
  359. for i, state := range guild.VoiceStates {
  360. if state.UserID == update.UserID {
  361. guild.VoiceStates = append(guild.VoiceStates[:i], guild.VoiceStates[i+1:]...)
  362. return nil
  363. }
  364. }
  365. } else {
  366. for i, state := range guild.VoiceStates {
  367. if state.UserID == update.UserID {
  368. guild.VoiceStates[i] = update.VoiceState
  369. return nil
  370. }
  371. }
  372. guild.VoiceStates = append(guild.VoiceStates, update.VoiceState)
  373. }
  374. return nil
  375. }
  376. // Message gets a message by channel and message ID.
  377. func (s *State) Message(channelID, messageID string) (*Message, error) {
  378. if s == nil {
  379. return nil, ErrNilState
  380. }
  381. c, err := s.Channel(channelID)
  382. if err != nil {
  383. return nil, err
  384. }
  385. s.RLock()
  386. defer s.RUnlock()
  387. for _, m := range c.Messages {
  388. if m.ID == messageID {
  389. return m, nil
  390. }
  391. }
  392. return nil, errors.New("Message not found.")
  393. }
  394. // onInterface handles all events related to states.
  395. func (s *State) onInterface(se *Session, i interface{}) (err error) {
  396. if s == nil {
  397. return ErrNilState
  398. }
  399. if !se.StateEnabled {
  400. return nil
  401. }
  402. switch t := i.(type) {
  403. case *Ready:
  404. err = s.OnReady(t)
  405. case *GuildCreate:
  406. err = s.GuildAdd(t.Guild)
  407. case *GuildUpdate:
  408. err = s.GuildAdd(t.Guild)
  409. case *GuildDelete:
  410. err = s.GuildRemove(t.Guild)
  411. case *GuildMemberAdd:
  412. err = s.MemberAdd(t.Member)
  413. case *GuildMemberUpdate:
  414. err = s.MemberAdd(t.Member)
  415. case *GuildMemberRemove:
  416. err = s.MemberRemove(t.Member)
  417. case *GuildEmojisUpdate:
  418. err = s.EmojisAdd(t.GuildID, t.Emojis)
  419. case *ChannelCreate:
  420. err = s.ChannelAdd(t.Channel)
  421. case *ChannelUpdate:
  422. err = s.ChannelAdd(t.Channel)
  423. case *ChannelDelete:
  424. err = s.ChannelRemove(t.Channel)
  425. case *MessageCreate:
  426. err = s.MessageAdd(t.Message)
  427. case *MessageUpdate:
  428. err = s.MessageAdd(t.Message)
  429. case *MessageDelete:
  430. err = s.MessageRemove(t.Message)
  431. case *VoiceStateUpdate:
  432. err = s.voiceStateUpdate(t)
  433. }
  434. return
  435. }