Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions lib/srpc/setupclient/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,54 @@ package setupclient

import (
"crypto/tls"
"crypto/x509"
"os"
"time"

"github.com/Cloud-Foundations/Dominator/lib/format"
"github.com/Cloud-Foundations/Dominator/lib/log"
"github.com/Cloud-Foundations/Dominator/lib/log/nulllogger"
"github.com/Cloud-Foundations/Dominator/lib/srpc"
)

const dateTime = time.DateTime + " MST"

func getEarliestCertExpiration(certs []tls.Certificate) *x509.Certificate {
var earliest *x509.Certificate
for _, cert := range certs {
if cert.Leaf != nil && !cert.Leaf.NotAfter.IsZero() {
if earliest == nil {
earliest = cert.Leaf
} else if cert.Leaf.NotAfter.Before(earliest.NotAfter) {
earliest = cert.Leaf
}
}
}
return earliest
}

func getSleepInterval(cert *x509.Certificate) time.Duration {
day := 24 * time.Hour
week := 7 * day
lifetime := cert.NotAfter.Sub(cert.NotBefore)
refreshIn := time.Until(cert.NotBefore.Add(7 * lifetime >> 3))
if refreshIn > 0 {
return refreshIn
}
expiresIn := time.Until(cert.NotAfter)
if expiresIn > 2*week {
return week
} else if expiresIn > 2*day {
return day
} else if expiresIn > 2*time.Hour {
return time.Hour
} else if expiresIn > 2*time.Minute {
return time.Minute
} else {
return 5 * time.Second
}
}

func loadCerts() ([]tls.Certificate, error) {
if *certDirectory == "" {
cert, err := srpc.LoadCertificatesFromMetadata(100*time.Millisecond,
Expand Down Expand Up @@ -41,22 +83,75 @@ func loadCerts() ([]tls.Certificate, error) {
return []tls.Certificate{*cert}, nil
}

func loadLoop(logger log.DebugLogger, cert *x509.Certificate) {
for {
sleepInterval := getSleepInterval(cert)
logger.Printf("Certificate refetch at: %s (%s)\n",
time.Now().Add(sleepInterval).Format(dateTime),
format.Duration(sleepInterval))
time.Sleep(sleepInterval)
if c, err := setupTlsOnce(logger, true); err != nil {
logger.Println(err)
} else {
cert = c
}
}
}

func setupTls(ignoreMissingCerts bool) error {
certs, err := loadCerts()
logger := nulllogger.New()
cert, err := setupTlsOnce(logger, !ignoreMissingCerts)
if err != nil {
return err
}
if cert != nil {
go loadLoop(logger, cert)
}
return nil
}

func setupTlsOnce(logger log.DebugLogger, failIfExpired bool) (
*x509.Certificate, error) {
certs, err := loadCerts()
if err != nil {
return nil, err
}
if certs == nil {
if ignoreMissingCerts {
return nil
if failIfExpired {
return nil, srpc.ErrorMissingCertificate
}
return nil, nil
}
// Check expiration of earliest certificate.
earliest := getEarliestCertExpiration(certs)
if earliest != nil {
now := time.Now()
if notYet := earliest.NotBefore.Sub(now); notYet > 0 {
msg := "certificate will not be valid for " +
format.Duration(notYet)
if failIfExpired {
logger.Println(msg)
return nil, srpc.ErrorMissingCertificate
}
logger.Println(msg)
} else if expired := now.Sub(earliest.NotAfter); expired > 0 {
msg := "certificate expired " + format.Duration(expired) + " ago"
if failIfExpired {
logger.Println(msg)
return nil, srpc.ErrorMissingCertificate
}
logger.Println(msg)
} else {
logger.Printf("Certificate expires at: %s (%s)\n",
earliest.NotAfter.Local().Format(dateTime),
format.Duration(time.Until(earliest.NotAfter)))
}
return srpc.ErrorMissingCertificate
}
// Setup client.
clientConfig := new(tls.Config)
clientConfig.InsecureSkipVerify = true
clientConfig.MinVersion = tls.VersionTLS12
clientConfig.Certificates = certs
srpc.RegisterClientTlsConfig(clientConfig)
return nil
return earliest, nil
}