diff --git a/regression-tests.dnsdist/test_Dnstap.py b/regression-tests.dnsdist/test_Dnstap.py index c9457afbe6d8..0d902f5ef46c 100644 --- a/regression-tests.dnsdist/test_Dnstap.py +++ b/regression-tests.dnsdist/test_Dnstap.py @@ -77,6 +77,29 @@ def checkDnstapResponse(testinstance, dnstap, protocol, response, initiator='127 testinstance.assertEqual(wire_message, response) +def getFirstMatchingMessageFromQueue(queue, messageType=None): + unused_messages = [] + selected = None + while True: + data = queue.get(True, timeout=2.0) + if not data: + break + decoded_message = dnstap_pb2.Dnstap() + decoded_message.ParseFromString(data) + if not selected and (not messageType or decoded_message.message.type == messageType): + selected = decoded_message + else: + unused_messages.append(data) + + if queue.empty(): + break + + # put back non-matching messages for later + for msg in reversed(unused_messages): + queue.put(msg) + + return selected + class TestDnstapOverRemoteLogger(DNSDistTest): _remoteLoggerServerPort = pickAvailablePort() _remoteLoggerQueue = Queue() @@ -155,12 +178,7 @@ def startResponders(cls): cls._remoteLoggerListener.start() def getFirstDnstap(self): - self.assertFalse(self._remoteLoggerQueue.empty()) - data = self._remoteLoggerQueue.get(False) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue) def testDnstap(self): """ @@ -380,13 +398,8 @@ def startResponders(cls): cls._remoteLoggerListener.daemon = True cls._remoteLoggerListener.start() - def getFirstDnstap(self): - self.assertFalse(self._remoteLoggerQueue.empty()) - data = self._remoteLoggerQueue.get(False) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + def getFirstDnstap(self, messageType=None): + return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue, messageType=messageType) def testDnstap(self): """ @@ -423,13 +436,13 @@ def testDnstap(self): time.sleep(1) # check the dnstap message corresponding to the UDP query - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_QUERY) checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query) checkDnstapNoExtra(self, dnstap) # check the dnstap message corresponding to the UDP response - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_RESPONSE) checkDnstapResponse(self, dnstap, dnstap_pb2.UDP, response) checkDnstapNoExtra(self, dnstap) @@ -444,13 +457,13 @@ def testDnstap(self): time.sleep(1) # check the dnstap message corresponding to the TCP query - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_QUERY) checkDnstapQuery(self, dnstap, dnstap_pb2.TCP, query) checkDnstapNoExtra(self, dnstap) # check the dnstap message corresponding to the TCP response - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_RESPONSE) checkDnstapResponse(self, dnstap, dnstap_pb2.TCP, response) checkDnstapNoExtra(self, dnstap) @@ -489,12 +502,12 @@ def testDnstapExtra(self): time.sleep(1) # check the dnstap message corresponding to the UDP query - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_QUERY) checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query) checkDnstapExtra(self, dnstap, b"Type,Query") # check the dnstap message corresponding to the UDP response - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_RESPONSE) checkDnstapResponse(self, dnstap, dnstap_pb2.UDP, response) checkDnstapExtra(self, dnstap, b"Type,Response") @@ -509,12 +522,12 @@ def testDnstapExtra(self): time.sleep(1) # check the dnstap message corresponding to the TCP query - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_QUERY) checkDnstapQuery(self, dnstap, dnstap_pb2.TCP, query) checkDnstapExtra(self, dnstap, b"Type,Query") # check the dnstap message corresponding to the TCP response - dnstap = self.getFirstDnstap() + dnstap = self.getFirstDnstap(dnstap_pb2.Message.CLIENT_RESPONSE) checkDnstapResponse(self, dnstap, dnstap_pb2.TCP, response) checkDnstapExtra(self, dnstap, b"Type,Response") @@ -617,11 +630,7 @@ def startResponders(cls): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamUnix(self): """ @@ -713,11 +722,7 @@ def startResponders(cls): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamUnix(self): """ @@ -796,11 +801,7 @@ def startResponders(cls): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamTcp(self): """ @@ -888,11 +889,7 @@ def startResponders(cls): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamTcp(self): """