Skip to content

Commit

Permalink
MQTT: adds keepalive pinging, disconnect, and graceful goroutine cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jmacwhyte authored and deadprogram committed Aug 10, 2022
1 parent 41b7f06 commit 5bd814c
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 31 deletions.
146 changes: 116 additions & 30 deletions net/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package mqtt
import (
"errors"
"strings"
"sync"
"time"

"github.com/eclipse/paho.mqtt.golang/packets"
Expand All @@ -19,20 +20,32 @@ import (
func NewClient(o *ClientOptions) Client {
c := &mqttclient{opts: o, adaptor: o.Adaptor}
c.msgRouter, c.stopRouter = newRouter()

c.inboundPacketChan = make(chan packets.ControlPacket, 10)
c.stopInbound = make(chan struct{})
c.incomingPubChan = make(chan *packets.PublishPacket, 10)
// this launches a goroutine, so only call once per client:
c.msgRouter.matchAndDispatch(c.incomingPubChan, c.opts.Order, c)
return c
}

type mqttclient struct {
adaptor net.Adapter
conn net.Conn
connected bool
opts *ClientOptions
mid uint16
inbound chan packets.ControlPacket
stop chan struct{}
msgRouter *router
stopRouter chan bool
incomingPubChan chan *packets.PublishPacket
adaptor net.Adapter
conn net.Conn
connected bool
opts *ClientOptions
mid uint16
inboundPacketChan chan packets.ControlPacket
stopInbound chan struct{}
msgRouter *router
stopRouter chan bool
incomingPubChan chan *packets.PublishPacket
// stats for keepalive
lastReceive time.Time
lastSend time.Time
// keep track of routines and signal a shutdown
workers sync.WaitGroup
shutdown bool
}

// AddRoute allows you to add a handler for messages on a specific topic
Expand All @@ -56,6 +69,9 @@ func (c *mqttclient) IsConnectionOpen() bool {

// Connect will create a connection to the message broker.
func (c *mqttclient) Connect() Token {
if c.IsConnected() {
return &mqtttoken{}
}
var err error

// make connection
Expand All @@ -77,10 +93,6 @@ func (c *mqttclient) Connect() Token {
}

c.mid = 1
c.inbound = make(chan packets.ControlPacket, 10)
c.stop = make(chan struct{})
c.incomingPubChan = make(chan *packets.PublishPacket, 10)
c.msgRouter.matchAndDispatch(c.incomingPubChan, c.opts.Order, c)

// send the MQTT connect message
connectPkt := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
Expand All @@ -98,7 +110,7 @@ func (c *mqttclient) Connect() Token {
connectPkt.ClientIdentifier = c.opts.ClientID
connectPkt.ProtocolVersion = byte(c.opts.ProtocolVersion)
connectPkt.ProtocolName = "MQTT"
connectPkt.Keepalive = 60
connectPkt.Keepalive = uint16(c.opts.KeepAlive)

connectPkt.WillFlag = c.opts.WillEnabled
connectPkt.WillTopic = c.opts.WillTopic
Expand All @@ -110,6 +122,7 @@ func (c *mqttclient) Connect() Token {
if err != nil {
return &mqtttoken{err: err}
}
c.lastSend = time.Now()

// TODO: handle timeout as ReadPacket blocks until it gets a packet.
// CONNECT response.
Expand All @@ -127,20 +140,36 @@ func (c *mqttclient) Connect() Token {
}
}

go readMessages(c)
go processInbound(c)
go readMessages(c)
go keepAlive(c)

return &mqtttoken{}
}

// Disconnect will end the connection with the server, but not before waiting
// the specified number of milliseconds to wait for existing work to be
// completed.
// completed. Blocks until disconnected.
func (c *mqttclient) Disconnect(quiesce uint) {
c.conn.Close()
c.shutdownRoutines()
// block until all done
for c.connected {
time.Sleep(time.Millisecond * 10)
}
return
}

// shutdownRoutines will disconnect and shut down all processes. If you want to trigger a
// disconnect internally, make sure you call this instead of Disconnect() to avoid deadlocks
func (c *mqttclient) shutdownRoutines() {
if c.shutdown {
return
}
c.shutdown = true
c.conn.Close()
c.stopInbound <- struct{}{}
}

// Publish will publish a message with the specified QoS and content
// to the specified topic.
// Returns a token to track delivery of the message to the broker
Expand All @@ -153,6 +182,7 @@ func (c *mqttclient) Publish(topic string, qos byte, retained bool, payload inte
pub.Qos = qos
pub.TopicName = topic
pub.Retain = retained

switch payload.(type) {
case string:
pub.Payload = []byte(payload.(string))
Expand All @@ -168,6 +198,8 @@ func (c *mqttclient) Publish(topic string, qos byte, retained bool, payload inte
if err != nil {
return &mqtttoken{err: err}
}
// update this for every control message that is sent successfully, for keepalive
c.lastSend = time.Now()

return &mqtttoken{}
}
Expand Down Expand Up @@ -195,6 +227,7 @@ func (c *mqttclient) Subscribe(topic string, qos byte, callback MessageHandler)
if err != nil {
return &mqtttoken{err: err}
}
c.lastSend = time.Now()

return &mqtttoken{}
}
Expand All @@ -220,12 +253,13 @@ func (c *mqttclient) OptionsReader() ClientOptionsReader {
}

func processInbound(c *mqttclient) {
PROCESS:
for {
select {
case msg := <-c.inbound:
case msg := <-c.inboundPacketChan:
switch m := msg.(type) {
case *packets.PingrespPacket:
// TODO: handle this
// println("pong")
case *packets.SubackPacket:
// TODO: handle this
case *packets.UnsubackPacket:
Expand All @@ -242,33 +276,85 @@ func processInbound(c *mqttclient) {
case *packets.PubcompPacket:
// TODO: handle this
}
case <-c.stop:
return
case <-c.stopInbound:
break PROCESS
}
}

// as this routine could be the last to finish (if a lot of messages are queued in the
// channel), it is the last to turn out the lights

c.workers.Wait()
c.connected = false
c.shutdown = false
}

// readMessages reads incoming messages off the wire.
// incoming messages are then send into inbound channel.
// incoming messages are then send into inbound buffered channel.
func readMessages(c *mqttclient) {
c.workers.Add(1)
defer c.workers.Done()

var err error
var cp packets.ControlPacket

PROCESS:
for {
for !c.shutdown {
if cp, err = c.ReadPacket(); err != nil {
break PROCESS
c.shutdownRoutines()
return
}
if cp != nil {
c.inbound <- cp
// TODO: Notify keepalive logic that we recently received a packet
c.inboundPacketChan <- cp
// notify keepalive logic that we recently received a packet
c.lastReceive = time.Now()
}

time.Sleep(100 * time.Millisecond)
}
}

// TODO: handle if we received an error on read.
// If disconnect is in progress, swallow error and return
// keepAlive is a goroutine to handle sending ping requests according to the MQTT spec. If the keepalive time has
// been reached with no messages being sent, we will send a ping request and check back to see if we've
// had any activity by the timeout. If not, disconnect.
func keepAlive(c *mqttclient) {
c.workers.Add(1)
defer c.workers.Done()

var err error
var ping *packets.PingreqPacket
var timeout, pingsent time.Time

for !c.shutdown {
// As long as we haven't reached the keepalive value...
if time.Since(c.lastSend) < time.Duration(c.opts.KeepAlive)*time.Second {
// ...sleep and check shutdown status again
time.Sleep(time.Millisecond * 100)
continue
}

// value has been reached, so send a ping request
ping = packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
if err = ping.Write(c.conn); err != nil {
// if connection is lost, report disconnect
c.shutdownRoutines()
return
}
// println("ping")

c.lastSend = time.Now()
pingsent = time.Now()
timeout = pingsent.Add(c.opts.PingTimeout)

// as long as we are still connected and haven't received anything after the ping...
for !c.shutdown && c.lastReceive.Before(pingsent) {
// if the timeout has passed, disconnect
if time.Now().After(timeout) {
c.shutdownRoutines()
return
}
time.Sleep(time.Millisecond * 100)
}
}
}

func (c *mqttclient) ackFunc(packet *packets.PublishPacket) func() {
Expand Down
19 changes: 18 additions & 1 deletion net/mqtt/paho.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ type ClientOptions struct {

// NewClientOptions returns a new ClientOptions struct.
func NewClientOptions() *ClientOptions {
return &ClientOptions{Adaptor: net.ActiveDevice, ProtocolVersion: 4}
return &ClientOptions{Adaptor: net.ActiveDevice, ProtocolVersion: 4, KeepAlive: 60, PingTimeout: time.Second * 10}
}

// AddBroker adds a broker URI to the list of brokers to be used. The format should be
Expand Down Expand Up @@ -257,6 +257,23 @@ func (o *ClientOptions) SetPassword(p string) *ClientOptions {
return o
}

// SetKeepAlive will set the amount of time (in seconds) that the client
// should wait before sending a PING request to the broker. This will
// allow the client to know that a connection has not been lost with the
// server.
func (o *ClientOptions) SetKeepAlive(k time.Duration) *ClientOptions {
o.KeepAlive = int64(k / time.Second)
return o
}

// SetPingTimeout will set the amount of time (in seconds) that the client
// will wait after sending a PING request to the broker, before deciding
// that the connection has been lost. Default is 10 seconds.
func (o *ClientOptions) SetPingTimeout(k time.Duration) *ClientOptions {
o.PingTimeout = k
return o
}

// SetWill accepts a string will message to be set. When the client connects,
// it will give this will message to the broker, which will then publish the
// provided payload (the will) to any clients that are subscribed to the provided
Expand Down

0 comments on commit 5bd814c

Please sign in to comment.