Browse Source

Make state tracking optional.

Chris Rhodes 9 years ago
parent
commit
9ba6d5b7c1
4 changed files with 104 additions and 30 deletions
  1. 2 1
      discord.go
  2. 70 18
      state.go
  3. 2 1
      structs.go
  4. 30 10
      wsapi.go

+ 2 - 1
discord.go

@@ -62,7 +62,8 @@ func New(args ...interface{}) (s *Session, err error) {
 
 	// Create an empty Session interface.
 	s = &Session{
-		State: NewState(),
+		State:        NewState(),
+		StateEnabled: true,
 	}
 
 	// If no arguments are passed return the empty Session interface.

+ 70 - 18
state.go

@@ -2,6 +2,8 @@ package discordgo
 
 import "errors"
 
+var nilError error = errors.New("State not instantiated, please use discordgo.New() or assign session.State.")
+
 // NewState creates an empty state.
 func NewState() *State {
 	return &State{
@@ -13,13 +15,22 @@ func NewState() *State {
 }
 
 // OnReady takes a Ready event and updates all internal state.
-func (s *State) OnReady(r *Ready) {
+func (s *State) OnReady(r *Ready) error {
+	if s == nil {
+		return nilError
+	}
+
 	s.Ready = *r
+	return nil
 }
 
 // AddGuild adds a guild to the current world state, or
 // updates it if it already exists.
-func (s *State) AddGuild(guild *Guild) {
+func (s *State) AddGuild(guild *Guild) error {
+	if s == nil {
+		return nilError
+	}
+
 	for _, g := range s.Guilds {
 		if g.ID == guild.ID {
 			// This could be a little faster ;)
@@ -29,14 +40,19 @@ func (s *State) AddGuild(guild *Guild) {
 			for _, c := range guild.Channels {
 				s.AddChannel(&c)
 			}
-			return
+			return nil
 		}
 	}
 	s.Guilds = append(s.Guilds, *guild)
+	return nil
 }
 
 // RemoveGuild removes a guild from current world state.
 func (s *State) RemoveGuild(guild *Guild) error {
+	if s == nil {
+		return nilError
+	}
+
 	for i, g := range s.Guilds {
 		if g.ID == guild.ID {
 			s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
@@ -46,11 +62,15 @@ func (s *State) RemoveGuild(guild *Guild) error {
 	return errors.New("Guild not found.")
 }
 
-// GetGuildByID gets a guild by ID.
+// Guild gets a guild by ID.
 // Useful for querying if @me is in a guild:
-//     _, err := discordgo.Session.State.GetGuildById(guildID)
+//     _, err := discordgo.Session.State.Guild(guildID)
 //     isInGuild := err == nil
-func (s *State) GetGuildByID(guildID string) (*Guild, error) {
+func (s *State) Guild(guildID string) (*Guild, error) {
+	if s == nil {
+		return nil, nilError
+	}
+
 	for _, g := range s.Guilds {
 		if g.ID == guildID {
 			return &g, nil
@@ -64,7 +84,11 @@ func (s *State) GetGuildByID(guildID string) (*Guild, error) {
 // AddMember adds a member to the current world state, or
 // updates it if it already exists.
 func (s *State) AddMember(member *Member) error {
-	guild, err := s.GetGuildByID(member.GuildID)
+	if s == nil {
+		return nilError
+	}
+
+	guild, err := s.Guild(member.GuildID)
 	if err != nil {
 		return err
 	}
@@ -82,7 +106,11 @@ func (s *State) AddMember(member *Member) error {
 
 // RemoveMember removes a member from current world state.
 func (s *State) RemoveMember(member *Member) error {
-	guild, err := s.GetGuildByID(member.GuildID)
+	if s == nil {
+		return nilError
+	}
+
+	guild, err := s.Guild(member.GuildID)
 	if err != nil {
 		return err
 	}
@@ -93,12 +121,17 @@ func (s *State) RemoveMember(member *Member) error {
 			return nil
 		}
 	}
+
 	return errors.New("Member not found.")
 }
 
-// GetMemberByID gets a member by ID from a guild.
-func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) {
-	guild, err := s.GetGuildByID(guildID)
+// Member gets a member by ID from a guild.
+func (s *State) Member(guildID string, userID string) (*Member, error) {
+	if s == nil {
+		return nil, nilError
+	}
+
+	guild, err := s.Guild(guildID)
 	if err != nil {
 		return nil, err
 	}
@@ -108,6 +141,7 @@ func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) {
 			return &m, nil
 		}
 	}
+
 	return nil, errors.New("Member not found.")
 }
 
@@ -116,6 +150,10 @@ func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) {
 // Channels may exist either as PrivateChannels or inside
 // a guild.
 func (s *State) AddChannel(channel *Channel) error {
+	if s == nil {
+		return nilError
+	}
+
 	if channel.IsPrivate {
 		for i, c := range s.PrivateChannels {
 			if c.ID == channel.ID {
@@ -126,7 +164,7 @@ func (s *State) AddChannel(channel *Channel) error {
 
 		s.PrivateChannels = append(s.PrivateChannels, *channel)
 	} else {
-		guild, err := s.GetGuildByID(channel.GuildID)
+		guild, err := s.Guild(channel.GuildID)
 		if err != nil {
 			return err
 		}
@@ -145,6 +183,10 @@ func (s *State) AddChannel(channel *Channel) error {
 
 // RemoveChannel removes a channel from current world state.
 func (s *State) RemoveChannel(channel *Channel) error {
+	if s == nil {
+		return nilError
+	}
+
 	if channel.IsPrivate {
 		for i, c := range s.PrivateChannels {
 			if c.ID == channel.ID {
@@ -153,7 +195,7 @@ func (s *State) RemoveChannel(channel *Channel) error {
 			}
 		}
 	} else {
-		guild, err := s.GetGuildByID(channel.GuildID)
+		guild, err := s.Guild(channel.GuildID)
 		if err != nil {
 			return err
 		}
@@ -169,9 +211,13 @@ func (s *State) RemoveChannel(channel *Channel) error {
 	return errors.New("Channel not found.")
 }
 
-// GetGuildChannelById gets a channel by ID from a guild.
-func (s *State) GetGuildChannelByID(guildID string, channelID string) (*Channel, error) {
-	guild, err := s.GetGuildByID(guildID)
+// GuildChannel gets a channel by ID from a guild.
+func (s *State) GuildChannel(guildID string, channelID string) (*Channel, error) {
+	if s == nil {
+		return nil, nilError
+	}
+
+	guild, err := s.Guild(guildID)
 	if err != nil {
 		return nil, err
 	}
@@ -181,15 +227,21 @@ func (s *State) GetGuildChannelByID(guildID string, channelID string) (*Channel,
 			return &c, nil
 		}
 	}
+
 	return nil, errors.New("Channel not found.")
 }
 
-// GetPrivateChannelByID gets a private channel by ID.
-func (s *State) GetPrivateChannelByID(channelID string) (*Channel, error) {
+// PrivateChannel gets a private channel by ID.
+func (s *State) PrivateChannel(channelID string) (*Channel, error) {
+	if s == nil {
+		return nil, nilError
+	}
+
 	for _, c := range s.PrivateChannels {
 		if c.ID == channelID {
 			return &c, nil
 		}
 	}
+
 	return nil, errors.New("Channel not found.")
 }

+ 2 - 1
structs.go

@@ -86,7 +86,8 @@ type Session struct {
 	UDPConn    *net.UDPConn
 
 	// Managed state object, updated with events.
-	State *State
+	State        *State
+	StateEnabled bool
 }
 
 // A Message stores all data related to a specific Discord message.

+ 30 - 10
wsapi.go

@@ -153,7 +153,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "READY":
 		var st Ready
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.OnReady(&st)
+			if s.StateEnabled {
+				s.State.OnReady(&st)
+			}
 			if s.OnReady != nil {
 				s.OnReady(s, st)
 			}
@@ -238,7 +240,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "CHANNEL_CREATE":
 		var st Channel
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.AddChannel(&st)
+			if s.StateEnabled {
+				s.State.AddChannel(&st)
+			}
 			if s.OnChannelCreate != nil {
 				s.OnChannelCreate(s, st)
 			}
@@ -247,7 +251,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "CHANNEL_UPDATE":
 		var st Channel
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.AddChannel(&st)
+			if s.StateEnabled {
+				s.State.AddChannel(&st)
+			}
 			if s.OnChannelUpdate != nil {
 				s.OnChannelUpdate(s, st)
 			}
@@ -256,7 +262,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "CHANNEL_DELETE":
 		var st Channel
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.RemoveChannel(&st)
+			if s.StateEnabled {
+				s.State.RemoveChannel(&st)
+			}
 			if s.OnChannelDelete != nil {
 				s.OnChannelDelete(s, st)
 			}
@@ -265,7 +273,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "GUILD_CREATE":
 		var st Guild
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.AddGuild(&st)
+			if s.StateEnabled {
+				s.State.AddGuild(&st)
+			}
 			if s.OnGuildCreate != nil {
 				s.OnGuildCreate(s, st)
 			}
@@ -274,7 +284,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "GUILD_UPDATE":
 		var st Guild
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.AddGuild(&st)
+			if s.StateEnabled {
+				s.State.AddGuild(&st)
+			}
 			if s.OnGuildCreate != nil {
 				s.OnGuildUpdate(s, st)
 			}
@@ -283,7 +295,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "GUILD_DELETE":
 		var st Guild
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.RemoveGuild(&st)
+			if s.StateEnabled {
+				s.State.RemoveGuild(&st)
+			}
 			if s.OnGuildDelete != nil {
 				s.OnGuildDelete(s, st)
 			}
@@ -292,7 +306,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "GUILD_MEMBER_ADD":
 		var st Member
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.AddMember(&st)
+			if s.StateEnabled {
+				s.State.AddMember(&st)
+			}
 			if s.OnGuildMemberAdd != nil {
 				s.OnGuildMemberAdd(s, st)
 			}
@@ -301,7 +317,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "GUILD_MEMBER_REMOVE":
 		var st Member
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.RemoveMember(&st)
+			if s.StateEnabled {
+				s.State.RemoveMember(&st)
+			}
 			if s.OnGuildMemberRemove != nil {
 				s.OnGuildMemberRemove(s, st)
 			}
@@ -310,7 +328,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
 	case "GUILD_MEMBER_UPDATE":
 		var st Member
 		if err = unmarshalEvent(e, &st); err == nil {
-			s.State.AddMember(&st)
+			if s.StateEnabled {
+				s.State.AddMember(&st)
+			}
 			if s.OnGuildMemberUpdate != nil {
 				s.OnGuildMemberUpdate(s, st)
 			}