Browse Source

Open() func now validates connection, fixed #198

Now the open function will follow through a bit more and insure that the
proper sequence of events happens during the Open call.  This required
some refactoring and a few mild changes in the onEvent func.
Bruce Marriner 7 years ago
parent
commit
7d1657e59b
1 changed files with 124 additions and 68 deletions
  1. 124 68
      wsapi.go

+ 124 - 68
wsapi.go

@@ -15,6 +15,7 @@ import (
 	"compress/zlib"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"io"
 	"net/http"
 	"runtime"
@@ -45,65 +46,93 @@ type resumePacket struct {
 	} `json:"d"`
 }
 
-// Open opens a websocket connection to Discord.
-func (s *Session) Open() (err error) {
-
+// Open creates a websocket connection to Discord.
+// See: https://discordapp.com/developers/docs/topics/gateway#connecting
+func (s *Session) Open() error {
 	s.log(LogInformational, "called")
 
-	s.Lock()
-	defer func() {
-		if err != nil {
-			s.Unlock()
-		}
-	}()
+	var err error
 
-	// A basic state is a hard requirement for Voice.
-	if s.State == nil {
-		state := NewState()
-		state.TrackChannels = false
-		state.TrackEmojis = false
-		state.TrackMembers = false
-		state.TrackRoles = false
-		state.TrackVoice = false
-		s.State = state
-	}
+	// Prevent Open or other major Session functions from
+	// being called while Open is still running.
+	s.Lock()
+	defer s.Unlock()
 
+	// If the websock is already open, bail out here.
 	if s.wsConn != nil {
-		err = ErrWSAlreadyOpen
-		return
-	}
-
-	if s.VoiceConnections == nil {
-		s.log(LogInformational, "creating new VoiceConnections map")
-		s.VoiceConnections = make(map[string]*VoiceConnection)
+		return ErrWSAlreadyOpen
 	}
 
 	// Get the gateway to use for the Websocket connection
 	if s.gateway == "" {
 		s.gateway, err = s.Gateway()
 		if err != nil {
-			return
+			return err
 		}
 
 		// Add the version and encoding to the URL
 		s.gateway = s.gateway + "?v=" + APIVersion + "&encoding=json"
 	}
 
+	// Connect to the Gateway
+	s.log(LogInformational, "connecting to gateway %s", s.gateway)
 	header := http.Header{}
 	header.Add("accept-encoding", "zlib")
-
-	s.log(LogInformational, "connecting to gateway %s", s.gateway)
 	s.wsConn, _, err = websocket.DefaultDialer.Dial(s.gateway, header)
 	if err != nil {
 		s.log(LogWarning, "error connecting to gateway %s, %s", s.gateway, err)
 		s.gateway = "" // clear cached gateway
-		// TODO: should we add a retry block here?
-		return
+		s.wsConn = nil // Just to be safe.
+		return err
 	}
 
+	defer func() {
+		// because of this, all code below must set err to the error
+		// when exiting with an error :)  Maybe someone has a better
+		// way :)
+		if err != nil {
+			s.wsConn.Close()
+			s.wsConn = nil
+		}
+	}()
+
+	// The first response from Discord should be an Op 10 (Hello) Packet.
+	// When processed by onEvent the heartbeat goroutine will be started.
+	mt, m, err := s.wsConn.ReadMessage()
+	if err != nil {
+		return err
+	}
+	e, err := s.onEvent(mt, m)
+	if err != nil {
+		return err
+	}
+	if e.Operation != 10 {
+		err = fmt.Errorf("Expecting Op 10, got Op %d instead.", e.Operation)
+		return err
+	}
+	s.log(LogInformational, "Op 10 Hello Packet received from Discord")
+	s.LastHeartbeatAck = time.Now().UTC()
+	var h helloOp
+	if err = json.Unmarshal(e.RawData, &h); err != nil {
+		err = fmt.Errorf("error unmarshalling helloOp, %s", err)
+		return err
+	}
+
+	// Now we send either an Op 2 Identity if this is a brand new
+	// connection or Op 6 Resume if we are resuming an existing connection.
 	sequence := atomic.LoadInt64(s.sequence)
-	if s.sessionID != "" && sequence > 0 {
+	if s.sessionID == "" && sequence == 0 {
 
+		// Send Op 2 Identity Packet
+		err = s.identify()
+		if err != nil {
+			err = fmt.Errorf("error sending identify packet to gateway, %s, %s", s.gateway, err)
+			return err
+		}
+
+	} else {
+
+		// Send Op 6 Resume Packet
 		p := resumePacket{}
 		p.Op = 6
 		p.Data.Token = s.Token
@@ -111,34 +140,66 @@ func (s *Session) Open() (err error) {
 		p.Data.Sequence = sequence
 
 		s.log(LogInformational, "sending resume packet to gateway")
+		s.wsMutex.Lock()
 		err = s.wsConn.WriteJSON(p)
+		s.wsMutex.Unlock()
 		if err != nil {
-			s.log(LogWarning, "error sending gateway resume packet, %s, %s", s.gateway, err)
-			return
+			err = fmt.Errorf("error sending gateway resume packet, %s, %s", s.gateway, err)
+			return err
 		}
 
-	} else {
-
-		err = s.identify()
-		if err != nil {
-			s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
-			return
-		}
 	}
 
-	// Create listening outside of listen, as it needs to happen inside the mutex
-	// lock.
-	s.listening = make(chan interface{})
-	go s.listen(s.wsConn, s.listening)
-	s.LastHeartbeatAck = time.Now().UTC()
+	// A basic state is a hard requirement for Voice.
+	// We create it here so the below READY/RESUMED packet can populate
+	// the state :)
+	// XXX: Move to New() func?
+	if s.State == nil {
+		state := NewState()
+		state.TrackChannels = false
+		state.TrackEmojis = false
+		state.TrackMembers = false
+		state.TrackRoles = false
+		state.TrackVoice = false
+		s.State = state
+	}
 
-	s.Unlock()
+	// Now Discord should send us a READY or RESUMED packet.
+	mt, m, err = s.wsConn.ReadMessage()
+	if err != nil {
+		return err
+	}
+	e, err = s.onEvent(mt, m)
+	if err != nil {
+		return err
+	}
+	if e.Type != `READY` && e.Type != `RESUMED` {
+		// This is not fatal, but it does not follow their API documentation.
+		s.log(LogWarning, "Expected READY/RESUMED, instead got:\n%#v\n", e)
+	}
+	s.log(LogInformational, "First Packet:\n%#v\n", e)
 
-	s.log(LogInformational, "emit connect event")
+	s.log(LogInformational, "We are now connected to Discord, emitting connect event")
 	s.handleEvent(connectEventType, &Connect{})
 
+	// A VoiceConnections map is a hard requirement for Voice.
+	// XXX: can this be moved to when opening a voice connection?
+	if s.VoiceConnections == nil {
+		s.log(LogInformational, "creating new VoiceConnections map")
+		s.VoiceConnections = make(map[string]*VoiceConnection)
+	}
+
+	// Create listening chan outside of listen, as it needs to happen inside the
+	// mutex lock and needs to exist before calling heartbeat and listen
+	// go rountines.
+	s.listening = make(chan interface{})
+
+	// Start sending heartbeats and reading messages from Discord.
+	go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval)
+	go s.listen(s.wsConn, s.listening)
+
 	s.log(LogInformational, "exiting")
-	return
+	return nil
 }
 
 // listen polls the websocket connection for events, it will stop when the
@@ -364,9 +425,7 @@ func (s *Session) RequestGuildMembers(guildID, query string, limit int) (err err
 //
 // If you use the AddHandler() function to register a handler for the
 // "OnEvent" event then all events will be passed to that handler.
-//
-// TODO: You may also register a custom event handler entirely using...
-func (s *Session) onEvent(messageType int, message []byte) {
+func (s *Session) onEvent(messageType int, message []byte) (*Event, error) {
 
 	var err error
 	var reader io.Reader
@@ -378,7 +437,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
 		z, err2 := zlib.NewReader(reader)
 		if err2 != nil {
 			s.log(LogError, "error uncompressing websocket message, %s", err)
-			return
+			return nil, err2
 		}
 
 		defer func() {
@@ -396,7 +455,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
 	decoder := json.NewDecoder(reader)
 	if err = decoder.Decode(&e); err != nil {
 		s.log(LogError, "error decoding websocket message, %s", err)
-		return
+		return e, err
 	}
 
 	s.log(LogDebug, "Op: %d, Seq: %d, Type: %s, Data: %s\n\n", e.Operation, e.Sequence, e.Type, string(e.RawData))
@@ -410,10 +469,10 @@ func (s *Session) onEvent(messageType int, message []byte) {
 		s.wsMutex.Unlock()
 		if err != nil {
 			s.log(LogError, "error sending heartbeat in response to Op1")
-			return
+			return e, err
 		}
 
-		return
+		return e, nil
 	}
 
 	// Reconnect
@@ -422,7 +481,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
 		s.log(LogInformational, "Closing and reconnecting in response to Op7")
 		s.Close()
 		s.reconnect()
-		return
+		return e, nil
 	}
 
 	// Invalid Session
@@ -434,20 +493,15 @@ func (s *Session) onEvent(messageType int, message []byte) {
 		err = s.identify()
 		if err != nil {
 			s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
-			return
+			return e, err
 		}
 
-		return
+		return e, nil
 	}
 
 	if e.Operation == 10 {
-		var h helloOp
-		if err = json.Unmarshal(e.RawData, &h); err != nil {
-			s.log(LogError, "error unmarshalling helloOp, %s", err)
-		} else {
-			go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval)
-		}
-		return
+		// Op10 is handled by Open()
+		return e, nil
 	}
 
 	if e.Operation == 11 {
@@ -455,7 +509,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
 		s.LastHeartbeatAck = time.Now().UTC()
 		s.Unlock()
 		s.log(LogInformational, "got heartbeat ACK")
-		return
+		return e, nil
 	}
 
 	// Do not try to Dispatch a non-Dispatch Message
@@ -463,7 +517,7 @@ func (s *Session) onEvent(messageType int, message []byte) {
 		// But we probably should be doing something with them.
 		// TEMP
 		s.log(LogWarning, "unknown Op: %d, Seq: %d, Type: %s, Data: %s, message: %s", e.Operation, e.Sequence, e.Type, string(e.RawData), string(message))
-		return
+		return e, nil
 	}
 
 	// Store the message sequence
@@ -492,6 +546,8 @@ func (s *Session) onEvent(messageType int, message []byte) {
 
 	// For legacy reasons, we send the raw event also, this could be useful for handling unknown events.
 	s.handleEvent(eventEventType, e)
+
+	return e, nil
 }
 
 // ------------------------------------------------------------------------------------------------