From 222d3c1329666ff7483b012996211f5b08f4bb75 Mon Sep 17 00:00:00 2001 From: Matt Brittan Date: Thu, 29 Apr 2021 17:31:24 +1200 Subject: [PATCH] Handle connection loss during call to Disconnect() (including tests). Also reduce noise from tests. Ref issue #501 --- client.go | 18 ++++++++++++++---- fvt_client_test.go | 42 ++++++++++++++++++++++++++++++++++++++++- unit_client_test.go | 10 +++++----- unit_messageids_test.go | 3 +-- 4 files changed, 61 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 53097471..60d8e7f3 100644 --- a/client.go +++ b/client.go @@ -439,12 +439,22 @@ func (c *client) Disconnect(quiesce uint) { dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket) dt := newToken(packets.Disconnect) - c.oboundP <- &PacketAndToken{p: dm, t: dt} + disconnectSent := false + select { + case c.oboundP <- &PacketAndToken{p: dm, t: dt}: + disconnectSent = true + case <-c.commsStopped: + WARN.Println("Disconnect packet could not be sent because comms stopped") + case <-time.After(time.Duration(quiesce) * time.Millisecond): + WARN.Println("Disconnect packet not sent due to timeout") + } // wait for work to finish, or quiesce time consumed - DEBUG.Println(CLI, "calling WaitTimeout") - dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond) - DEBUG.Println(CLI, "WaitTimeout done") + if disconnectSent { + DEBUG.Println(CLI, "calling WaitTimeout") + dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond) + DEBUG.Println(CLI, "WaitTimeout done") + } } else { WARN.Println(CLI, "Disconnect() called but not connected (disconnected/reconnecting)") c.setConnected(disconnected) diff --git a/fvt_client_test.go b/fvt_client_test.go index 31d71fed..4e3793c1 100644 --- a/fvt_client_test.go +++ b/fvt_client_test.go @@ -31,7 +31,19 @@ func Test_Start(t *testing.T) { t.Fatalf("Error on Client.Connect(): %v", token.Error()) } - c.Disconnect(250) + // Disconnect should return within 250ms and calling a second time should not block + disconnectC := make(chan struct{}, 1) + go func() { + c.Disconnect(250) + c.Disconnect(5) + close(disconnectC) + }() + + select { + case <-time.After(time.Millisecond * 300): + t.Errorf("disconnect did not finnish within 300ms") + case <-disconnectC: + } } /* uncomment this if you have connection policy disallowing FailClientID @@ -90,6 +102,34 @@ func Test_Start(t *testing.T) { } */ +// Disconnect should not block under any circumstance +// This is triggered by issue #501; there is a very slight chance that Disconnect could get through the +// `status == connected` check and then the connection drops... +func Test_Disconnect(t *testing.T) { + ops := NewClientOptions().SetClientID("Disconnect").AddBroker(FVTTCP) + c := NewClient(ops) + + if token := c.Connect(); token.Wait() && token.Error() != nil { + t.Fatalf("Error on Client.Connect(): %v", token.Error()) + } + + // Attempt to disconnect twice simultaneously and ensure this does not block + disconnectC := make(chan struct{}, 1) + go func() { + c.Disconnect(250) + cli := c.(*client) + cli.status = connected + c.Disconnect(250) + close(disconnectC) + }() + + select { + case <-time.After(time.Millisecond * 300): + t.Errorf("disconnect did not finnish within 300ms") + case <-disconnectC: + } +} + func Test_Publish_1(t *testing.T) { ops := NewClientOptions() ops.AddBroker(FVTTCP) diff --git a/unit_client_test.go b/unit_client_test.go index 778026c6..939fd2cf 100644 --- a/unit_client_test.go +++ b/unit_client_test.go @@ -18,15 +18,15 @@ import ( "log" "net/http" _ "net/http/pprof" - "os" "testing" ) func init() { - DEBUG = log.New(os.Stderr, "DEBUG ", log.Ltime) - WARN = log.New(os.Stderr, "WARNING ", log.Ltime) - CRITICAL = log.New(os.Stderr, "CRITICAL ", log.Ltime) - ERROR = log.New(os.Stderr, "ERROR ", log.Ltime) + // Logging is off by default as this makes things simpler when you just want to confirm that tests pass + // DEBUG = log.New(os.Stderr, "DEBUG ", log.Ltime) + // WARN = log.New(os.Stderr, "WARNING ", log.Ltime) + // CRITICAL = log.New(os.Stderr, "CRITICAL ", log.Ltime) + // ERROR = log.New(os.Stderr, "ERROR ", log.Ltime) go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) diff --git a/unit_messageids_test.go b/unit_messageids_test.go index a6f9709c..37229d3a 100644 --- a/unit_messageids_test.go +++ b/unit_messageids_test.go @@ -16,7 +16,6 @@ package mqtt import ( "fmt" - "log" "testing" ) @@ -63,7 +62,7 @@ func Test_noFreeID(t *testing.T) { mids := &messageIds{index: make(map[uint16]tokenCompletor)} for i := midMin; i != 0; i++ { - log.Println(i) + // Uncomment to see all message IDS log.Println(i) mids.index[i] = &d }