Parcourir la source

Added support for tls cert swapping real time

Thisnthat il y a 1 an
Parent
commit
24d702dcca
1 fichiers modifiés avec 29 ajouts et 8 suppressions
  1. 29 8
      api.go

+ 29 - 8
api.go

@@ -2,6 +2,7 @@ package api
 
 import (
 	"context"
+	"crypto/tls"
 	"net/http"
 	"os"
 	"os/signal"
@@ -22,6 +23,9 @@ type Api struct {
 	ssl        bool
 	cert       string
 	key        string
+
+	certMu  sync.Mutex
+	tlscert *tls.Certificate
 }
 
 type ApiConfig struct {
@@ -43,19 +47,20 @@ func NewApi(config ApiConfig) *Api {
 
 	logrus.Warn(srv.Addr)
 
-	webserver := &Api{
+	apiWebServer := &Api{
 		httpServer: srv,
 		mux:        mux.NewRouter(),
 		online:     false,
 	}
 
 	if config.SSL {
-		webserver.ssl = true
-		webserver.cert = config.Cert
-		webserver.key = config.Key
+		apiWebServer.ssl = true
+		apiWebServer.cert = config.Cert
+		apiWebServer.key = config.Key
+		apiWebServer.tlscert = nil
 	}
 
-	return webserver
+	return apiWebServer
 }
 
 func (api *Api) AddHandler(path string, handler func(http.ResponseWriter, *http.Request), method string) {
@@ -66,6 +71,9 @@ func (api *Api) Start(wg *sync.WaitGroup) error {
 	defer wg.Done()
 
 	api.httpServer.Handler = api.mux
+	api.httpServer.TLSConfig = &tls.Config{
+		GetCertificate: api.getCertificate,
+	}
 
 	go func() {
 		if api.ssl {
@@ -81,9 +89,9 @@ func (api *Api) Start(wg *sync.WaitGroup) error {
 	go logrus.Info("API webserver has started")
 	api.online = true
 	api.running = make(chan os.Signal, 1)
-	signal.Notify(api.running, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill)
-	<-api.running
+	signal.Notify(api.running, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
 
+	<-api.running
 	ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
 
 	if err := api.httpServer.Shutdown(ctx); err != nil {
@@ -97,7 +105,20 @@ func (api *Api) Start(wg *sync.WaitGroup) error {
 
 func (api *Api) Stop() {
 	if api.online {
-		go logrus.Infof("Shutting down API webserver")
+		go logrus.Info("Shutting down API webserver")
 		api.running <- syscall.SIGINT
 	}
 }
+
+func (api *Api) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+	api.certMu.Lock()
+	defer api.certMu.Unlock()
+
+	cert, err := tls.LoadX509KeyPair(api.cert, api.key)
+	if err != nil {
+		go logrus.Error(err)
+		return nil, err
+	}
+
+	return &cert, nil
+}