state.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. // Discordgo - Discord bindings for Go
  2. // Available at https://github.com/bwmarrin/discordgo
  3. // Copyright 2015-2016 Bruce Marriner <bruce@sqls.net>. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // This file contains code related to state tracking. If enabled, state
  7. // tracking will capture the initial READY packet and many other websocket
  8. // events and maintain an in-memory state of of guilds, channels, users, and
  9. // so forth. This information can be accessed through the Session.State struct.
  10. package discordgo
  11. import "errors"
  12. var nilError error = errors.New("State not instantiated, please use discordgo.New() or assign Session.State.")
  13. // NewState creates an empty state.
  14. func NewState() *State {
  15. return &State{
  16. Ready: Ready{
  17. PrivateChannels: []*Channel{},
  18. Guilds: []*Guild{},
  19. },
  20. }
  21. }
  22. // OnReady takes a Ready event and updates all internal state.
  23. func (s *State) OnReady(r *Ready) error {
  24. if s == nil {
  25. return nilError
  26. }
  27. s.Lock()
  28. defer s.Unlock()
  29. s.Ready = *r
  30. return nil
  31. }
  32. // GuildAdd adds a guild to the current world state, or
  33. // updates it if it already exists.
  34. func (s *State) GuildAdd(guild *Guild) error {
  35. if s == nil {
  36. return nilError
  37. }
  38. s.Lock()
  39. defer s.Unlock()
  40. // If the guild exists, replace it.
  41. for i, g := range s.Guilds {
  42. if g.ID == guild.ID {
  43. s.Guilds[i] = guild
  44. return nil
  45. }
  46. }
  47. s.Guilds = append(s.Guilds, guild)
  48. return nil
  49. }
  50. // GuildRemove removes a guild from current world state.
  51. func (s *State) GuildRemove(guild *Guild) error {
  52. if s == nil {
  53. return nilError
  54. }
  55. s.Lock()
  56. defer s.Unlock()
  57. for i, g := range s.Guilds {
  58. if g.ID == guild.ID {
  59. s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
  60. return nil
  61. }
  62. }
  63. return errors.New("Guild not found.")
  64. }
  65. // Guild gets a guild by ID.
  66. // Useful for querying if @me is in a guild:
  67. // _, err := discordgo.Session.State.Guild(guildID)
  68. // isInGuild := err == nil
  69. func (s *State) Guild(guildID string) (*Guild, error) {
  70. if s == nil {
  71. return nil, nilError
  72. }
  73. s.RLock()
  74. defer s.RUnlock()
  75. for _, g := range s.Guilds {
  76. if g.ID == guildID {
  77. return g, nil
  78. }
  79. }
  80. return nil, errors.New("Guild not found.")
  81. }
  82. // TODO: Consider moving Guild state update methods onto *Guild.
  83. // MemberAdd adds a member to the current world state, or
  84. // updates it if it already exists.
  85. func (s *State) MemberAdd(member *Member) error {
  86. if s == nil {
  87. return nilError
  88. }
  89. s.Lock()
  90. defer s.Unlock()
  91. guild, err := s.Guild(member.GuildID)
  92. if err != nil {
  93. return err
  94. }
  95. for i, m := range guild.Members {
  96. if m.User.ID == member.User.ID {
  97. guild.Members[i] = member
  98. return nil
  99. }
  100. }
  101. guild.Members = append(guild.Members, member)
  102. return nil
  103. }
  104. // MemberRemove removes a member from current world state.
  105. func (s *State) MemberRemove(member *Member) error {
  106. if s == nil {
  107. return nilError
  108. }
  109. s.Lock()
  110. defer s.Unlock()
  111. guild, err := s.Guild(member.GuildID)
  112. if err != nil {
  113. return err
  114. }
  115. for i, m := range guild.Members {
  116. if m.User.ID == member.User.ID {
  117. guild.Members = append(guild.Members[:i], guild.Members[i+1:]...)
  118. return nil
  119. }
  120. }
  121. return errors.New("Member not found.")
  122. }
  123. // Member gets a member by ID from a guild.
  124. func (s *State) Member(guildID, userID string) (*Member, error) {
  125. if s == nil {
  126. return nil, nilError
  127. }
  128. s.RLock()
  129. defer s.RUnlock()
  130. guild, err := s.Guild(guildID)
  131. if err != nil {
  132. return nil, err
  133. }
  134. for _, m := range guild.Members {
  135. if m.User.ID == userID {
  136. return m, nil
  137. }
  138. }
  139. return nil, errors.New("Member not found.")
  140. }
  141. // ChannelAdd adds a guild to the current world state, or
  142. // updates it if it already exists.
  143. // Channels may exist either as PrivateChannels or inside
  144. // a guild.
  145. func (s *State) ChannelAdd(channel *Channel) error {
  146. if s == nil {
  147. return nilError
  148. }
  149. s.Lock()
  150. defer s.Unlock()
  151. if channel.IsPrivate {
  152. // If the channel exists, replace it.
  153. for i, c := range s.PrivateChannels {
  154. if c.ID == channel.ID {
  155. s.PrivateChannels[i] = channel
  156. return nil
  157. }
  158. }
  159. s.PrivateChannels = append(s.PrivateChannels, channel)
  160. } else {
  161. guild, err := s.Guild(channel.GuildID)
  162. if err != nil {
  163. return err
  164. }
  165. // If the channel exists, replace it.
  166. for i, c := range guild.Channels {
  167. if c.ID == channel.ID {
  168. guild.Channels[i] = channel
  169. return nil
  170. }
  171. }
  172. guild.Channels = append(guild.Channels, channel)
  173. }
  174. return nil
  175. }
  176. // ChannelRemove removes a channel from current world state.
  177. func (s *State) ChannelRemove(channel *Channel) error {
  178. if s == nil {
  179. return nilError
  180. }
  181. s.Lock()
  182. defer s.Unlock()
  183. if channel.IsPrivate {
  184. for i, c := range s.PrivateChannels {
  185. if c.ID == channel.ID {
  186. s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)
  187. return nil
  188. }
  189. }
  190. } else {
  191. guild, err := s.Guild(channel.GuildID)
  192. if err != nil {
  193. return err
  194. }
  195. for i, c := range guild.Channels {
  196. if c.ID == channel.ID {
  197. guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...)
  198. return nil
  199. }
  200. }
  201. }
  202. return errors.New("Channel not found.")
  203. }
  204. // GuildChannel gets a channel by ID from a guild.
  205. func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
  206. if s == nil {
  207. return nil, nilError
  208. }
  209. s.RLock()
  210. defer s.RUnlock()
  211. guild, err := s.Guild(guildID)
  212. if err != nil {
  213. return nil, err
  214. }
  215. for _, c := range guild.Channels {
  216. if c.ID == channelID {
  217. return c, nil
  218. }
  219. }
  220. return nil, errors.New("Channel not found.")
  221. }
  222. // PrivateChannel gets a private channel by ID.
  223. func (s *State) PrivateChannel(channelID string) (*Channel, error) {
  224. if s == nil {
  225. return nil, nilError
  226. }
  227. s.RLock()
  228. defer s.RUnlock()
  229. for _, c := range s.PrivateChannels {
  230. if c.ID == channelID {
  231. return c, nil
  232. }
  233. }
  234. return nil, errors.New("Channel not found.")
  235. }
  236. // Channel gets a channel by ID, it will look in all guilds an private channels.
  237. func (s *State) Channel(channelID string) (*Channel, error) {
  238. if s == nil {
  239. return nil, nilError
  240. }
  241. c, err := s.PrivateChannel(channelID)
  242. if err == nil {
  243. return c, nil
  244. }
  245. for _, g := range s.Guilds {
  246. c, err := s.GuildChannel(g.ID, channelID)
  247. if err == nil {
  248. return c, nil
  249. }
  250. }
  251. return nil, errors.New("Channel not found.")
  252. }
  253. // Emoji returns an emoji for a guild and emoji id.
  254. func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) {
  255. if s == nil {
  256. return nil, nilError
  257. }
  258. s.RLock()
  259. defer s.RUnlock()
  260. guild, err := s.Guild(guildID)
  261. if err != nil {
  262. return nil, err
  263. }
  264. for _, e := range guild.Emojis {
  265. if e.ID == emojiID {
  266. return e, nil
  267. }
  268. }
  269. return nil, errors.New("Emoji not found.")
  270. }
  271. // EmojiAdd adds an emoji to the current world state.
  272. func (s *State) EmojiAdd(guildID string, emoji *Emoji) error {
  273. if s == nil {
  274. return nilError
  275. }
  276. s.Lock()
  277. defer s.Unlock()
  278. guild, err := s.Guild(guildID)
  279. if err != nil {
  280. return err
  281. }
  282. for i, e := range guild.Emojis {
  283. if e.ID == emoji.ID {
  284. guild.Emojis[i] = emoji
  285. return nil
  286. }
  287. }
  288. guild.Emojis = append(guild.Emojis, emoji)
  289. return nil
  290. }
  291. // EmojisAdd adds multiple emojis to the world state.
  292. func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
  293. for _, e := range emojis {
  294. if err := s.EmojiAdd(guildID, e); err != nil {
  295. return err
  296. }
  297. }
  298. return nil
  299. }
  300. // MessageAdd adds a message to the current world state, or updates it if it exists.
  301. // If the channel cannot be found, the message is discarded.
  302. // Messages are kept in state up to s.MaxMessageCount
  303. func (s *State) MessageAdd(message *Message) error {
  304. if s == nil {
  305. return nilError
  306. }
  307. c, err := s.Channel(message.ChannelID)
  308. if err != nil {
  309. return err
  310. }
  311. s.Lock()
  312. defer s.Unlock()
  313. // If the message exists, replace it.
  314. for i, m := range c.Messages {
  315. if m.ID == message.ID {
  316. c.Messages[i] = message
  317. return nil
  318. }
  319. }
  320. c.Messages = append(c.Messages, message)
  321. if len(c.Messages) > s.MaxMessageCount {
  322. s.Unlock()
  323. for len(c.Messages) > s.MaxMessageCount {
  324. s.MessageRemove(c.Messages[0])
  325. }
  326. s.Lock()
  327. }
  328. return nil
  329. }
  330. // MessageRemove removes a message from the world state.
  331. func (s *State) MessageRemove(message *Message) error {
  332. if s == nil {
  333. return nilError
  334. }
  335. c, err := s.Channel(message.ChannelID)
  336. if err != nil {
  337. return err
  338. }
  339. s.Lock()
  340. defer s.Unlock()
  341. for i, m := range c.Messages {
  342. if m.ID == message.ID {
  343. c.Messages = append(c.Messages[:i], c.Messages[i+1:]...)
  344. return nil
  345. }
  346. }
  347. return errors.New("Message not found.")
  348. }
  349. // Message gets a message by channel and message ID.
  350. func (s *State) Message(channelID, messageID string) (*Message, error) {
  351. if s == nil {
  352. return nil, nilError
  353. }
  354. c, err := s.Channel(channelID)
  355. if err != nil {
  356. return nil, err
  357. }
  358. s.RLock()
  359. defer s.RUnlock()
  360. for _, m := range c.Messages {
  361. if m.ID == messageID {
  362. return m, nil
  363. }
  364. }
  365. return nil, errors.New("Message not found.")
  366. }