Browse Source

Only Shard when ShardCount > 1

Also cleaned up identify sending so there's now a function that handles
it instead of duplicate code.  Renamed handshake* structs to identify*
structs to make naming match up.
Bruce Marriner 8 years ago
parent
commit
9dc51d1c49
3 changed files with 65 additions and 63 deletions
  1. 1 1
      discord.go
  2. 2 2
      structs.go
  3. 62 60
      wsapi.go

+ 1 - 1
discord.go

@@ -41,7 +41,7 @@ func New(args ...interface{}) (s *Session, err error) {
 		Compress:               true,
 		ShouldReconnectOnError: true,
 		ShardID:                0,
-		NumShards:              1,
+		ShardCount:             1,
 	}
 
 	// If no arguments are passed return the empty Session interface.

+ 2 - 2
structs.go

@@ -40,8 +40,8 @@ type Session struct {
 	Compress bool
 
 	// Sharding
-	ShardID   int
-	NumShards int
+	ShardID    int
+	ShardCount int
 
 	// Should state tracking be enabled.
 	// State tracking is the best way for getting the the users

+ 62 - 60
wsapi.go

@@ -26,27 +26,6 @@ import (
 	"github.com/gorilla/websocket"
 )
 
-type handshakeProperties struct {
-	OS              string `json:"$os"`
-	Browser         string `json:"$browser"`
-	Device          string `json:"$device"`
-	Referer         string `json:"$referer"`
-	ReferringDomain string `json:"$referring_domain"`
-}
-
-type handshakeData struct {
-	Token          string              `json:"token"`
-	Properties     handshakeProperties `json:"properties"`
-	LargeThreshold int                 `json:"large_threshold"`
-	Compress       bool                `json:"compress"`
-	Shard          [2]int              `json:"shard"`
-}
-
-type handshakeOp struct {
-	Op   int           `json:"op"`
-	Data handshakeData `json:"d"`
-}
-
 type resumePacket struct {
 	Op   int `json:"op"`
 	Data struct {
@@ -73,16 +52,6 @@ func (s *Session) Open() (err error) {
 		return
 	}
 
-	if s.NumShards <= 0 {
-		err = errors.New("NumShards must be greater or equal to 1")
-		return
-	}
-
-	if s.ShardID >= s.NumShards {
-		err = errors.New("ShardID must be less than NumShards")
-		return
-	}
-
 	if s.VoiceConnections == nil {
 		s.log(LogInformational, "creating new VoiceConnections map")
 		s.VoiceConnections = make(map[string]*VoiceConnection)
@@ -128,24 +97,7 @@ func (s *Session) Open() (err error) {
 
 	} else {
 
-		data := handshakeOp{
-			2,
-			handshakeData{
-				s.Token,
-				handshakeProperties{
-					runtime.GOOS,
-					"Discordgo v" + VERSION,
-					"",
-					"",
-					"",
-				},
-				250,
-				s.Compress,
-				[2]int{s.ShardID, s.NumShards},
-			},
-		}
-		s.log(LogInformational, "sending identify packet to gateway")
-		err = s.wsConn.WriteJSON(data)
+		err = s.identify()
 		if err != nil {
 			s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
 			return
@@ -384,17 +336,8 @@ func (s *Session) onEvent(messageType int, message []byte) {
 	if e.Operation == 9 {
 
 		s.log(LogInformational, "sending identify packet to gateway in response to Op9")
-		s.wsMutex.Lock()
-		err = s.wsConn.WriteJSON(handshakeOp{
-			2,
-			handshakeData{
-				s.Token,
-				handshakeProperties{runtime.GOOS, "Discordgo v" + VERSION, "", "", ""},
-				250,
-				s.Compress,
-				[2]int{s.ShardID, s.NumShards},
-			}})
-		s.wsMutex.Unlock()
+
+		err = s.identify()
 		if err != nil {
 			s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
 			return
@@ -592,6 +535,65 @@ func (s *Session) onVoiceServerUpdate(se *Session, st *VoiceServerUpdate) {
 	}
 }
 
+type identifyProperties struct {
+	OS              string `json:"$os"`
+	Browser         string `json:"$browser"`
+	Device          string `json:"$device"`
+	Referer         string `json:"$referer"`
+	ReferringDomain string `json:"$referring_domain"`
+}
+
+type identifyData struct {
+	Token          string             `json:"token"`
+	Properties     identifyProperties `json:"properties"`
+	LargeThreshold int                `json:"large_threshold"`
+	Compress       bool               `json:"compress"`
+	Shard          *[2]int            `json:"shard,omitempty"`
+}
+
+type identifyOp struct {
+	Op   int          `json:"op"`
+	Data identifyData `json:"d"`
+}
+
+// identify sends the identify packet to the gateway
+func (s *Session) identify() error {
+
+	properties := identifyProperties{runtime.GOOS,
+		"Discordgo v" + VERSION,
+		"",
+		"",
+		"",
+	}
+
+	data := identifyData{s.Token,
+		properties,
+		250,
+		s.Compress,
+		nil,
+	}
+
+	if s.ShardCount > 1 {
+
+		if s.ShardID >= s.ShardCount {
+			return errors.New("ShardID must be less than ShardCount")
+		}
+
+		data.Shard = &[2]int{s.ShardID, s.ShardCount}
+	}
+
+	op := identifyOp{2, data}
+
+	s.wsMutex.Lock()
+	err := s.wsConn.WriteJSON(op)
+	s.wsMutex.Unlock()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func (s *Session) reconnect() {
 
 	s.log(LogInformational, "called")