Просмотр исходного кода

Prevent concurrent writes to gateway websocket.

Bruce Marriner 8 лет назад
Родитель
Сommit
e9e9ef86b3
2 измененных файлов с 14 добавлено и 0 удалено
  1. 3 0
      structs.go
  2. 11 0
      wsapi.go

+ 3 - 0
structs.go

@@ -89,6 +89,9 @@ type Session struct {
 
 	// stores session ID of current Gateway connection
 	sessionID string
+
+	// used to make sure gateway websocket writes do not happen concurrently
+	wsMutex sync.Mutex
 }
 
 type rateLimitMutex struct {

+ 11 - 0
wsapi.go

@@ -237,7 +237,10 @@ func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{}
 	var err error
 	ticker := time.NewTicker(i * time.Millisecond)
 	for {
+
+		s.wsMutex.Lock()
 		err = wsConn.WriteJSON(heartbeatOp{1, s.sequence})
+		s.wsMutex.Unlock()
 		if err != nil {
 			log.Println("Error sending heartbeat:", err)
 			return
@@ -284,7 +287,9 @@ func (s *Session) UpdateStatus(idle int, game string) (err error) {
 		usd.Game = &updateStatusGame{game}
 	}
 
+	s.wsMutex.Lock()
 	err = s.wsConn.WriteJSON(updateStatusOp{3, usd})
+	s.wsMutex.Unlock()
 
 	return
 }
@@ -340,7 +345,9 @@ func (s *Session) onEvent(messageType int, message []byte) {
 	// Must respond with a heartbeat packet within 5 seconds
 	if e.Operation == 1 {
 		s.log(LogInformational, "sending heartbeat in response to Op1")
+		s.wsMutex.Lock()
 		err = s.wsConn.WriteJSON(heartbeatOp{1, s.sequence})
+		s.wsMutex.Unlock()
 		if err != nil {
 			s.log(LogError, "error sending heartbeat in response to Op1")
 			return
@@ -358,7 +365,9 @@ func (s *Session) onEvent(messageType int, message []byte) {
 	if e.Operation == 9 {
 
 		s.log(LogInformational, "sending identify packet to gateway in response to Op9")
+		s.wsMutex.Lock()
 		err = s.wsConn.WriteJSON(handshakeOp{2, handshakeData{s.Token, handshakeProperties{runtime.GOOS, "Discordgo v" + VERSION, "", "", ""}, 250, s.Compress}})
+		s.wsMutex.Unlock()
 		if err != nil {
 			s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err)
 			return
@@ -468,7 +477,9 @@ func (s *Session) ChannelVoiceJoin(gID, cID string, mute, deaf bool) (voice *Voi
 
 	// Send the request to Discord that we want to join the voice channel
 	data := voiceChannelJoinOp{4, voiceChannelJoinData{&gID, &cID, mute, deaf}}
+	s.wsMutex.Lock()
 	err = s.wsConn.WriteJSON(data)
+	s.wsMutex.Unlock()
 	if err != nil {
 		s.log(LogInformational, "Deleting VoiceConnection %s", gID)
 		delete(s.VoiceConnections, gID)